Skip to content

Commit

Permalink
MatchExpression struct in Expressions (#1354)
Browse files Browse the repository at this point in the history
This PR is part of issue #1345.
In particular, it adds the struct MatchExpression to Expressions to
homogenise the structure before including source references.
  • Loading branch information
gzanitti authored May 22, 2024
1 parent 98da495 commit c10ffb1
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 46 deletions.
35 changes: 23 additions & 12 deletions asm-to-pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use powdr_ast::{
build::{self, absolute_reference, direct_reference, next_reference},
visitor::ExpressionVisitable,
ArrayExpression, BinaryOperation, BinaryOperator, Expression, FunctionCall,
FunctionDefinition, FunctionKind, LambdaExpression, MatchArm, Number, Pattern,
PilStatement, PolynomialName, SelectedExpressions, UnaryOperation, UnaryOperator,
FunctionDefinition, FunctionKind, LambdaExpression, MatchArm, MatchExpression, Number,
Pattern, PilStatement, PolynomialName, SelectedExpressions, UnaryOperation, UnaryOperator,
},
SourceRef,
};
Expand Down Expand Up @@ -744,7 +744,7 @@ impl<T: FieldElement> VMConverter<T> {
Expression::String(_) => panic!(),
Expression::Tuple(_) => panic!(),
Expression::ArrayLiteral(_) => panic!(),
Expression::MatchExpression(_, _) => panic!(),
Expression::MatchExpression(_) => panic!(),
Expression::IfExpression(_) => panic!(),
Expression::BlockExpression(_, _) => panic!(),
Expression::FreeInput(expr) => {
Expand Down Expand Up @@ -981,17 +981,28 @@ impl<T: FieldElement> VMConverter<T> {
value: absolute_reference("::std::prover::Query::None"),
});

FunctionDefinition::Expression(Expression::LambdaExpression(LambdaExpression {
let scrutinee = Box::new(
FunctionCall {
function: Box::new(absolute_reference("::std::prover::eval")),
arguments: vec![direct_reference(pc_name.as_ref().unwrap())],
}
.into(),
);

let lambda = LambdaExpression {
kind: FunctionKind::Query,
params: vec![Pattern::Variable("__i".to_string())],
body: Box::new(Expression::MatchExpression(
Box::new(Expression::FunctionCall(FunctionCall {
function: Box::new(absolute_reference("::std::prover::eval")),
arguments: vec![direct_reference(pc_name.as_ref().unwrap())],
})),
prover_query_arms,
)),
}))
body: Box::new(
MatchExpression {
scrutinee,
arms: prover_query_arms,
}
.into(),
),
}
.into();

FunctionDefinition::Expression(lambda)
});
witness_column(SourceRef::unknown(), free_value, prover_query)
})
Expand Down
14 changes: 10 additions & 4 deletions ast/src/parsed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,10 +614,8 @@ impl<Ref: Display> Display for Expression<Ref> {
Expression::IndexAccess(index_access) => write!(f, "{index_access}"),
Expression::FunctionCall(fun_call) => write!(f, "{fun_call}"),
Expression::FreeInput(input) => write!(f, "${{ {input} }}"),
Expression::MatchExpression(scrutinee, arms) => {
writeln!(f, "match {scrutinee} {{")?;
write_items_indented(f, arms)?;
write!(f, "}}")
Expression::MatchExpression(match_expr) => {
write!(f, "{match_expr}")
}
Expression::IfExpression(e) => write!(f, "{e}"),
Expression::BlockExpression(statements, expr) => {
Expand Down Expand Up @@ -673,6 +671,14 @@ impl<E: Display> Display for LambdaExpression<E> {
}
}

impl<E: Display> Display for MatchExpression<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
writeln!(f, "match {} {{", self.scrutinee)?;
write_items_indented(f, &self.arms)?;
write!(f, "}}")
}
}

impl Display for FunctionKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
53 changes: 44 additions & 9 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ pub enum Expression<Ref = NamespacedPolynomialReference> {
IndexAccess(IndexAccess<Self>),
FunctionCall(FunctionCall<Self>),
FreeInput(Box<Self>),
MatchExpression(Box<Self>, Vec<MatchArm<Self>>),
MatchExpression(MatchExpression<Self>),
IfExpression(IfExpression<Self>),
BlockExpression(Vec<StatementInsideBlock<Self>>, Box<Self>),
}
Expand Down Expand Up @@ -422,6 +422,33 @@ impl<Ref> Expression<Ref> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
pub struct MatchExpression<E = Expression<NamespacedPolynomialReference>> {
pub scrutinee: Box<E>,
pub arms: Vec<MatchArm<E>>,
}

impl<Ref> From<MatchExpression<Expression<Ref>>> for Expression<Ref> {
fn from(match_expr: MatchExpression<Expression<Ref>>) -> Self {
Expression::MatchExpression(match_expr)
}
}

impl<E> Children<E> for MatchExpression<E> {
fn children(&self) -> Box<dyn Iterator<Item = &E> + '_> {
Box::new(
once(self.scrutinee.as_ref()).chain(self.arms.iter().flat_map(|arm| arm.children())),
)
}

fn children_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
Box::new(
once(self.scrutinee.as_mut())
.chain(self.arms.iter_mut().flat_map(|arm| arm.children_mut())),
)
}
}

impl<Ref> From<BinaryOperation<Expression<Ref>>> for Expression<Ref> {
fn from(operation: BinaryOperation<Expression<Ref>>) -> Self {
Expression::BinaryOperation(operation)
Expand Down Expand Up @@ -486,7 +513,7 @@ impl<R> Expression<R> {
| Expression::String(_)
| Expression::Number(_) => empty(),
Expression::Tuple(v) => v.iter(),
Expression::LambdaExpression(LambdaExpression { body, .. }) => once(body.as_ref()),
Expression::LambdaExpression(lambda) => lambda.children(),
Expression::ArrayLiteral(ArrayLiteral { items }) => items.iter(),
Expression::BinaryOperation(BinaryOperation { left, right, .. }) => {
[left.as_ref(), right.as_ref()].into_iter()
Expand All @@ -500,9 +527,7 @@ impl<R> Expression<R> {
arguments,
}) => once(function.as_ref()).chain(arguments.iter()),
Expression::FreeInput(e) => once(e.as_ref()),
Expression::MatchExpression(e, arms) => {
once(e.as_ref()).chain(arms.iter().flat_map(|arm| arm.children()))
}
Expression::MatchExpression(match_expr) => match_expr.children(),
Expression::IfExpression(IfExpression {
condition,
body,
Expand All @@ -527,7 +552,7 @@ impl<R> Expression<R> {
}
Expression::Number(_) => empty(),
Expression::Tuple(v) => v.iter_mut(),
Expression::LambdaExpression(LambdaExpression { body, .. }) => once(body.as_mut()),
Expression::LambdaExpression(lambda) => lambda.children_mut(),
Expression::ArrayLiteral(ArrayLiteral { items }) => items.iter_mut(),
Expression::BinaryOperation(BinaryOperation { left, right, .. }) => {
[left.as_mut(), right.as_mut()].into_iter()
Expand All @@ -541,9 +566,7 @@ impl<R> Expression<R> {
arguments,
}) => once(function.as_mut()).chain(arguments.iter_mut()),
Expression::FreeInput(e) => once(e.as_mut()),
Expression::MatchExpression(e, arms) => {
once(e.as_mut()).chain(arms.iter_mut().flat_map(|arm| arm.children_mut()))
}
Expression::MatchExpression(match_expr) => match_expr.children_mut(),
Expression::IfExpression(IfExpression {
condition,
body,
Expand Down Expand Up @@ -601,6 +624,12 @@ pub struct LambdaExpression<E = Expression<NamespacedPolynomialReference>> {
pub body: Box<E>,
}

impl<Ref> From<LambdaExpression<Expression<Ref>>> for Expression<Ref> {
fn from(lambda: LambdaExpression<Expression<Ref>>) -> Self {
Expression::LambdaExpression(lambda)
}
}

impl<E> Children<E> for LambdaExpression<E> {
fn children(&self) -> Box<dyn Iterator<Item = &E> + '_> {
Box::new(once(self.body.as_ref()))
Expand Down Expand Up @@ -786,6 +815,12 @@ pub struct FunctionCall<E = Expression<NamespacedPolynomialReference>> {
pub arguments: Vec<E>,
}

impl<Ref> From<FunctionCall<Expression<Ref>>> for Expression<Ref> {
fn from(call: FunctionCall<Expression<Ref>>) -> Self {
Expression::FunctionCall(call)
}
}

impl<E> Children<E> for FunctionCall<E> {
fn children(&self) -> Box<dyn Iterator<Item = &E> + '_> {
Box::new(once(self.function.as_ref()).chain(self.arguments.iter()))
Expand Down
14 changes: 7 additions & 7 deletions importer/src/path_canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use powdr_ast::parsed::{
types::{Type, TypeScheme},
visitor::{Children, ExpressionVisitable},
ArrayLiteral, BinaryOperation, EnumDeclaration, EnumVariant, Expression, FunctionCall,
IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, Pattern, PilStatement,
StatementInsideBlock, TypedExpression, UnaryOperation,
IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Pattern,
PilStatement, StatementInsideBlock, TypedExpression, UnaryOperation,
};

/// Changes all symbol references (symbol paths) from relative paths
Expand Down Expand Up @@ -172,7 +172,7 @@ fn free_inputs_in_expression<'a>(
Expression::LambdaExpression(_) => todo!(),
Expression::ArrayLiteral(_) => todo!(),
Expression::IndexAccess(_) => todo!(),
Expression::MatchExpression(_, _) => todo!(),
Expression::MatchExpression(_) => todo!(),
Expression::IfExpression(_) => todo!(),
Expression::BlockExpression(_, _) => todo!(),
}
Expand Down Expand Up @@ -209,7 +209,7 @@ fn free_inputs_in_expression_mut<'a>(
Expression::LambdaExpression(_) => todo!(),
Expression::ArrayLiteral(_) => todo!(),
Expression::IndexAccess(_) => todo!(),
Expression::MatchExpression(_, _) => todo!(),
Expression::MatchExpression(_) => todo!(),
Expression::IfExpression(_) => todo!(),
Expression::BlockExpression(_, _) => todo!(),
}
Expand Down Expand Up @@ -242,8 +242,8 @@ fn canonicalize_inside_expression(
canonicalize_inside_pattern(p, path, paths);
});
}
Expression::MatchExpression(_, match_arms) => {
match_arms.iter_mut().for_each(|MatchArm { pattern, .. }| {
Expression::MatchExpression(MatchExpression { scrutinee: _, arms }) => {
arms.iter_mut().for_each(|MatchArm { pattern, .. }| {
canonicalize_inside_pattern(pattern, path, paths);
})
}
Expand Down Expand Up @@ -679,7 +679,7 @@ fn check_expression(
check_expression(location, function, state, local_variables)?;
check_expressions(location, arguments, state, local_variables)
}
Expression::MatchExpression(scrutinee, arms) => {
Expression::MatchExpression(MatchExpression { scrutinee, arms }) => {
check_expression(location, scrutinee, state, local_variables)?;
arms.iter().try_for_each(|MatchArm { pattern, value }| {
let mut local_variables = local_variables.clone();
Expand Down
2 changes: 1 addition & 1 deletion parser/src/powdr.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ GenericReference: NamespacedPolynomialReference = {
}

MatchExpression: Box<Expression> = {
"match" <BoxedExpression> "{" <MatchArms> "}" => Box::new(Expression::MatchExpression(<>))
"match" <scrutinee:BoxedExpression> "{" <arms:MatchArms> "}" => Box::new(MatchExpression{scrutinee, arms}.into())
}

MatchArms: Vec<MatchArm> = {
Expand Down
11 changes: 7 additions & 4 deletions pil-analyzer/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use powdr_ast::{
display::quote,
types::{Type, TypeScheme},
ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess,
LambdaExpression, LetStatementInsideBlock, MatchArm, Number, Pattern, StatementInsideBlock,
UnaryOperation, UnaryOperator,
LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Number, Pattern,
StatementInsideBlock, UnaryOperation, UnaryOperator,
},
SourceRef,
};
Expand Down Expand Up @@ -726,7 +726,10 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> {
.extend(arguments.iter().rev().map(Operation::Expand));
self.expand(function)?;
}
Expression::MatchExpression(condition, _)
Expression::MatchExpression(MatchExpression {
scrutinee: condition,
arms: _,
})
| Expression::IfExpression(IfExpression { condition, .. }) => {
// Only handle the scrutinee / condition for now, we do not want to evaluate all arms.
self.op_stack.push(Operation::Combine(expr));
Expand Down Expand Up @@ -869,7 +872,7 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> {
let function = self.value_stack.pop().unwrap();
return self.combine_function_call(function, arguments);
}
Expression::MatchExpression(_, arms) => {
Expression::MatchExpression(MatchExpression { scrutinee: _, arms }) => {
let v = self.value_stack.pop().unwrap();
let (vars, body) = arms
.iter()
Expand Down
15 changes: 9 additions & 6 deletions pil-analyzer/src/expression_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use powdr_ast::{
analyzed::{Expression, PolynomialReference, Reference, RepeatedArray},
parsed::{
self, asm::SymbolPath, ArrayExpression, ArrayLiteral, BinaryOperation, IfExpression,
LambdaExpression, LetStatementInsideBlock, MatchArm, NamespacedPolynomialReference, Number,
Pattern, SelectedExpressions, StatementInsideBlock, SymbolCategory, UnaryOperation,
LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression,
NamespacedPolynomialReference, Number, Pattern, SelectedExpressions, StatementInsideBlock,
SymbolCategory, UnaryOperation,
},
};
use powdr_number::DegreeType;
Expand Down Expand Up @@ -121,9 +122,10 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
function: Box::new(self.process_expression(*c.function)),
arguments: self.process_expressions(c.arguments),
}),
PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
Box::new(self.process_expression(*scrutinee)),
arms.into_iter()
PExpression::MatchExpression(MatchExpression { scrutinee, arms }) => MatchExpression {
scrutinee: Box::new(self.process_expression(*scrutinee)),
arms: arms
.into_iter()
.map(|MatchArm { pattern, value }| {
let vars = self.save_local_variables();
let pattern = self.process_pattern(pattern);
Expand All @@ -132,7 +134,8 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
MatchArm { pattern, value }
})
.collect(),
),
}
.into(),
PExpression::IfExpression(IfExpression {
condition,
body,
Expand Down
5 changes: 3 additions & 2 deletions pil-analyzer/src/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use powdr_ast::{
types::{ArrayType, FunctionType, TupleType, Type, TypeBounds, TypeScheme},
visitor::ExpressionVisitable,
ArrayLiteral, BinaryOperation, FunctionCall, IndexAccess, LambdaExpression,
LetStatementInsideBlock, MatchArm, Number, Pattern, StatementInsideBlock, UnaryOperation,
LetStatementInsideBlock, MatchArm, MatchExpression, Number, Pattern, StatementInsideBlock,
UnaryOperation,
},
};

Expand Down Expand Up @@ -585,7 +586,7 @@ impl<'a> TypeChecker<'a> {
})?
}
Expression::FreeInput(_) => todo!(),
Expression::MatchExpression(scrutinee, arms) => {
Expression::MatchExpression(MatchExpression { scrutinee, arms }) => {
let scrutinee_type = self.infer_type_of_expression(scrutinee)?;
let result = self.new_type_var();
for MatchArm { pattern, value } in arms {
Expand Down
2 changes: 1 addition & 1 deletion riscv-executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> {
}
}
}
Expression::MatchExpression(_, _) => todo!(),
Expression::MatchExpression(_) => todo!(),
Expression::IfExpression(_) => panic!(),
Expression::BlockExpression(_, _) => panic!(),
Expression::IndexAccess(_) => todo!(),
Expand Down

0 comments on commit c10ffb1

Please sign in to comment.