diff --git a/asm-to-pil/src/vm_to_constrained.rs b/asm-to-pil/src/vm_to_constrained.rs index be4b99593..813027e6f 100644 --- a/asm-to-pil/src/vm_to_constrained.rs +++ b/asm-to-pil/src/vm_to_constrained.rs @@ -13,8 +13,8 @@ use powdr_ast::{ build::{self, absolute_reference, direct_reference, next_reference}, visitor::ExpressionVisitable, ArrayExpression, BinaryOperation, BinaryOperator, Expression, FunctionCall, - FunctionDefinition, FunctionKind, LambdaExpression, MatchArm, Number, Pattern, - PilStatement, PolynomialName, SelectedExpressions, UnaryOperation, UnaryOperator, + FunctionDefinition, FunctionKind, LambdaExpression, MatchArm, MatchExpression, Number, + Pattern, PilStatement, PolynomialName, SelectedExpressions, UnaryOperation, UnaryOperator, }, SourceRef, }; @@ -744,7 +744,7 @@ impl VMConverter { Expression::String(_) => panic!(), Expression::Tuple(_) => panic!(), Expression::ArrayLiteral(_) => panic!(), - Expression::MatchExpression(_, _) => panic!(), + Expression::MatchExpression(_) => panic!(), Expression::IfExpression(_) => panic!(), Expression::BlockExpression(_, _) => panic!(), Expression::FreeInput(expr) => { @@ -981,17 +981,28 @@ impl VMConverter { value: absolute_reference("::std::prover::Query::None"), }); - FunctionDefinition::Expression(Expression::LambdaExpression(LambdaExpression { + let scrutinee = Box::new( + FunctionCall { + function: Box::new(absolute_reference("::std::prover::eval")), + arguments: vec![direct_reference(pc_name.as_ref().unwrap())], + } + .into(), + ); + + let lambda = LambdaExpression { kind: FunctionKind::Query, params: vec![Pattern::Variable("__i".to_string())], - body: Box::new(Expression::MatchExpression( - Box::new(Expression::FunctionCall(FunctionCall { - function: Box::new(absolute_reference("::std::prover::eval")), - arguments: vec![direct_reference(pc_name.as_ref().unwrap())], - })), - prover_query_arms, - )), - })) + body: Box::new( + MatchExpression { + scrutinee, + arms: prover_query_arms, + } + .into(), + ), + } + .into(); + + FunctionDefinition::Expression(lambda) }); witness_column(SourceRef::unknown(), free_value, prover_query) }) diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index b4ff39192..54d9d4e9b 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -614,10 +614,8 @@ impl Display for Expression { Expression::IndexAccess(index_access) => write!(f, "{index_access}"), Expression::FunctionCall(fun_call) => write!(f, "{fun_call}"), Expression::FreeInput(input) => write!(f, "${{ {input} }}"), - Expression::MatchExpression(scrutinee, arms) => { - writeln!(f, "match {scrutinee} {{")?; - write_items_indented(f, arms)?; - write!(f, "}}") + Expression::MatchExpression(match_expr) => { + write!(f, "{match_expr}") } Expression::IfExpression(e) => write!(f, "{e}"), Expression::BlockExpression(statements, expr) => { @@ -673,6 +671,14 @@ impl Display for LambdaExpression { } } +impl Display for MatchExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + writeln!(f, "match {} {{", self.scrutinee)?; + write_items_indented(f, &self.arms)?; + write!(f, "}}") + } +} + impl Display for FunctionKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index a4624acd2..6cfa5e99c 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -347,7 +347,7 @@ pub enum Expression { IndexAccess(IndexAccess), FunctionCall(FunctionCall), FreeInput(Box), - MatchExpression(Box, Vec>), + MatchExpression(MatchExpression), IfExpression(IfExpression), BlockExpression(Vec>, Box), } @@ -422,6 +422,33 @@ impl Expression { } } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] +pub struct MatchExpression> { + pub scrutinee: Box, + pub arms: Vec>, +} + +impl From>> for Expression { + fn from(match_expr: MatchExpression>) -> Self { + Expression::MatchExpression(match_expr) + } +} + +impl Children for MatchExpression { + fn children(&self) -> Box + '_> { + Box::new( + once(self.scrutinee.as_ref()).chain(self.arms.iter().flat_map(|arm| arm.children())), + ) + } + + fn children_mut(&mut self) -> Box + '_> { + Box::new( + once(self.scrutinee.as_mut()) + .chain(self.arms.iter_mut().flat_map(|arm| arm.children_mut())), + ) + } +} + impl From>> for Expression { fn from(operation: BinaryOperation>) -> Self { Expression::BinaryOperation(operation) @@ -486,7 +513,7 @@ impl Expression { | Expression::String(_) | Expression::Number(_) => empty(), Expression::Tuple(v) => v.iter(), - Expression::LambdaExpression(LambdaExpression { body, .. }) => once(body.as_ref()), + Expression::LambdaExpression(lambda) => lambda.children(), Expression::ArrayLiteral(ArrayLiteral { items }) => items.iter(), Expression::BinaryOperation(BinaryOperation { left, right, .. }) => { [left.as_ref(), right.as_ref()].into_iter() @@ -500,9 +527,7 @@ impl Expression { arguments, }) => once(function.as_ref()).chain(arguments.iter()), Expression::FreeInput(e) => once(e.as_ref()), - Expression::MatchExpression(e, arms) => { - once(e.as_ref()).chain(arms.iter().flat_map(|arm| arm.children())) - } + Expression::MatchExpression(match_expr) => match_expr.children(), Expression::IfExpression(IfExpression { condition, body, @@ -527,7 +552,7 @@ impl Expression { } Expression::Number(_) => empty(), Expression::Tuple(v) => v.iter_mut(), - Expression::LambdaExpression(LambdaExpression { body, .. }) => once(body.as_mut()), + Expression::LambdaExpression(lambda) => lambda.children_mut(), Expression::ArrayLiteral(ArrayLiteral { items }) => items.iter_mut(), Expression::BinaryOperation(BinaryOperation { left, right, .. }) => { [left.as_mut(), right.as_mut()].into_iter() @@ -541,9 +566,7 @@ impl Expression { arguments, }) => once(function.as_mut()).chain(arguments.iter_mut()), Expression::FreeInput(e) => once(e.as_mut()), - Expression::MatchExpression(e, arms) => { - once(e.as_mut()).chain(arms.iter_mut().flat_map(|arm| arm.children_mut())) - } + Expression::MatchExpression(match_expr) => match_expr.children_mut(), Expression::IfExpression(IfExpression { condition, body, @@ -601,6 +624,12 @@ pub struct LambdaExpression> { pub body: Box, } +impl From>> for Expression { + fn from(lambda: LambdaExpression>) -> Self { + Expression::LambdaExpression(lambda) + } +} + impl Children for LambdaExpression { fn children(&self) -> Box + '_> { Box::new(once(self.body.as_ref())) @@ -786,6 +815,12 @@ pub struct FunctionCall> { pub arguments: Vec, } +impl From>> for Expression { + fn from(call: FunctionCall>) -> Self { + Expression::FunctionCall(call) + } +} + impl Children for FunctionCall { fn children(&self) -> Box + '_> { Box::new(once(self.function.as_ref()).chain(self.arguments.iter())) diff --git a/importer/src/path_canonicalizer.rs b/importer/src/path_canonicalizer.rs index dec476b3f..5f7fd118d 100644 --- a/importer/src/path_canonicalizer.rs +++ b/importer/src/path_canonicalizer.rs @@ -15,8 +15,8 @@ use powdr_ast::parsed::{ types::{Type, TypeScheme}, visitor::{Children, ExpressionVisitable}, ArrayLiteral, BinaryOperation, EnumDeclaration, EnumVariant, Expression, FunctionCall, - IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, Pattern, PilStatement, - StatementInsideBlock, TypedExpression, UnaryOperation, + IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Pattern, + PilStatement, StatementInsideBlock, TypedExpression, UnaryOperation, }; /// Changes all symbol references (symbol paths) from relative paths @@ -172,7 +172,7 @@ fn free_inputs_in_expression<'a>( Expression::LambdaExpression(_) => todo!(), Expression::ArrayLiteral(_) => todo!(), Expression::IndexAccess(_) => todo!(), - Expression::MatchExpression(_, _) => todo!(), + Expression::MatchExpression(_) => todo!(), Expression::IfExpression(_) => todo!(), Expression::BlockExpression(_, _) => todo!(), } @@ -209,7 +209,7 @@ fn free_inputs_in_expression_mut<'a>( Expression::LambdaExpression(_) => todo!(), Expression::ArrayLiteral(_) => todo!(), Expression::IndexAccess(_) => todo!(), - Expression::MatchExpression(_, _) => todo!(), + Expression::MatchExpression(_) => todo!(), Expression::IfExpression(_) => todo!(), Expression::BlockExpression(_, _) => todo!(), } @@ -242,8 +242,8 @@ fn canonicalize_inside_expression( canonicalize_inside_pattern(p, path, paths); }); } - Expression::MatchExpression(_, match_arms) => { - match_arms.iter_mut().for_each(|MatchArm { pattern, .. }| { + Expression::MatchExpression(MatchExpression { scrutinee: _, arms }) => { + arms.iter_mut().for_each(|MatchArm { pattern, .. }| { canonicalize_inside_pattern(pattern, path, paths); }) } @@ -679,7 +679,7 @@ fn check_expression( check_expression(location, function, state, local_variables)?; check_expressions(location, arguments, state, local_variables) } - Expression::MatchExpression(scrutinee, arms) => { + Expression::MatchExpression(MatchExpression { scrutinee, arms }) => { check_expression(location, scrutinee, state, local_variables)?; arms.iter().try_for_each(|MatchArm { pattern, value }| { let mut local_variables = local_variables.clone(); diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index cd712141e..3b3676224 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -609,7 +609,7 @@ GenericReference: NamespacedPolynomialReference = { } MatchExpression: Box = { - "match" "{" "}" => Box::new(Expression::MatchExpression(<>)) + "match" "{" "}" => Box::new(MatchExpression{scrutinee, arms}.into()) } MatchArms: Vec = { diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index 18b50a467..b8a73f16e 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -17,8 +17,8 @@ use powdr_ast::{ display::quote, types::{Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, - LambdaExpression, LetStatementInsideBlock, MatchArm, Number, Pattern, StatementInsideBlock, - UnaryOperation, UnaryOperator, + LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Number, Pattern, + StatementInsideBlock, UnaryOperation, UnaryOperator, }, SourceRef, }; @@ -726,7 +726,10 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { .extend(arguments.iter().rev().map(Operation::Expand)); self.expand(function)?; } - Expression::MatchExpression(condition, _) + Expression::MatchExpression(MatchExpression { + scrutinee: condition, + arms: _, + }) | Expression::IfExpression(IfExpression { condition, .. }) => { // Only handle the scrutinee / condition for now, we do not want to evaluate all arms. self.op_stack.push(Operation::Combine(expr)); @@ -869,7 +872,7 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { let function = self.value_stack.pop().unwrap(); return self.combine_function_call(function, arguments); } - Expression::MatchExpression(_, arms) => { + Expression::MatchExpression(MatchExpression { scrutinee: _, arms }) => { let v = self.value_stack.pop().unwrap(); let (vars, body) = arms .iter() diff --git a/pil-analyzer/src/expression_processor.rs b/pil-analyzer/src/expression_processor.rs index b0bb9fde3..937f82dec 100644 --- a/pil-analyzer/src/expression_processor.rs +++ b/pil-analyzer/src/expression_processor.rs @@ -8,8 +8,9 @@ use powdr_ast::{ analyzed::{Expression, PolynomialReference, Reference, RepeatedArray}, parsed::{ self, asm::SymbolPath, ArrayExpression, ArrayLiteral, BinaryOperation, IfExpression, - LambdaExpression, LetStatementInsideBlock, MatchArm, NamespacedPolynomialReference, Number, - Pattern, SelectedExpressions, StatementInsideBlock, SymbolCategory, UnaryOperation, + LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, + NamespacedPolynomialReference, Number, Pattern, SelectedExpressions, StatementInsideBlock, + SymbolCategory, UnaryOperation, }, }; use powdr_number::DegreeType; @@ -121,9 +122,10 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> { function: Box::new(self.process_expression(*c.function)), arguments: self.process_expressions(c.arguments), }), - PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( - Box::new(self.process_expression(*scrutinee)), - arms.into_iter() + PExpression::MatchExpression(MatchExpression { scrutinee, arms }) => MatchExpression { + scrutinee: Box::new(self.process_expression(*scrutinee)), + arms: arms + .into_iter() .map(|MatchArm { pattern, value }| { let vars = self.save_local_variables(); let pattern = self.process_pattern(pattern); @@ -132,7 +134,8 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> { MatchArm { pattern, value } }) .collect(), - ), + } + .into(), PExpression::IfExpression(IfExpression { condition, body, diff --git a/pil-analyzer/src/type_inference.rs b/pil-analyzer/src/type_inference.rs index 7b35aa709..3ea9499b1 100644 --- a/pil-analyzer/src/type_inference.rs +++ b/pil-analyzer/src/type_inference.rs @@ -8,7 +8,8 @@ use powdr_ast::{ types::{ArrayType, FunctionType, TupleType, Type, TypeBounds, TypeScheme}, visitor::ExpressionVisitable, ArrayLiteral, BinaryOperation, FunctionCall, IndexAccess, LambdaExpression, - LetStatementInsideBlock, MatchArm, Number, Pattern, StatementInsideBlock, UnaryOperation, + LetStatementInsideBlock, MatchArm, MatchExpression, Number, Pattern, StatementInsideBlock, + UnaryOperation, }, }; @@ -585,7 +586,7 @@ impl<'a> TypeChecker<'a> { })? } Expression::FreeInput(_) => todo!(), - Expression::MatchExpression(scrutinee, arms) => { + Expression::MatchExpression(MatchExpression { scrutinee, arms }) => { let scrutinee_type = self.infer_type_of_expression(scrutinee)?; let result = self.new_type_var(); for MatchArm { pattern, value } in arms { diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index 2ce0a3d58..7e41aa273 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -962,7 +962,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { } } } - Expression::MatchExpression(_, _) => todo!(), + Expression::MatchExpression(_) => todo!(), Expression::IfExpression(_) => panic!(), Expression::BlockExpression(_, _) => panic!(), Expression::IndexAccess(_) => todo!(),