Skip to content

Commit

Permalink
BinaryExpression struct in Expressions (#1351)
Browse files Browse the repository at this point in the history
This PR is part of issue #1345.
In particular, it adds the struct BinaryOperation to Expressions to
homogenise the structure before including source references.

---------

Co-authored-by: Thibaut Schaeffer <[email protected]>
  • Loading branch information
gzanitti and Schaeff authored May 21, 2024
1 parent a3087ea commit 29b0ee6
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 105 deletions.
21 changes: 15 additions & 6 deletions asm-to-pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use powdr_ast::{
asm::{CallableRef, InstructionBody, InstructionParams},
build::{self, absolute_reference, direct_reference, next_reference},
visitor::ExpressionVisitable,
ArrayExpression, BinaryOperator, Expression, FunctionCall, FunctionDefinition,
FunctionKind, LambdaExpression, MatchArm, Number, Pattern, PilStatement, PolynomialName,
SelectedExpressions, UnaryOperator,
ArrayExpression, BinaryOperation, BinaryOperator, Expression, FunctionCall,
FunctionDefinition, FunctionKind, LambdaExpression, MatchArm, Number, Pattern,
PilStatement, PolynomialName, SelectedExpressions, UnaryOperator,
},
SourceRef,
};
Expand Down Expand Up @@ -750,7 +750,7 @@ impl<T: FieldElement> VMConverter<T> {
Expression::LambdaExpression(_) => {
unreachable!("lambda expressions should have been removed")
}
Expression::BinaryOperation(left, op, right) => match op {
Expression::BinaryOperation(BinaryOperation { left, op, right }) => match op {
BinaryOperator::Add => self.add_assignment_value(
self.process_assignment_value(*left),
self.process_assignment_value(*right),
Expand Down Expand Up @@ -1078,7 +1078,11 @@ impl<T: FieldElement> VMConverter<T> {
expr: Expression,
) -> (usize, Expression) {
match expr {
Expression::BinaryOperation(left, operator, right) => match operator {
Expression::BinaryOperation(BinaryOperation {
left,
op: operator,
right,
}) => match operator {
BinaryOperator::Add => {
let (counter, left) = self.linearize_rec(prefix, counter, *left);
let (counter, right) = self.linearize_rec(prefix, counter, *right);
Expand Down Expand Up @@ -1208,7 +1212,12 @@ fn witness_column<S: Into<String>>(
}

fn extract_update(expr: Expression) -> (Option<String>, Expression) {
let Expression::BinaryOperation(left, BinaryOperator::Identity, right) = expr else {
let Expression::BinaryOperation(BinaryOperation {
left,
op: BinaryOperator::Identity,
right,
}) = expr
else {
panic!("Invalid statement for instruction body, expected constraint: {expr}");
};
// TODO check that there are no other "next" references in the expression
Expand Down
8 changes: 7 additions & 1 deletion ast/src/analyzed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{parsed::FunctionKind, writeln_indented, writeln_indented_by};
use self::parsed::{
asm::{AbsoluteSymbolPath, SymbolPath},
display::format_type_scheme_around_name,
BinaryOperation,
};

use super::*;
Expand Down Expand Up @@ -272,7 +273,12 @@ impl Display for Identity<Expression> {
match self.kind {
IdentityKind::Polynomial => {
let expression = self.expression_for_poly_id();
if let Expression::BinaryOperation(left, BinaryOperator::Sub, right) = expression {
if let Expression::BinaryOperation(BinaryOperation {
left,
op: BinaryOperator::Sub,
right,
}) = expression
{
write!(f, "{left} = {right};")
} else {
write!(f, "{expression} = 0;")
Expand Down
9 changes: 7 additions & 2 deletions ast/src/parsed/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::parsed::Expression;

use super::{
asm::{parse_absolute_path, Part, SymbolPath},
BinaryOperator, IndexAccess, NamespacedPolynomialReference, UnaryOperator,
BinaryOperation, BinaryOperator, IndexAccess, NamespacedPolynomialReference, UnaryOperator,
};

pub fn absolute_reference(name: &str) -> Expression {
Expand Down Expand Up @@ -40,5 +40,10 @@ pub fn index_access(expr: Expression, index: Option<BigUint>) -> Expression {
}

pub fn identity(lhs: Expression, rhs: Expression) -> Expression {
Expression::BinaryOperation(Box::new(lhs), BinaryOperator::Identity, Box::new(rhs))
BinaryOperation {
left: Box::new(lhs),
op: BinaryOperator::Identity,
right: Box::new(rhs),
}
.into()
}
97 changes: 45 additions & 52 deletions ast/src/parsed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,14 +596,6 @@ fn format_list<L: IntoIterator<Item = I>, I: Display>(list: L) -> String {
}

impl<E: Display> Expression<E> {
pub fn precedence(&self) -> Option<ExpressionPrecedence> {
match self {
Expression::UnaryOperation(op, _) => Some(op.precedence()),
Expression::BinaryOperation(_, op, _) => Some(op.precedence()),
_ => None,
}
}

pub fn format_unary_operation(
&self,
op: &UnaryOperator,
Expand All @@ -625,48 +617,6 @@ impl<E: Display> Expression<E> {
write!(f, "{exp_string}{op}")
}
}

pub fn format_binary_operation(
left: &Expression<E>,
op: &BinaryOperator,
right: &Expression<E>,
f: &mut Formatter<'_>,
) -> Result {
let force_parentheses = matches!(op, BinaryOperator::Pow);

let use_left_parentheses = match left.precedence() {
Some(left_precedence) => {
force_parentheses
|| left_precedence > op.precedence()
|| (left_precedence == op.precedence()
&& op.associativity() != BinaryOperatorAssociativity::Left)
}
None => false,
};

let use_right_parentheses = match right.precedence() {
Some(right_precedence) => {
force_parentheses
|| right_precedence > op.precedence()
|| (right_precedence == op.precedence()
&& op.associativity() != BinaryOperatorAssociativity::Right)
}
None => false,
};

let left_string = if use_left_parentheses {
format!("({left})")
} else {
format!("{left}")
};
let right_string = if use_right_parentheses {
format!("({right})")
} else {
format!("{right}")
};

write!(f, "{left_string} {op} {right_string}")
}
}

impl<Ref: Display> Display for Expression<Ref> {
Expand All @@ -679,8 +629,8 @@ impl<Ref: Display> Display for Expression<Ref> {
Expression::Tuple(items) => write!(f, "({})", format_list(items)),
Expression::LambdaExpression(lambda) => write!(f, "{lambda}"),
Expression::ArrayLiteral(array) => write!(f, "{array}"),
Expression::BinaryOperation(left, op, right) => {
Expression::format_binary_operation(left, op, right, f)
Expression::BinaryOperation(binaryop) => {
write!(f, "{binaryop}")
}
Expression::UnaryOperation(op, exp) => self.format_unary_operation(op, exp, f),
Expression::IndexAccess(index_access) => write!(f, "{index_access}"),
Expand Down Expand Up @@ -761,6 +711,49 @@ impl<E: Display> Display for ArrayLiteral<E> {
}
}

impl<E> Display for BinaryOperation<E>
where
E: Display + Precedence,
{
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
let force_parentheses = matches!(self.op, BinaryOperator::Pow);

let op_precedence = self.op.precedence().unwrap();
let use_left_parentheses = match self.left.precedence() {
Some(left_precedence) => {
force_parentheses
|| left_precedence > op_precedence
|| (left_precedence == op_precedence
&& self.op.associativity() != BinaryOperatorAssociativity::Left)
}
None => false,
};

let use_right_parentheses = match self.right.precedence() {
Some(right_precedence) => {
force_parentheses
|| right_precedence > op_precedence
|| (right_precedence == op_precedence
&& self.op.associativity() != BinaryOperatorAssociativity::Right)
}
None => false,
};

let left_string = if use_left_parentheses {
format!("({})", self.left)
} else {
format!("{}", self.left)
};
let right_string = if use_right_parentheses {
format!("({})", self.right)
} else {
format!("{}", self.right)
};

write!(f, "{left_string} {} {right_string}", self.op)
}
}

impl Display for BinaryOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(
Expand Down
51 changes: 41 additions & 10 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ pub enum Expression<Ref = NamespacedPolynomialReference> {
Tuple(Vec<Self>),
LambdaExpression(LambdaExpression<Self>),
ArrayLiteral(ArrayLiteral<Self>),
BinaryOperation(Box<Self>, BinaryOperator, Box<Self>),
BinaryOperation(BinaryOperation<Self>),
UnaryOperation(UnaryOperator, Box<Self>),
IndexAccess(IndexAccess<Self>),
FunctionCall(FunctionCall<Self>),
Expand All @@ -352,6 +352,13 @@ pub enum Expression<Ref = NamespacedPolynomialReference> {
BlockExpression(Vec<StatementInsideBlock<Self>>, Box<Self>),
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
pub struct BinaryOperation<E = Expression<NamespacedPolynomialReference>> {
pub left: Box<E>,
pub op: BinaryOperator,
pub right: Box<E>,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Number {
#[schemars(skip)]
Expand Down Expand Up @@ -380,7 +387,11 @@ pub type ExpressionPrecedence = u64;

impl<Ref> Expression<Ref> {
pub fn new_binary(left: Self, op: BinaryOperator, right: Self) -> Self {
Expression::BinaryOperation(Box::new(left), op, Box::new(right))
Expression::BinaryOperation(BinaryOperation {
left: Box::new(left),
op,
right: Box::new(right),
})
}

/// Visits this expression and all of its sub-expressions and returns true
Expand All @@ -399,6 +410,12 @@ impl<Ref> Expression<Ref> {
}
}

impl<Ref> From<BinaryOperation<Expression<Ref>>> for Expression<Ref> {
fn from(operation: BinaryOperation<Expression<Ref>>) -> Self {
Expression::BinaryOperation(operation)
}
}

impl Expression<NamespacedPolynomialReference> {
pub fn try_to_identifier(&self) -> Option<&String> {
if let Expression::Reference(r) = self {
Expand Down Expand Up @@ -459,7 +476,7 @@ impl<R> Expression<R> {
Expression::Tuple(v) => v.iter(),
Expression::LambdaExpression(LambdaExpression { body, .. }) => once(body.as_ref()),
Expression::ArrayLiteral(ArrayLiteral { items }) => items.iter(),
Expression::BinaryOperation(left, _, right) => {
Expression::BinaryOperation(BinaryOperation { left, right, .. }) => {
[left.as_ref(), right.as_ref()].into_iter()
}
Expression::UnaryOperation(_, e) => once(e.as_ref()),
Expand Down Expand Up @@ -500,7 +517,7 @@ impl<R> Expression<R> {
Expression::Tuple(v) => v.iter_mut(),
Expression::LambdaExpression(LambdaExpression { body, .. }) => once(body.as_mut()),
Expression::ArrayLiteral(ArrayLiteral { items }) => items.iter_mut(),
Expression::BinaryOperation(left, _, right) => {
Expression::BinaryOperation(BinaryOperation { left, right, .. }) => {
[left.as_mut(), right.as_mut()].into_iter()
}
Expression::UnaryOperation(_, e) => once(e.as_mut()),
Expand Down Expand Up @@ -659,24 +676,26 @@ pub enum BinaryOperatorAssociativity {
}

trait Precedence {
fn precedence(&self) -> ExpressionPrecedence;
fn precedence(&self) -> Option<ExpressionPrecedence>;
}

impl Precedence for UnaryOperator {
fn precedence(&self) -> ExpressionPrecedence {
fn precedence(&self) -> Option<ExpressionPrecedence> {
use UnaryOperator::*;
match self {
let precedence = match self {
// NOTE: Any modification must be done with care to not overlap with BinaryOperator's precedence
Next => 1,
Minus | LogicalNot => 2,
}
};

Some(precedence)
}
}

impl Precedence for BinaryOperator {
fn precedence(&self) -> ExpressionPrecedence {
fn precedence(&self) -> Option<ExpressionPrecedence> {
use BinaryOperator::*;
match self {
let precedence = match self {
// NOTE: Any modification must be done with care to not overlap with LambdaExpression's precedence
// Unary Oprators
// **
Expand All @@ -701,6 +720,18 @@ impl Precedence for BinaryOperator {
LogicalOr => 12,
// .. ..=
// ??
};

Some(precedence)
}
}

impl<E> Precedence for Expression<E> {
fn precedence(&self) -> Option<ExpressionPrecedence> {
match self {
Expression::UnaryOperation(op, _) => op.precedence(),
Expression::BinaryOperation(operation) => operation.op.precedence(),
_ => None,
}
}
}
Expand Down
12 changes: 7 additions & 5 deletions importer/src/path_canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use powdr_ast::parsed::{
folder::Folder,
types::{Type, TypeScheme},
visitor::{Children, ExpressionVisitable},
ArrayLiteral, EnumDeclaration, EnumVariant, Expression, FunctionCall, IndexAccess,
LambdaExpression, LetStatementInsideBlock, MatchArm, Pattern, PilStatement,
ArrayLiteral, BinaryOperation, EnumDeclaration, EnumVariant, Expression, FunctionCall,
IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, Pattern, PilStatement,
StatementInsideBlock, TypedExpression,
};

Expand Down Expand Up @@ -156,7 +156,7 @@ fn free_inputs_in_expression<'a>(
| Expression::PublicReference(_)
| Expression::Number(_)
| Expression::String(_) => Box::new(None.into_iter()),
Expression::BinaryOperation(left, _, right) => {
Expression::BinaryOperation(BinaryOperation { left, right, .. }) => {
Box::new(free_inputs_in_expression(left).chain(free_inputs_in_expression(right)))
}
Expression::UnaryOperation(_, expr) => free_inputs_in_expression(expr),
Expand Down Expand Up @@ -188,7 +188,7 @@ fn free_inputs_in_expression_mut<'a>(
| Expression::PublicReference(_)
| Expression::Number(_)
| Expression::String(_) => Box::new(None.into_iter()),
Expression::BinaryOperation(left, _, right) => Box::new(
Expression::BinaryOperation(BinaryOperation { left, right, .. }) => Box::new(
free_inputs_in_expression_mut(left).chain(free_inputs_in_expression_mut(right)),
),
Expression::UnaryOperation(_, expr) => free_inputs_in_expression_mut(expr),
Expand Down Expand Up @@ -660,7 +660,9 @@ fn check_expression(
local_variables.extend(check_patterns(location, params, state)?);
check_expression(location, body, state, &local_variables)
}
Expression::BinaryOperation(a, _, b)
Expression::BinaryOperation(BinaryOperation {
left: a, right: b, ..
})
| Expression::IndexAccess(IndexAccess { array: a, index: b }) => {
check_expression(location, a.as_ref(), state, local_variables)?;
check_expression(location, b.as_ref(), state, local_variables)
Expand Down
Loading

0 comments on commit 29b0ee6

Please sign in to comment.