From c7e5ac9a1b0871af21f647e21639e521c950b779 Mon Sep 17 00:00:00 2001 From: Claudio Russo Date: Tue, 30 Jul 2024 15:13:46 +0100 Subject: [PATCH] refactor: simplify `await.ml` (#4633) Simple inspection reveals that we never add `MetaCont` to the context, which means we can simplify the code to enforce that invariant and consider fewer cases. This is an experiment to that effect, building on #4632. --- src/ir_passes/await.ml | 121 +++++++++++++++++++++++------------------ 1 file changed, 67 insertions(+), 54 deletions(-) diff --git a/src/ir_passes/await.ml b/src/ir_passes/await.ml index a6cff58cafe..787a2c8e1d4 100644 --- a/src/ir_passes/await.ml +++ b/src/ir_passes/await.ml @@ -39,31 +39,19 @@ let letcont k scope = blockE [funcD k' v e] (* at this point, I'm really worried about variable capture *) (scope k') -(* pre-compose a continuation with a call to a `finally`-thunk *) -let precont k vthunk = - let finally e = blockE [expD (varE vthunk -*- unitE ())] e in - match k with - | ContVar k' -> - let typ = match typ_of_var k' with - | T.(Func (Local, Returns, [], ts1, _)) -> T.seq ts1 - | _ -> assert false in - MetaCont (typ, fun v -> finally (varE k' -*- varE v)) - | MetaCont (typ, cont) -> - MetaCont (typ, fun v -> finally (cont v)) - (* Named labels for break, special labels for return, throw and cleanup *) type label = Return | Throw | Cleanup | Named of string let ( -@- ) k exp2 = match k with | ContVar v -> - varE v -*- exp2 + varE v -*- exp2 | MetaCont (typ0, k) -> - match exp2.it with - | VarE v -> k (var v (typ exp2)) - | _ -> - let u = fresh_var "u" typ0 in - letE u exp2 (k u) + match exp2.it with + | VarE v -> k (var v (typ exp2)) + | _ -> + let u = fresh_var "u" typ0 in + letE u exp2 (k u) (* Label environments *) @@ -71,7 +59,30 @@ module LabelEnv = Env.Make(struct type t = label let compare = compare end) module PatEnv = Env.Make(String) -type label_sort = Cont of kont | Label +type label_sort = Cont of var | Label + +let precompose vthunk k = + let typ0 = match typ_of_var k with + | T.(Func (Local, Returns, [], ts1, _)) -> T.seq ts1 + | _ -> assert false in + let v = fresh_var "v" typ0 in + let e = blockE [expD (varE vthunk -*- unitE ())] (varE k -*- varE v) in + let k' = fresh_cont typ0 (typ e) in + (k', funcD k' v e) + +let preconts context vthunk scope = + let (ds, ctxt) = LabelEnv.fold + (fun lab sort (ds, ctxt) -> + match sort with + | Label -> assert false + | Cont k -> + let (k', d) = precompose vthunk k in + (d :: ds, + LabelEnv.add lab (Cont k') ctxt)) + context + ([], LabelEnv.empty) + in + blockE ds (scope ctxt) let typ_cases cases = List.fold_left (fun t case -> T.lub t (typ case.it.exp)) T.Non cases @@ -84,9 +95,9 @@ let rec t_async context exp = let k_fail = fresh_err_cont T.unit in let k_clean = fresh_bail_cont T.unit in let context' = - LabelEnv.add Cleanup (Cont (ContVar k_clean)) - (LabelEnv.add Return (Cont (ContVar k_ret)) - (LabelEnv.singleton Throw (Cont (ContVar k_fail)))) + LabelEnv.add Cleanup (Cont k_clean) + (LabelEnv.add Return (Cont k_ret) + (LabelEnv.singleton Throw (Cont k_fail))) in cps_asyncE s typ1 (typ exp1) (forall [tb] ([k_ret; k_fail; k_clean] -->* @@ -122,14 +133,14 @@ and t_exp' context exp = | PrimE (BreakPrim id, [exp1]) -> begin match LabelEnv.find_opt (Named id) context with - | Some (Cont k) -> (retE (k -@- t_exp context exp1)).it + | Some (Cont k) -> (retE (varE k -*- t_exp context exp1)).it | Some Label -> (breakE id (t_exp context exp1)).it | None -> assert false end | PrimE (RetPrim, [exp1]) -> begin match LabelEnv.find_opt Return context with - | Some (Cont k) -> (retE (k -@- t_exp context exp1)).it + | Some (Cont k) -> (retE (varE k -*- t_exp context exp1)).it | Some Label -> (retE (t_exp context exp1)).it | None -> assert false end @@ -341,21 +352,25 @@ and c_exp' context exp k = note = Note.{ exp.note with typ = typ' } })) end) | TryE (exp1, cases, finally_opt) -> - let pre k = + let precont k scope = + match finally_opt with + | Some (id2, typ2) -> + let vthunk = var id2 typ2 in + let (k', d) = precompose vthunk k in + blockE [d] (scope k') + | None -> + scope k in + let finalise context scope = match finally_opt with - | Some (id2, typ2) -> precont k (var id2 typ2) - | None -> k in - let pre' = function - | Cont k -> Cont (pre k) - | Label -> assert false in - (* All control-flow out must pass through the potential `finally` thunk *) - let context = LabelEnv.map pre' context in - (* assert that a context (top-level or async) has set up a `Cleanup` cont *) + | Some (id2, typ2) -> preconts context (var id2 typ2) scope + | None -> scope context + in + finalise context (fun context -> + (* assert that a context (top-level or async) has set up a `Cleanup` and `Throw` cont *) assert (LabelEnv.find_opt Cleanup context <> None); - (* TODO: do we need to reify f? *) let f = match LabelEnv.find Throw context with Cont f -> f | _ -> assert false in - letcont f (fun f -> - letcont (pre k) (fun k -> + letcont k (fun k -> + precont k (fun k -> match eff exp1 with | T.Triv -> varE k -*- t_exp context exp1 @@ -379,7 +394,7 @@ and c_exp' context exp k = cases @ if omit_rethrow then [] else [rethrow] in let throw = fresh_err_cont (answerT (typ_of_var k)) in - let context' = LabelEnv.add Throw (Cont (ContVar throw)) context in + let context' = LabelEnv.add Throw (Cont throw) context in blockE [ let e = fresh_var "e" T.catch in funcD throw e { @@ -389,30 +404,30 @@ and c_exp' context exp k = } ] (c_exp context' exp1 (ContVar k)) - )) + ))) | LoopE exp1 -> c_loop context k exp1 | LabelE (id, _typ, exp1) -> letcont k (fun k -> - let context' = LabelEnv.add (Named id) (Cont (ContVar k)) context in + let context' = LabelEnv.add (Named id) (Cont k) context in c_exp context' exp1 (ContVar k)) (* TODO optimize me, if possible *) | PrimE (BreakPrim id, [exp1]) -> begin match LabelEnv.find_opt (Named id) context with - | Some (Cont k') -> c_exp context exp1 k' + | Some (Cont k') -> c_exp context exp1 (ContVar k') | _ -> assert false end | PrimE (RetPrim, [exp1]) -> begin match LabelEnv.find_opt Return context with - | Some (Cont k') -> c_exp context exp1 k' + | Some (Cont k') -> c_exp context exp1 (ContVar k') | _ -> assert false end | PrimE (ThrowPrim, [exp1]) -> begin match LabelEnv.find_opt Throw context with - | Some (Cont k') -> c_exp context exp1 k' + | Some (Cont k') -> c_exp context exp1 (ContVar k') | _ -> assert false end | AsyncE (T.Cmp, tb, exp1, typ1) -> @@ -423,9 +438,9 @@ and c_exp' context exp k = let k_fail = fresh_err_cont T.unit in let k_clean = fresh_bail_cont T.unit in let context' = - LabelEnv.add Cleanup (Cont (ContVar k_clean)) - (LabelEnv.add Return (Cont (ContVar k_ret)) - (LabelEnv.singleton Throw (Cont (ContVar k_fail)))) + LabelEnv.add Cleanup (Cont k_clean) + (LabelEnv.add Return (Cont k_ret) + (LabelEnv.singleton Throw (Cont k_fail))) in let r = match LabelEnv.find_opt Throw context with | Some (Cont r) -> r @@ -439,7 +454,7 @@ and c_exp' context exp k = (fun v -> check_call_perform_status (k -@- varE v) - (fun e -> r -@- e)) + (fun e -> varE r -*- e)) in k' -@- cps_async | PrimE (AwaitPrim s, [exp1]) -> @@ -451,8 +466,6 @@ and c_exp' context exp k = | Some (Cont r) -> r | _ -> assert false in - letcont b (fun b -> - letcont r (fun r -> letcont k (fun k -> let krb = List.map varE [k; r; b] |> tupE in match eff exp1 with @@ -461,7 +474,7 @@ and c_exp' context exp k = | T.Await -> c_exp context exp1 (meta (typ exp1) (fun v1 -> (cps_awaitE s (typ_of_var k) (varE v1) krb))) - ))) + ) | DeclareE (id, typ, exp1) -> unary context k (fun v1 -> e (DeclareE (id, typ, varE v1))) exp1 | DefineE (id, mut, exp1) -> @@ -477,7 +490,7 @@ and c_exp' context exp k = (fun v -> check_call_perform_status (k -@- varE v) - (fun e -> r -@- e)) + (fun e -> varE r -*- e)) in nary context k' (fun vs -> e (PrimE (p, vs))) exps | PrimE (p, exps) -> @@ -623,8 +636,8 @@ and t_comp_unit context = function | T.Await -> let throw = fresh_err_cont T.unit in let context' = - LabelEnv.add Cleanup (Cont (ContVar (var "@cleanup" bail_contT))) - (LabelEnv.add Throw (Cont (ContVar throw)) context) in + LabelEnv.add Cleanup (Cont (var "@cleanup" bail_contT)) + (LabelEnv.add Throw (Cont throw) context) in let e = fresh_var "e" T.catch in ProgU [ funcD throw e (assertE (falseE ())); @@ -649,8 +662,8 @@ and t_ignore_throw context exp = | _ -> let throw = fresh_err_cont T.unit in let context' = - LabelEnv.add Cleanup (Cont (ContVar (var "@cleanup" bail_contT))) - (LabelEnv.add Throw (Cont (ContVar throw)) context) in + LabelEnv.add Cleanup (Cont (var "@cleanup" bail_contT)) + (LabelEnv.add Throw (Cont throw) context) in let e = fresh_var "e" T.catch in { (blockE [ funcD throw e (tupE[]);