Skip to content

Commit

Permalink
Merge pull request #1106 from daniel-larraz/forall-enum-record-fields
Browse files Browse the repository at this point in the history
Fix handling of enum fields in quantified records
  • Loading branch information
daniel-larraz authored Oct 8, 2024
2 parents 9e815c4 + fa80e03 commit 81056b5
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 8 deletions.
82 changes: 78 additions & 4 deletions src/lustre/lustreAstNormalizer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,81 @@ let mk_fresh_node_arg_local info pos is_const expr_type expr =
NodeArgCache.add node_arg_cache expr nexpr;
nexpr, gids

let mk_range_expr ctx node_id expr_type expr =
let mk_range_expr ctx node_id expr_type expr =
let rec mk ctx n expr_type expr =
let expr_type = Chk.expand_type_syn_reftype_history ctx expr_type |> unwrap in
match expr_type with
| A.IntRange (_, l, u) ->
let original_ty, _ = Chk.infer_type_expr ctx node_id expr |> unwrap in
let original_ty = Chk.expand_type_syn_reftype_history ctx original_ty |> unwrap in
let user_prop, is_original = match original_ty with
| A.IntRange (_, l', u') ->
let eval_int_expr_opt expr = match expr with
| Some expr -> Some (AIC.eval_int_expr ctx expr)
| None -> None
in
let is_original =
let (l, u) = eval_int_expr_opt l, eval_int_expr_opt u in
let (l', u') = eval_int_expr_opt l', eval_int_expr_opt u' in
(match (l, u, l', u') with
| Some (Ok l), Some (Ok u), Some (Ok l'), Some (Ok u') -> l = l' && u = u'
| Some (Ok l), None, Some (Ok l'), None -> l = l'
| None, Some (Ok u), None, Some (Ok u') -> u = u'
| None, None, None, None -> true
| _ -> false)
in
let user_prop = if is_original then []
else
match l', u' with
| Some l', Some u' ->
let l' = A.CompOp (dpos, A.Lte, l', expr) in
let u' = A.CompOp (dpos, A.Lte, expr, u') in
[A.BinaryOp (dpos, A.And, l', u'), true]
| Some l', None -> [A.CompOp (dpos, A.Lte, l', expr), true]
| None, Some u' -> [A.CompOp (dpos, A.Lte, expr, u'), true]
| None, None -> [(A.Const (dpos, A.True)), true]
in
user_prop, is_original
| A.Int _ -> [], false
| _ -> assert false
in (match l, u with
| Some l, Some u ->
let l = A.CompOp (dpos, A.Lte, l, expr) in
let u = A.CompOp (dpos, A.Lte, expr, u) in
[A.BinaryOp (dpos, A.And, l, u), is_original] @ user_prop
| Some l, None ->
[A.CompOp (dpos, A.Lte, l, expr), is_original] @ user_prop
| None, Some u ->
[A.CompOp (dpos, A.Lte, expr, u), is_original] @ user_prop
| None, None -> [(A.Const (dpos, A.True)), is_original] @ user_prop
)
| A.ArrayType (_, (ty, upper_bound)) ->
let id_str = HString.concat2 (HString.mk_hstring "x") (HString.mk_hstring (string_of_int n)) in
let id = A.Ident (dpos, id_str) in
let ctx = Ctx.add_ty ctx id_str (A.Int dpos) in
let expr = A.ArrayIndex (dpos, expr, id) in
let rexpr = mk ctx (succ n) ty expr in
let l = A.CompOp (dpos, A.Lte, A.Const (dpos, A.Num (HString.mk_hstring "0")), id) in
let u = A.CompOp (dpos, A.Lt, id, upper_bound) in
let assumption = A.BinaryOp (dpos, A.And, l, u) in
let var = dpos, id_str, (A.Int dpos) in
let body = fun e -> A.BinaryOp (dpos, A.Impl, assumption, e) in
List.map (fun (e, is_original) -> A.Quantifier (dpos, A.Forall, [var], body e), is_original) rexpr
| TupleType (_, tys) ->
let mk_proj i = A.TupleProject (dpos, expr, i) in
let tys = List.filter (fun ty -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.mapi (fun i ty -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| RecordType (_, _, tys) ->
let mk_proj i = A.RecordProject (dpos, expr, i) in
let tys = List.filter (fun (_, _, ty) -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.map (fun (_, i, ty) -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| _ -> []
in
mk ctx 0 expr_type expr

let mk_enum_range_expr ctx node_id expr_type expr =
let rec mk ctx n expr_type expr =
let expr_type = Chk.expand_type_syn_reftype_history ctx expr_type |> unwrap in
match expr_type with
Expand Down Expand Up @@ -473,12 +547,12 @@ let mk_range_expr ctx node_id expr_type expr =
List.map (fun (e, is_original) -> A.Quantifier (dpos, A.Forall, [var], body e), is_original) rexpr
| TupleType (_, tys) ->
let mk_proj i = A.TupleProject (dpos, expr, i) in
let tys = List.filter (fun ty -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.filter (fun ty -> Ctx.type_contains_enum_or_subrange ctx ty) tys in
let tys = List.mapi (fun i ty -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| RecordType (_, _, tys) ->
let mk_proj i = A.RecordProject (dpos, expr, i) in
let tys = List.filter (fun (_, _, ty) -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.filter (fun (_, _, ty) -> Ctx.type_contains_enum_or_subrange ctx ty) tys in
let tys = List.map (fun (_, i, ty) -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| _ -> []
Expand Down Expand Up @@ -1683,7 +1757,7 @@ and normalize_expr ?guard info node_id map =
(fun acc (_, id, ty) ->
let expr = A.Ident(dpos, id) in
let range_exprs =
List.map fst (mk_range_expr info.context (Some node_id) ty expr) @
List.map fst (mk_enum_range_expr info.context (Some node_id) ty expr) @
List.map snd (mk_ref_type_expr info.context expr Local ty)
in
range_exprs :: acc
Expand Down
6 changes: 6 additions & 0 deletions src/lustre/lustreAstNormalizer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ val mk_range_expr : TypeCheckerContext.tc_context ->
LustreAst.expr ->
(LustreAst.expr * bool) list

val mk_enum_range_expr : TypeCheckerContext.tc_context ->
HString.t option ->
LustreAst.lustre_type ->
LustreAst.expr ->
(LustreAst.expr * bool) list

val normalize : TypeCheckerContext.tc_context ->
LustreAbstractInterpretation.context ->
LustreAst.t ->
Expand Down
55 changes: 51 additions & 4 deletions src/lustre/typeCheckerContext.ml
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,35 @@ let rec type_contains_subrange ctx = function
| Some ty -> type_contains_subrange ctx ty
| None -> assert false
)
| _ -> false
| Bool _ | Int _ | Real _ | EnumType _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_enum_or_subrange ctx = function
| LA.IntRange _
| EnumType _ -> true
| RefinementType (_, (_, _, ty), _) -> type_contains_enum_or_subrange ctx ty
| TupleType (_, tys) | GroupType (_, tys) ->
List.fold_left (fun acc ty -> acc || type_contains_enum_or_subrange ctx ty) false tys
| RecordType (_, _, tys) ->
List.fold_left (fun acc (_, _, ty) -> acc || type_contains_enum_or_subrange ctx ty)
false tys
| ArrayType (_, (ty, _)) -> type_contains_enum_or_subrange ctx ty
| TArr (_, ty1, ty2) -> type_contains_enum_or_subrange ctx ty1 || type_contains_enum_or_subrange ctx ty2
| History (_, id) ->
(match lookup_ty ctx id with
| Some ty -> type_contains_enum_or_subrange ctx ty
| _ -> assert false)
| UserType (_, ty_args, id) -> (
match lookup_ty_syn ctx id ty_args with
| Some ty -> type_contains_enum_or_subrange ctx ty
| None -> assert false
)
| Bool _ | Int _ | Real _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_ref ctx = function
| LA.RefinementType _ -> true
Expand All @@ -716,7 +744,15 @@ let rec type_contains_subrange ctx = function
(match lookup_ty ctx id with
| Some ty -> type_contains_ref ctx ty
| _ -> assert false)
| _ -> false
| UserType (_, ty_args, id) -> (
match lookup_ty_syn ctx id ty_args with
| Some ty -> type_contains_ref ctx ty
| None -> false
)
| Bool _ | Int _ | Real _ | EnumType _ | IntRange _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_enum_subrange_reftype ctx = function
| LA.IntRange _
Expand All @@ -733,7 +769,15 @@ let rec type_contains_enum_subrange_reftype ctx = function
(match lookup_ty ctx id with
| Some ty -> type_contains_enum_subrange_reftype ctx ty
| _ -> assert false)
| _ -> false
| UserType (_, ty_args, id) -> (
match lookup_ty_syn ctx id ty_args with
| Some ty -> type_contains_enum_subrange_reftype ctx ty
| None -> assert false
)
| Bool _ | Int _ | Real _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_abstract ctx = function
| LA.UserType (_, ty_args, id) ->
Expand All @@ -753,7 +797,10 @@ let rec type_contains_abstract ctx = function
(match lookup_ty ctx id with
| Some ty -> type_contains_abstract ctx ty
| _ -> assert false)
| _ -> false
| Bool _ | Int _ | Real _ | EnumType _ | IntRange _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec ty_vars_of_expr ctx node_name expr =
let call = ty_vars_of_expr ctx node_name in match expr with
Expand Down
3 changes: 3 additions & 0 deletions src/lustre/typeCheckerContext.mli
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ val is_machine_type_of_associated_width: tc_context -> (LA.lustre_type * LA.lust
val type_contains_subrange : tc_context -> LA.lustre_type -> bool
(** Returns true if the lustre type expression contains an IntRange or if it is an IntRange *)

val type_contains_enum_or_subrange : tc_context -> LA.lustre_type -> bool
(** Returns true if the lustre type expression contains an EnumType/IntRange or if it is an EnumType/IntRange *)

val type_contains_ref : tc_context -> LA.lustre_type -> bool
(** Returns true if the lustre type expression contains a RefinementType or if it is an RefinementType *)

Expand Down
10 changes: 10 additions & 0 deletions tests/regression/success/forall_enum.lus
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
type E = enum { E1, E2 };

type R = struct {
f: E;
};

node N() returns (y:int);
let
check forall (x: R) (x.f=E1 or x.f=E2);
tel

0 comments on commit 81056b5

Please sign in to comment.