Skip to content

Commit

Permalink
refactor: simplify await.ml (#4633)
Browse files Browse the repository at this point in the history
Simple inspection reveals that we never add `MetaCont` to the context, which means we can simplify the code to enforce that invariant and consider fewer cases. This is an experiment to that effect, building on #4632.
  • Loading branch information
crusso authored Jul 30, 2024
1 parent 75c59e1 commit c7e5ac9
Showing 1 changed file with 67 additions and 54 deletions.
121 changes: 67 additions & 54 deletions src/ir_passes/await.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,39 +39,50 @@ let letcont k scope =
blockE [funcD k' v e] (* at this point, I'm really worried about variable capture *)
(scope k')

(* pre-compose a continuation with a call to a `finally`-thunk *)
let precont k vthunk =
let finally e = blockE [expD (varE vthunk -*- unitE ())] e in
match k with
| ContVar k' ->
let typ = match typ_of_var k' with
| T.(Func (Local, Returns, [], ts1, _)) -> T.seq ts1
| _ -> assert false in
MetaCont (typ, fun v -> finally (varE k' -*- varE v))
| MetaCont (typ, cont) ->
MetaCont (typ, fun v -> finally (cont v))

(* Named labels for break, special labels for return, throw and cleanup *)
type label = Return | Throw | Cleanup | Named of string

let ( -@- ) k exp2 =
match k with
| ContVar v ->
varE v -*- exp2
varE v -*- exp2
| MetaCont (typ0, k) ->
match exp2.it with
| VarE v -> k (var v (typ exp2))
| _ ->
let u = fresh_var "u" typ0 in
letE u exp2 (k u)
match exp2.it with
| VarE v -> k (var v (typ exp2))
| _ ->
let u = fresh_var "u" typ0 in
letE u exp2 (k u)

(* Label environments *)

module LabelEnv = Env.Make(struct type t = label let compare = compare end)

module PatEnv = Env.Make(String)

type label_sort = Cont of kont | Label
type label_sort = Cont of var | Label

let precompose vthunk k =
let typ0 = match typ_of_var k with
| T.(Func (Local, Returns, [], ts1, _)) -> T.seq ts1
| _ -> assert false in
let v = fresh_var "v" typ0 in
let e = blockE [expD (varE vthunk -*- unitE ())] (varE k -*- varE v) in
let k' = fresh_cont typ0 (typ e) in
(k', funcD k' v e)

let preconts context vthunk scope =
let (ds, ctxt) = LabelEnv.fold
(fun lab sort (ds, ctxt) ->
match sort with
| Label -> assert false
| Cont k ->
let (k', d) = precompose vthunk k in
(d :: ds,
LabelEnv.add lab (Cont k') ctxt))
context
([], LabelEnv.empty)
in
blockE ds (scope ctxt)

let typ_cases cases = List.fold_left (fun t case -> T.lub t (typ case.it.exp)) T.Non cases

Expand All @@ -84,9 +95,9 @@ let rec t_async context exp =
let k_fail = fresh_err_cont T.unit in
let k_clean = fresh_bail_cont T.unit in
let context' =
LabelEnv.add Cleanup (Cont (ContVar k_clean))
(LabelEnv.add Return (Cont (ContVar k_ret))
(LabelEnv.singleton Throw (Cont (ContVar k_fail))))
LabelEnv.add Cleanup (Cont k_clean)
(LabelEnv.add Return (Cont k_ret)
(LabelEnv.singleton Throw (Cont k_fail)))
in
cps_asyncE s typ1 (typ exp1)
(forall [tb] ([k_ret; k_fail; k_clean] -->*
Expand Down Expand Up @@ -122,14 +133,14 @@ and t_exp' context exp =
| PrimE (BreakPrim id, [exp1]) ->
begin
match LabelEnv.find_opt (Named id) context with
| Some (Cont k) -> (retE (k -@- t_exp context exp1)).it
| Some (Cont k) -> (retE (varE k -*- t_exp context exp1)).it
| Some Label -> (breakE id (t_exp context exp1)).it
| None -> assert false
end
| PrimE (RetPrim, [exp1]) ->
begin
match LabelEnv.find_opt Return context with
| Some (Cont k) -> (retE (k -@- t_exp context exp1)).it
| Some (Cont k) -> (retE (varE k -*- t_exp context exp1)).it
| Some Label -> (retE (t_exp context exp1)).it
| None -> assert false
end
Expand Down Expand Up @@ -341,21 +352,25 @@ and c_exp' context exp k =
note = Note.{ exp.note with typ = typ' } }))
end)
| TryE (exp1, cases, finally_opt) ->
let pre k =
let precont k scope =
match finally_opt with
| Some (id2, typ2) ->
let vthunk = var id2 typ2 in
let (k', d) = precompose vthunk k in
blockE [d] (scope k')
| None ->
scope k in
let finalise context scope =
match finally_opt with
| Some (id2, typ2) -> precont k (var id2 typ2)
| None -> k in
let pre' = function
| Cont k -> Cont (pre k)
| Label -> assert false in
(* All control-flow out must pass through the potential `finally` thunk *)
let context = LabelEnv.map pre' context in
(* assert that a context (top-level or async) has set up a `Cleanup` cont *)
| Some (id2, typ2) -> preconts context (var id2 typ2) scope
| None -> scope context
in
finalise context (fun context ->
(* assert that a context (top-level or async) has set up a `Cleanup` and `Throw` cont *)
assert (LabelEnv.find_opt Cleanup context <> None);
(* TODO: do we need to reify f? *)
let f = match LabelEnv.find Throw context with Cont f -> f | _ -> assert false in
letcont f (fun f ->
letcont (pre k) (fun k ->
letcont k (fun k ->
precont k (fun k ->
match eff exp1 with
| T.Triv ->
varE k -*- t_exp context exp1
Expand All @@ -379,7 +394,7 @@ and c_exp' context exp k =
cases
@ if omit_rethrow then [] else [rethrow] in
let throw = fresh_err_cont (answerT (typ_of_var k)) in
let context' = LabelEnv.add Throw (Cont (ContVar throw)) context in
let context' = LabelEnv.add Throw (Cont throw) context in
blockE
[ let e = fresh_var "e" T.catch in
funcD throw e {
Expand All @@ -389,30 +404,30 @@ and c_exp' context exp k =
}
]
(c_exp context' exp1 (ContVar k))
))
)))
| LoopE exp1 ->
c_loop context k exp1
| LabelE (id, _typ, exp1) ->
letcont k
(fun k ->
let context' = LabelEnv.add (Named id) (Cont (ContVar k)) context in
let context' = LabelEnv.add (Named id) (Cont k) context in
c_exp context' exp1 (ContVar k)) (* TODO optimize me, if possible *)
| PrimE (BreakPrim id, [exp1]) ->
begin
match LabelEnv.find_opt (Named id) context with
| Some (Cont k') -> c_exp context exp1 k'
| Some (Cont k') -> c_exp context exp1 (ContVar k')
| _ -> assert false
end
| PrimE (RetPrim, [exp1]) ->
begin
match LabelEnv.find_opt Return context with
| Some (Cont k') -> c_exp context exp1 k'
| Some (Cont k') -> c_exp context exp1 (ContVar k')
| _ -> assert false
end
| PrimE (ThrowPrim, [exp1]) ->
begin
match LabelEnv.find_opt Throw context with
| Some (Cont k') -> c_exp context exp1 k'
| Some (Cont k') -> c_exp context exp1 (ContVar k')
| _ -> assert false
end
| AsyncE (T.Cmp, tb, exp1, typ1) ->
Expand All @@ -423,9 +438,9 @@ and c_exp' context exp k =
let k_fail = fresh_err_cont T.unit in
let k_clean = fresh_bail_cont T.unit in
let context' =
LabelEnv.add Cleanup (Cont (ContVar k_clean))
(LabelEnv.add Return (Cont (ContVar k_ret))
(LabelEnv.singleton Throw (Cont (ContVar k_fail))))
LabelEnv.add Cleanup (Cont k_clean)
(LabelEnv.add Return (Cont k_ret)
(LabelEnv.singleton Throw (Cont k_fail)))
in
let r = match LabelEnv.find_opt Throw context with
| Some (Cont r) -> r
Expand All @@ -439,7 +454,7 @@ and c_exp' context exp k =
(fun v ->
check_call_perform_status
(k -@- varE v)
(fun e -> r -@- e))
(fun e -> varE r -*- e))
in
k' -@- cps_async
| PrimE (AwaitPrim s, [exp1]) ->
Expand All @@ -451,8 +466,6 @@ and c_exp' context exp k =
| Some (Cont r) -> r
| _ -> assert false
in
letcont b (fun b ->
letcont r (fun r ->
letcont k (fun k ->
let krb = List.map varE [k; r; b] |> tupE in
match eff exp1 with
Expand All @@ -461,7 +474,7 @@ and c_exp' context exp k =
| T.Await ->
c_exp context exp1
(meta (typ exp1) (fun v1 -> (cps_awaitE s (typ_of_var k) (varE v1) krb)))
)))
)
| DeclareE (id, typ, exp1) ->
unary context k (fun v1 -> e (DeclareE (id, typ, varE v1))) exp1
| DefineE (id, mut, exp1) ->
Expand All @@ -477,7 +490,7 @@ and c_exp' context exp k =
(fun v ->
check_call_perform_status
(k -@- varE v)
(fun e -> r -@- e))
(fun e -> varE r -*- e))
in
nary context k' (fun vs -> e (PrimE (p, vs))) exps
| PrimE (p, exps) ->
Expand Down Expand Up @@ -623,8 +636,8 @@ and t_comp_unit context = function
| T.Await ->
let throw = fresh_err_cont T.unit in
let context' =
LabelEnv.add Cleanup (Cont (ContVar (var "@cleanup" bail_contT)))
(LabelEnv.add Throw (Cont (ContVar throw)) context) in
LabelEnv.add Cleanup (Cont (var "@cleanup" bail_contT))
(LabelEnv.add Throw (Cont throw) context) in
let e = fresh_var "e" T.catch in
ProgU [
funcD throw e (assertE (falseE ()));
Expand All @@ -649,8 +662,8 @@ and t_ignore_throw context exp =
| _ ->
let throw = fresh_err_cont T.unit in
let context' =
LabelEnv.add Cleanup (Cont (ContVar (var "@cleanup" bail_contT)))
(LabelEnv.add Throw (Cont (ContVar throw)) context) in
LabelEnv.add Cleanup (Cont (var "@cleanup" bail_contT))
(LabelEnv.add Throw (Cont throw) context) in
let e = fresh_var "e" T.catch in
{ (blockE [
funcD throw e (tupE[]);
Expand Down

0 comments on commit c7e5ac9

Please sign in to comment.