From 9472671b9bb2f936c372fa98dbce381fda61c999 Mon Sep 17 00:00:00 2001 From: Amin Latifi Date: Thu, 9 May 2024 07:33:26 +0000 Subject: [PATCH] Reduce Number of Expression Parentheses (#1289) Adds binary operation precedence support to avoid unnecessary parentheses in expression printed format - #962 --------- Co-authored-by: chriseth --- analysis/README.md | 2 +- ast/src/parsed/display.rs | 100 +++++++++++++++++--- ast/src/parsed/mod.rs | 72 ++++++++++++++ linker/src/lib.rs | 140 ++++++++++++++-------------- parser/src/lib.rs | 49 +++------- parser/src/powdr.lalrpop | 2 +- parser/src/test_utils.rs | 26 ++++++ parser/tests/parentheses_test.rs | 136 +++++++++++++++++++++++++++ pil-analyzer/tests/condenser.rs | 8 +- pil-analyzer/tests/parse_display.rs | 44 ++++----- pil-analyzer/tests/side_effects.rs | 4 +- pilopt/src/lib.rs | 4 +- 12 files changed, 440 insertions(+), 147 deletions(-) create mode 100644 parser/src/test_utils.rs create mode 100644 parser/tests/parentheses_test.rs diff --git a/analysis/README.md b/analysis/README.md index ee8534672..10f774a63 100644 --- a/analysis/README.md +++ b/analysis/README.md @@ -200,7 +200,7 @@ The diff for our example program is as follows: + _output_0 = ((((read__output_0_pc * pc) + (read__output_0__input_0 * _input_0)) + _output_0_const) + (_output_0_read_free * _output_0_free_value)); + pol constant first_step = [1] + [0]*; + ((1 - instr__reset) * _input_0') = ((1 - instr__reset) * _input_0); -+ pc' = ((1 - first_step') * ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1)))); ++ pc' = (1 - first_step') * (instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1)); + pol constant p_line = [0, 1, 2, 3, 4, 5] + [5]*; + pol commit _output_0_free_value; + pol constant p__output_0_const = [0, 0, 0, 0, 1, 0] + [0]*; diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 15e6873c9..e77a76eb1 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -329,15 +329,23 @@ impl Display for Params { } } -impl Display for IndexAccess { +impl Display for IndexAccess> { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - write!(f, "{}[{}]", self.array, self.index) + if self.array.precedence().is_none() { + write!(f, "{}[{}]", self.array, self.index) + } else { + write!(f, "({})[{}]", self.array, self.index) + } } } -impl Display for FunctionCall { +impl Display for FunctionCall> { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - write!(f, "{}({})", self.function, format_list(&self.arguments)) + if self.function.precedence().is_none() { + write!(f, "{}({})", self.function, format_list(&self.arguments)) + } else { + write!(f, "({})({})", self.function, format_list(&self.arguments)) + } } } @@ -587,6 +595,80 @@ fn format_list, I: Display>(list: L) -> String { format!("{}", list.into_iter().format(", ")) } +impl Expression { + pub fn precedence(&self) -> Option { + match self { + Expression::UnaryOperation(op, _) => Some(op.precedence()), + Expression::BinaryOperation(_, op, _) => Some(op.precedence()), + _ => None, + } + } + + pub fn format_unary_operation( + &self, + op: &UnaryOperator, + exp: &Expression, + f: &mut Formatter<'_>, + ) -> Result { + let exp_string = match (self.precedence(), exp.precedence()) { + (Some(precedence), Some(inner_precedence)) if precedence < inner_precedence => { + format!("({exp})") + } + _ => { + format!("{exp}") + } + }; + + if op.is_prefix() { + write!(f, "{op}{exp_string}") + } else { + write!(f, "{exp_string}{op}") + } + } + + pub fn format_binary_operation( + left: &Expression, + op: &BinaryOperator, + right: &Expression, + f: &mut Formatter<'_>, + ) -> Result { + let force_parentheses = matches!(op, BinaryOperator::Pow); + + let use_left_parentheses = match left.precedence() { + Some(left_precedence) => { + force_parentheses + || left_precedence > op.precedence() + || (left_precedence == op.precedence() + && op.associativity() != BinaryOperatorAssociativity::Left) + } + None => false, + }; + + let use_right_parentheses = match right.precedence() { + Some(right_precedence) => { + force_parentheses + || right_precedence > op.precedence() + || (right_precedence == op.precedence() + && op.associativity() != BinaryOperatorAssociativity::Right) + } + None => false, + }; + + let left_string = if use_left_parentheses { + format!("({left})") + } else { + format!("{left}") + }; + let right_string = if use_right_parentheses { + format!("({right})") + } else { + format!("{right}") + }; + + write!(f, "{left_string} {op} {right_string}") + } +} + impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { @@ -597,14 +679,10 @@ impl Display for Expression { Expression::Tuple(items) => write!(f, "({})", format_list(items)), Expression::LambdaExpression(lambda) => write!(f, "{lambda}"), Expression::ArrayLiteral(array) => write!(f, "{array}"), - Expression::BinaryOperation(left, op, right) => write!(f, "({left} {op} {right})"), - Expression::UnaryOperation(op, exp) => { - if op.is_prefix() { - write!(f, "{op}{exp}") - } else { - write!(f, "{exp}{op}") - } + Expression::BinaryOperation(left, op, right) => { + Expression::format_binary_operation(left, op, right, f) } + Expression::UnaryOperation(op, exp) => self.format_unary_operation(op, exp, f), Expression::IndexAccess(index_access) => write!(f, "{index_access}"), Expression::FunctionCall(fun_call) => write!(f, "{fun_call}"), Expression::FreeInput(input) => write!(f, "${{ {input} }}"), diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index bad238fa7..62956110f 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -352,6 +352,8 @@ pub enum Expression { BlockExpression(Vec>, Box), } +pub type ExpressionPrecedence = u64; + impl Expression { pub fn new_binary(left: Self, op: BinaryOperator, right: Self) -> Self { Expression::BinaryOperation(Box::new(left), op, Box::new(right)) @@ -638,6 +640,76 @@ pub enum BinaryOperator { Greater, } +#[derive(Debug, PartialEq, Eq)] +pub enum BinaryOperatorAssociativity { + Left, + Right, + RequireParentheses, +} + +trait Precedence { + fn precedence(&self) -> ExpressionPrecedence; +} + +impl Precedence for UnaryOperator { + fn precedence(&self) -> ExpressionPrecedence { + use UnaryOperator::*; + match self { + // NOTE: Any modification must be done with care to not overlap with BinaryOperator's precedence + Next => 1, + Minus | LogicalNot => 2, + } + } +} + +impl Precedence for BinaryOperator { + fn precedence(&self) -> ExpressionPrecedence { + use BinaryOperator::*; + match self { + // NOTE: Any modification must be done with care to not overlap with LambdaExpression's precedence + // Unary Oprators + // ** + Pow => 3, + // * / % + Mul | Div | Mod => 4, + // + - + Add | Sub => 5, + // << >> + ShiftLeft | ShiftRight => 6, + // & + BinaryAnd => 7, + // ^ + BinaryXor => 8, + // | + BinaryOr => 9, + // = == != < > <= >= + Identity | Equal | NotEqual | Less | Greater | LessEqual | GreaterEqual => 10, + // && + LogicalAnd => 11, + // || + LogicalOr => 12, + // .. ..= + // ?? + } + } +} + +impl BinaryOperator { + pub fn associativity(&self) -> BinaryOperatorAssociativity { + use BinaryOperator::*; + use BinaryOperatorAssociativity::*; + match self { + Identity | Equal | NotEqual | Less | Greater | LessEqual | GreaterEqual => { + RequireParentheses + } + Pow => Right, + + // .. ..= => RequireParentheses, + _ => Left, + } + } +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] pub struct IndexAccess> { pub array: Box, diff --git a/linker/src/lib.rs b/linker/src/lib.rs index a68828d33..9895c1726 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -305,19 +305,19 @@ mod test { #[test] fn compile_empty_vm() { - let expectation = r#"namespace main((4 + 4)); + let expectation = r#"namespace main(4 + 4); pol commit _operation_id(i) query std::prover::Query::Hint(2); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit instr__jump_to_operation; pol commit instr__reset; pol commit instr__loop; pol commit instr_return; pol constant first_step = [1] + [0]*; - pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2] + [2]*; pol constant p_instr__jump_to_operation = [0, 1, 0] + [0]*; pol constant p_instr__loop = [0, 0, 1] + [1]*; @@ -348,7 +348,7 @@ mod test { #[test] fn compile_pil_without_machine() { - let input = " let even = std::array::new(5, (|i| (2 * i)));"; + let input = " let even = std::array::new(5, (|i| 2 * i));"; let graph = parse_analyze_and_compile::(input); let pil = link(graph).unwrap().to_string(); assert_eq!(&pil[0..input.len()], input); @@ -359,8 +359,8 @@ mod test { let expectation = r#"namespace main(16); pol commit _operation_id(i) query std::prover::Query::Hint(4); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit X; pol commit Y; @@ -378,16 +378,16 @@ mod test { pol commit X_read_free; pol commit read_X_A; pol commit read_X_pc; - (X = ((((read_X_A * A) + (read_X_pc * pc)) + X_const) + (X_read_free * X_free_value))); + X = read_X_A * A + read_X_pc * pc + X_const + X_read_free * X_free_value; pol commit Y_const; pol commit Y_read_free; pol commit read_Y_A; pol commit read_Y_pc; - (Y = ((((read_Y_A * A) + (read_Y_pc * pc)) + Y_const) + (Y_read_free * Y_free_value))); + Y = read_Y_A * A + read_Y_pc * pc + Y_const + Y_read_free * Y_free_value; pol constant first_step = [1] + [0]*; - (A' = ((((reg_write_X_A * X) + (reg_write_Y_A * Y)) + (instr__reset * 0)) + ((1 - ((reg_write_X_A + reg_write_Y_A) + instr__reset)) * A))); - pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + A' = reg_write_X_A * X + reg_write_Y_A * Y + instr__reset * 0 + (1 - (reg_write_X_A + reg_write_Y_A + instr__reset)) * A; + pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol commit X_free_value; pol commit Y_free_value; @@ -413,12 +413,12 @@ mod test { instr_one { 4, Y } in main_sub.instr_return { main_sub._operation_id, main_sub._output_0 }; instr_nothing { 3 } in main_sub.instr_return { main_sub._operation_id }; pol constant _linker_first_step = [1] + [0]*; - ((_linker_first_step * (_operation_id - 2)) = 0); + _linker_first_step * (_operation_id - 2) = 0; namespace main_sub(16); pol commit _operation_id(i) query std::prover::Query::Hint(5); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit _input_0; pol commit _output_0; @@ -430,11 +430,11 @@ namespace main_sub(16); pol commit _output_0_read_free; pol commit read__output_0_pc; pol commit read__output_0__input_0; - (_output_0 = ((((read__output_0_pc * pc) + (read__output_0__input_0 * _input_0)) + _output_0_const) + (_output_0_read_free * _output_0_free_value))); + _output_0 = read__output_0_pc * pc + read__output_0__input_0 * _input_0 + _output_0_const + _output_0_read_free * _output_0_free_value; pol constant first_step = [1] + [0]*; - (((1 - instr__reset) * (_input_0' - _input_0)) = 0); - pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + (1 - instr__reset) * (_input_0' - _input_0) = 0; + pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2, 3, 4, 5] + [5]*; pol commit _output_0_free_value; pol constant p__output_0_const = [0, 0, 0, 0, 1, 0] + [0]*; @@ -462,13 +462,13 @@ namespace main_sub(16); let expectation = r#"namespace main(1024); pol commit XInv; pol commit XIsZero; - (XIsZero = (1 - (X * XInv))); - ((XIsZero * X) = 0); - ((XIsZero * (1 - XIsZero)) = 0); + XIsZero = 1 - X * XInv; + XIsZero * X = 0; + XIsZero * (1 - XIsZero) = 0; pol commit _operation_id(i) query std::prover::Query::Hint(10); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit X; pol commit reg_write_X_A; @@ -477,13 +477,13 @@ namespace main_sub(16); pol commit CNT; pol commit instr_jmpz; pol commit instr_jmpz_param_l; - pol instr_jmpz_pc_update = (XIsZero * instr_jmpz_param_l); - pol instr_jmpz_pc_update_1 = ((1 - XIsZero) * (pc + 1)); + pol instr_jmpz_pc_update = XIsZero * instr_jmpz_param_l; + pol instr_jmpz_pc_update_1 = (1 - XIsZero) * (pc + 1); pol commit instr_jmp; pol commit instr_jmp_param_l; pol commit instr_dec_CNT; pol commit instr_assert_zero; - ((instr_assert_zero * (XIsZero - 1)) = 0); + instr_assert_zero * (XIsZero - 1) = 0; pol commit instr__jump_to_operation; pol commit instr__reset; pol commit instr__loop; @@ -493,16 +493,16 @@ namespace main_sub(16); pol commit read_X_A; pol commit read_X_CNT; pol commit read_X_pc; - (X = (((((read_X_A * A) + (read_X_CNT * CNT)) + (read_X_pc * pc)) + X_const) + (X_read_free * X_free_value))); + X = read_X_A * A + read_X_CNT * CNT + read_X_pc * pc + X_const + X_read_free * X_free_value; pol constant first_step = [1] + [0]*; - (A' = (((reg_write_X_A * X) + (instr__reset * 0)) + ((1 - (reg_write_X_A + instr__reset)) * A))); - (CNT' = ((((reg_write_X_CNT * X) + (instr_dec_CNT * (CNT - 1))) + (instr__reset * 0)) + ((1 - ((reg_write_X_CNT + instr_dec_CNT) + instr__reset)) * CNT))); - pol pc_update = ((((((instr_jmpz * (instr_jmpz_pc_update + instr_jmpz_pc_update_1)) + (instr_jmp * instr_jmp_param_l)) + (instr__jump_to_operation * _operation_id)) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((((instr_jmpz + instr_jmp) + instr__jump_to_operation) + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + A' = reg_write_X_A * X + instr__reset * 0 + (1 - (reg_write_X_A + instr__reset)) * A; + CNT' = reg_write_X_CNT * X + instr_dec_CNT * (CNT - 1) + instr__reset * 0 + (1 - (reg_write_X_CNT + instr_dec_CNT + instr__reset)) * CNT; + pol pc_update = instr_jmpz * (instr_jmpz_pc_update + instr_jmpz_pc_update_1) + instr_jmp * instr_jmp_param_l + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_jmpz + instr_jmp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + [10]*; pol commit X_free_value(__i) query match std::prover::eval(pc) { 2 => std::prover::Query::Input(1), - 4 => std::prover::Query::Input(std::convert::int((std::prover::eval(CNT) + 1))), + 4 => std::prover::Query::Input(std::convert::int(std::prover::eval(CNT) + 1)), 7 => std::prover::Query::Input(0), _ => std::prover::Query::None, }; @@ -525,7 +525,7 @@ namespace main_sub(16); pol constant p_reg_write_X_CNT = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*; { pc, reg_write_X_A, reg_write_X_CNT, instr_jmpz, instr_jmpz_param_l, instr_jmp, instr_jmp_param_l, instr_dec_CNT, instr_assert_zero, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_CNT, read_X_pc } in { p_line, p_reg_write_X_A, p_reg_write_X_CNT, p_instr_jmpz, p_instr_jmpz_param_l, p_instr_jmp, p_instr_jmp_param_l, p_instr_dec_CNT, p_instr_assert_zero, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_CNT, p_read_X_pc }; pol constant _linker_first_step = [1] + [0]*; - ((_linker_first_step * (_operation_id - 2)) = 0); + _linker_first_step * (_operation_id - 2) = 0; "#; let file_name = format!( "{}/../test_data/asm/simple_sum.asm", @@ -557,8 +557,8 @@ machine Machine { let expectation = r#"namespace main(1024); pol commit _operation_id(i) query std::prover::Query::Hint(4); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit fp; pol commit instr_inc_fp; @@ -571,9 +571,9 @@ machine Machine { pol commit instr__loop; pol commit instr_return; pol constant first_step = [1] + [0]*; - (fp' = ((((instr_inc_fp * (fp + instr_inc_fp_param_amount)) + (instr_adjust_fp * (fp + instr_adjust_fp_param_amount))) + (instr__reset * 0)) + ((1 - ((instr_inc_fp + instr_adjust_fp) + instr__reset)) * fp))); - pol pc_update = (((((instr_adjust_fp * label) + (instr__jump_to_operation * _operation_id)) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - (((instr_adjust_fp + instr__jump_to_operation) + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + fp' = instr_inc_fp * (fp + instr_inc_fp_param_amount) + instr_adjust_fp * (fp + instr_adjust_fp_param_amount) + instr__reset * 0 + (1 - (instr_inc_fp + instr_adjust_fp + instr__reset)) * fp; + pol pc_update = instr_adjust_fp * label + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_adjust_fp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol constant p_instr__jump_to_operation = [0, 1, 0, 0, 0] + [0]*; pol constant p_instr__loop = [0, 0, 0, 0, 1] + [1]*; @@ -586,7 +586,7 @@ machine Machine { pol constant p_instr_return = [0]*; { pc, instr_inc_fp, instr_inc_fp_param_amount, instr_adjust_fp, instr_adjust_fp_param_amount, instr_adjust_fp_param_t, instr__jump_to_operation, instr__reset, instr__loop, instr_return } in { p_line, p_instr_inc_fp, p_instr_inc_fp_param_amount, p_instr_adjust_fp, p_instr_adjust_fp_param_amount, p_instr_adjust_fp_param_t, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return }; pol constant _linker_first_step = [1] + [0]*; - ((_linker_first_step * (_operation_id - 2)) = 0); + _linker_first_step * (_operation_id - 2) = 0; "#; let graph = parse_analyze_and_compile::(source); let pil = link(graph).unwrap(); @@ -644,8 +644,8 @@ machine Main { let expected = r#"namespace main(1024); pol commit _operation_id(i) query std::prover::Query::Hint(3); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit X; pol commit reg_write_X_A; @@ -659,11 +659,11 @@ machine Main { pol commit X_read_free; pol commit read_X_A; pol commit read_X_pc; - (X = ((((read_X_A * A) + (read_X_pc * pc)) + X_const) + (X_read_free * X_free_value))); + X = read_X_A * A + read_X_pc * pc + X_const + X_read_free * X_free_value; pol constant first_step = [1] + [0]*; - (A' = ((((reg_write_X_A * X) + (instr_add5_into_A * A')) + (instr__reset * 0)) + ((1 - ((reg_write_X_A + instr_add5_into_A) + instr__reset)) * A))); - pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + A' = reg_write_X_A * X + instr_add5_into_A * A' + instr__reset * 0 + (1 - (reg_write_X_A + instr_add5_into_A + instr__reset)) * A; + pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2, 3] + [3]*; pol commit X_free_value; pol constant p_X_const = [0, 0, 10, 0] + [0]*; @@ -679,13 +679,13 @@ machine Main { { pc, reg_write_X_A, instr_add5_into_A, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_pc } in { p_line, p_reg_write_X_A, p_instr_add5_into_A, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_pc }; instr_add5_into_A { 0, X, A' } in main_vm.latch { main_vm.operation_id, main_vm.x, main_vm.y }; pol constant _linker_first_step = [1] + [0]*; - ((_linker_first_step * (_operation_id - 2)) = 0); + _linker_first_step * (_operation_id - 2) = 0; namespace main_vm(1024); pol commit operation_id; pol constant latch = [1]*; pol commit x; pol commit y; - (y = (x + 5)); + y = x + 5; "#; let graph = parse_analyze_and_compile::(asm); let pil = link(graph).unwrap(); @@ -697,8 +697,8 @@ namespace main_vm(1024); let expected = r#"namespace main(65536); pol commit _operation_id(i) query std::prover::Query::Hint(13); pol constant _block_enforcer_last_step = [0]* + [1]; - let _operation_id_no_change = ((1 - _block_enforcer_last_step) * (1 - instr_return)); - ((_operation_id_no_change * (_operation_id' - _operation_id)) = 0); + let _operation_id_no_change = (1 - _block_enforcer_last_step) * (1 - instr_return); + _operation_id_no_change * (_operation_id' - _operation_id) = 0; pol commit pc; pol commit X; pol commit Y; @@ -714,7 +714,7 @@ namespace main_vm(1024); pol commit instr_or; pol commit instr_or_into_B; pol commit instr_assert_eq; - ((instr_assert_eq * (X - Y)) = 0); + instr_assert_eq * (X - Y) = 0; pol commit instr__jump_to_operation; pol commit instr__reset; pol commit instr__loop; @@ -724,24 +724,24 @@ namespace main_vm(1024); pol commit read_X_A; pol commit read_X_B; pol commit read_X_pc; - (X = (((((read_X_A * A) + (read_X_B * B)) + (read_X_pc * pc)) + X_const) + (X_read_free * X_free_value))); + X = read_X_A * A + read_X_B * B + read_X_pc * pc + X_const + X_read_free * X_free_value; pol commit Y_const; pol commit Y_read_free; pol commit read_Y_A; pol commit read_Y_B; pol commit read_Y_pc; - (Y = (((((read_Y_A * A) + (read_Y_B * B)) + (read_Y_pc * pc)) + Y_const) + (Y_read_free * Y_free_value))); + Y = read_Y_A * A + read_Y_B * B + read_Y_pc * pc + Y_const + Y_read_free * Y_free_value; pol commit Z_const; pol commit Z_read_free; pol commit read_Z_A; pol commit read_Z_B; pol commit read_Z_pc; - (Z = (((((read_Z_A * A) + (read_Z_B * B)) + (read_Z_pc * pc)) + Z_const) + (Z_read_free * Z_free_value))); + Z = read_Z_A * A + read_Z_B * B + read_Z_pc * pc + Z_const + Z_read_free * Z_free_value; pol constant first_step = [1] + [0]*; - (A' = (((((reg_write_X_A * X) + (reg_write_Y_A * Y)) + (reg_write_Z_A * Z)) + (instr__reset * 0)) + ((1 - (((reg_write_X_A + reg_write_Y_A) + reg_write_Z_A) + instr__reset)) * A))); - (B' = ((((((reg_write_X_B * X) + (reg_write_Y_B * Y)) + (reg_write_Z_B * Z)) + (instr_or_into_B * B')) + (instr__reset * 0)) + ((1 - ((((reg_write_X_B + reg_write_Y_B) + reg_write_Z_B) + instr_or_into_B) + instr__reset)) * B))); - pol pc_update = ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1))); - (pc' = ((1 - first_step') * pc_update)); + A' = reg_write_X_A * X + reg_write_Y_A * Y + reg_write_Z_A * Z + instr__reset * 0 + (1 - (reg_write_X_A + reg_write_Y_A + reg_write_Z_A + instr__reset)) * A; + B' = reg_write_X_B * X + reg_write_Y_B * Y + reg_write_Z_B * Z + instr_or_into_B * B' + instr__reset * 0 + (1 - (reg_write_X_B + reg_write_Y_B + reg_write_Z_B + instr_or_into_B + instr__reset)) * B; + pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pc' = (1 - first_step') * pc_update; pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + [13]*; pol commit X_free_value; pol commit Y_free_value; @@ -775,28 +775,28 @@ namespace main_vm(1024); pol constant p_reg_write_Z_A = [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0] + [0]*; pol constant p_reg_write_Z_B = [0]*; { pc, reg_write_X_A, reg_write_Y_A, reg_write_Z_A, reg_write_X_B, reg_write_Y_B, reg_write_Z_B, instr_or, instr_or_into_B, instr_assert_eq, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_B, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_B, read_Y_pc, Z_const, Z_read_free, read_Z_A, read_Z_B, read_Z_pc } in { p_line, p_reg_write_X_A, p_reg_write_Y_A, p_reg_write_Z_A, p_reg_write_X_B, p_reg_write_Y_B, p_reg_write_Z_B, p_instr_or, p_instr_or_into_B, p_instr_assert_eq, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_B, p_read_X_pc, p_Y_const, p_Y_read_free, p_read_Y_A, p_read_Y_B, p_read_Y_pc, p_Z_const, p_Z_read_free, p_read_Z_A, p_read_Z_B, p_read_Z_pc }; - instr_or { 0, X, Y, Z } is (main_bin.latch * main_bin.sel[0]) { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C }; - instr_or_into_B { 0, X, Y, B' } is (main_bin.latch * main_bin.sel[1]) { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C }; + instr_or { 0, X, Y, Z } is main_bin.latch * main_bin.sel[0] { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C }; + instr_or_into_B { 0, X, Y, B' } is main_bin.latch * main_bin.sel[1] { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C }; pol constant _linker_first_step = [1] + [0]*; - ((_linker_first_step * (_operation_id - 2)) = 0); + _linker_first_step * (_operation_id - 2) = 0; namespace main_bin(65536); pol commit operation_id; - pol constant latch(i) { if ((i % 4) == 3) { 1 } else { 0 } }; - pol constant FACTOR(i) { (1 << (((i + 1) % 4) * 8)) }; - let a = (|i| (i % 256)); + pol constant latch(i) { if i % 4 == 3 { 1 } else { 0 } }; + pol constant FACTOR(i) { 1 << (i + 1) % 4 * 8 }; + let a = (|i| i % 256); pol constant P_A(i) { a(i) }; - let b = (|i| ((i >> 8) % 256)); + let b = (|i| (i >> 8) % 256); pol constant P_B(i) { b(i) }; - pol constant P_C(i) { ((a(i) | b(i)) & 255) }; + pol constant P_C(i) { (a(i) | b(i)) & 255 }; pol commit A_byte; pol commit B_byte; pol commit C_byte; pol commit A; pol commit B; pol commit C; - (A' = ((A * (1 - latch)) + (A_byte * FACTOR))); - (B' = ((B * (1 - latch)) + (B_byte * FACTOR))); - (C' = ((C * (1 - latch)) + (C_byte * FACTOR))); + A' = A * (1 - latch) + A_byte * FACTOR; + B' = B * (1 - latch) + B_byte * FACTOR; + C' = C * (1 - latch) + C_byte * FACTOR; { A_byte, B_byte, C_byte } in { P_A, P_B, P_C }; pol commit sel[2]; std::array::map(sel, std::utils::force_bool); diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 187e91c85..233139fda 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -13,6 +13,8 @@ use powdr_parser_util::{handle_parse_error, ParseError}; use std::sync::Arc; +pub mod test_utils; + lalrpop_mod!( #[allow(clippy::all)] #[allow(clippy::uninlined_format_args)] @@ -243,30 +245,6 @@ mod test { }) } - // helper function to clear SourceRef's inside the AST so we can compare for equality - fn pil_clear_source_refs(ast: &mut PILFile) { - ast.0.iter_mut().for_each(pil_statement_clear_source_ref); - } - - fn pil_statement_clear_source_ref(stmt: &mut PilStatement) { - match stmt { - PilStatement::Include(s, _) - | PilStatement::Namespace(s, _, _) - | PilStatement::LetStatement(s, _, _, _) - | PilStatement::PolynomialDefinition(s, _, _) - | PilStatement::PublicDeclaration(s, _, _, _, _) - | PilStatement::PolynomialConstantDeclaration(s, _) - | PilStatement::PolynomialConstantDefinition(s, _, _) - | PilStatement::PolynomialCommitDeclaration(s, _, _, _) - | PilStatement::PlookupIdentity(s, _, _) - | PilStatement::PermutationIdentity(s, _, _) - | PilStatement::ConnectIdentity(s, _, _) - | PilStatement::ConstantDefinition(s, _, _) - | PilStatement::Expression(s, _) - | PilStatement::EnumDeclaration(s, _) => *s = SourceRef::unknown(), - } - } - // helper function to clear SourceRef's inside the AST so we can compare for equality fn asm_clear_source_refs(ast: &mut ASMProgram) { use powdr_ast::parsed::asm::{ @@ -275,6 +253,7 @@ mod test { }; fn clear_machine_stmt(stmt: &mut MachineStatement) { + use test_utils::pil_statement_clear_source_ref; match stmt { MachineStatement::Submachine(s, _, _) | MachineStatement::RegisterDeclaration(s, _, _) @@ -365,6 +344,7 @@ mod test { #[test] /// Test that (source -> AST -> source -> AST) works properly for pil files fn parse_write_reparse_pil() { + use test_utils::pil_clear_source_refs; let crate_dir = env!("CARGO_MANIFEST_DIR"); let basedir = std::path::PathBuf::from(format!("{crate_dir}/../test_data/")); let pil_files = find_files_with_ext(basedir, "pil".into()); @@ -404,8 +384,8 @@ mod test { let input = r#" constant %N = 16; namespace Fibonacci(%N); - constant %last_row = (%N - 1); - let bool: expr -> expr = (|X| (X * (1 - X))); + constant %last_row = %N - 1; + let bool: expr -> expr = (|X| X * (1 - X)); let one_hot = (|i, which| match i { which => 1, _ => 0, @@ -413,9 +393,9 @@ namespace Fibonacci(%N); pol constant ISLAST(i) { one_hot(i, %last_row) }; pol commit arr[8]; pol commit x, y; - { (x + 2), y' } in { ISLAST, 7 }; - y { (x + 2), y' } is ISLAST { ISLAST, 7 }; - (((x - 2) * y) = 8); + { x + 2, y' } in { ISLAST, 7 }; + y { x + 2, y' } is ISLAST { ISLAST, 7 }; + (x - 2) * y = 8; public out = y(%last_row);"#; let printed = format!("{}", parse(Some("input"), input).unwrap()); assert_eq!(input.trim(), printed.trim()); @@ -430,7 +410,8 @@ namespace Fibonacci(%N); #[test] fn reparse_arrays() { - let input = " pol commit y[3];\n ((y - 2) = 0);\n ((y[2] - 2) = 0);\n public out = y[1](2);"; + let input = + " pol commit y[3];\n y - 2 = 0;\n y[2] - 2 = 0;\n public out = y[1](2);"; let printed = format!("{}", parse(Some("input"), input).unwrap()); assert_eq!(input.trim(), printed.trim()); } @@ -444,7 +425,7 @@ namespace Fibonacci(%N); #[test] fn array_literals() { - let input = r#"let x = [[1], [2], [(3 + 7)]];"#; + let input = r#"let x = [[1], [2], [3 + 7]];"#; let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr()); assert_eq!(input.trim(), printed.trim()); } @@ -512,7 +493,7 @@ namespace N(2); fn type_args() { let input = r#" namespace N(2); - let max: T, T -> T = (|a, b| if (a < b) { b } else { a }); + let max: T, T -> T = (|a, b| if a < b { b } else { a }); let left: T1, T2 -> T1 = (|a, b| a); let seven = max::(3, 7); let five = left::(5, [7]); @@ -526,12 +507,12 @@ namespace N(2); fn type_args_with_space() { let input = r#" namespace N(2); - let max: T, T -> T = (|a, b| if (a < b) { b } else { a }); + let max: T, T -> T = (|a, b| if a < b { b } else { a }); let seven = max :: (3, 7); "#; let expected = r#" namespace N(2); - let max: T, T -> T = (|a, b| if (a < b) { b } else { a }); + let max: T, T -> T = (|a, b| if a < b { b } else { a }); let seven = max::(3, 7); "#; let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr()); diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 13f8c454d..9bfb5ee4d 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -545,7 +545,7 @@ ProductOp: BinaryOperator = { } Power: Box = { - => Box::new(Expression::BinaryOperation(<>)), + => Box::new(Expression::BinaryOperation(<>)), Unary, } diff --git a/parser/src/test_utils.rs b/parser/src/test_utils.rs new file mode 100644 index 000000000..7c99efc9c --- /dev/null +++ b/parser/src/test_utils.rs @@ -0,0 +1,26 @@ +use powdr_ast::{ + parsed::{PILFile, PilStatement}, + SourceRef, +}; +pub fn pil_statement_clear_source_ref(stmt: &mut PilStatement) { + match stmt { + PilStatement::Include(s, _) + | PilStatement::Namespace(s, _, _) + | PilStatement::LetStatement(s, _, _, _) + | PilStatement::PolynomialDefinition(s, _, _) + | PilStatement::PublicDeclaration(s, _, _, _, _) + | PilStatement::PolynomialConstantDeclaration(s, _) + | PilStatement::PolynomialConstantDefinition(s, _, _) + | PilStatement::PolynomialCommitDeclaration(s, _, _, _) + | PilStatement::PlookupIdentity(s, _, _) + | PilStatement::PermutationIdentity(s, _, _) + | PilStatement::ConnectIdentity(s, _, _) + | PilStatement::ConstantDefinition(s, _, _) + | PilStatement::Expression(s, _) + | PilStatement::EnumDeclaration(s, _) => *s = SourceRef::unknown(), + } +} +// helper function to clear SourceRef's inside the AST so we can compare for equality +pub fn pil_clear_source_refs(ast: &mut PILFile) { + ast.0.iter_mut().for_each(pil_statement_clear_source_ref); +} diff --git a/parser/tests/parentheses_test.rs b/parser/tests/parentheses_test.rs new file mode 100644 index 000000000..14279d91d --- /dev/null +++ b/parser/tests/parentheses_test.rs @@ -0,0 +1,136 @@ +#[cfg(test)] +mod test { + use powdr_parser::{parse, test_utils::pil_clear_source_refs}; + use powdr_parser_util::UnwrapErrToStderr; + use pretty_assertions::assert_eq; + use test_log::test; + + type TestCase = (&'static str, &'static str); + + fn test_paren(test_case: &TestCase) { + let (input, expected) = test_case; + let mut parsed = parse(None, input).unwrap_err_to_stderr(); + let printed = parsed.to_string(); + assert_eq!(expected.trim(), printed.trim()); + let mut re_parsed = parse(None, printed.as_str()).unwrap_err_to_stderr(); + + pil_clear_source_refs(&mut parsed); + pil_clear_source_refs(&mut re_parsed); + assert_eq!(parsed, re_parsed); + } + + #[test] + fn test_binary_op_parentheses() { + let test_cases: Vec = vec![ + // Complete line + ("let t = ((x + y) * z);", "let t = (x + y) * z;"), + // Don't add extra + ("-x + y * !z;", "-x + y * !z;"), + ("x = (y <= z);", "x = (y <= z);"), + ("(x = y) <= z;", "(x = y) <= z;"), + ("x + y + z;", "x + y + z;"), + ("x * y * z;", "x * y * z;"), + ("x / y / z;", "x / y / z;"), + // Remove unneeded + ("(-x) + y * (!z);", "-x + y * !z;"), + ("(x * y) * z;", "x * y * z;"), + ("(x / y) / z;", "x / y / z;"), + ("(x ** (y ** z));", "x ** (y ** z);"), + ("(x - (y + z));", "x - (y + z);"), + // Observe associativity + ("x * (y * z);", "x * (y * z);"), + ("x / (y / z);", "x / (y / z);"), + ("x ** (y ** z);", "x ** (y ** z);"), + ("(x ** y) ** z;", "(x ** y) ** z;"), + // Don't remove needed + ("(x + y) * z;", "(x + y) * z;"), + ("((x + y) * z);", "(x + y) * z;"), + ("-(x + y);", "-(x + y);"), + // function call + ("(a + b)(2);", "(a + b)(2);"), + // Index access + ("(a + b)[2];", "(a + b)[2];"), + ("(i < 7) && (6 >= -i);", "i < 7 && 6 >= -i;"), + // Power test + ("(-x) ** (-y);", "(-x) ** (-y);"), + ("2 ** x';", "2 ** (x');"), + ("(2 ** x)';", "(2 ** x)';"), + ]; + + for test_case in test_cases { + test_paren(&test_case); + } + } + + #[test] + fn test_lambda_ex_parentheses() { + let test_cases: Vec = vec![ + ("let x = 1 + (|i| i + 2);", "let x = 1 + (|i| i + 2);"), + ("let x = 1 + (|i| i) + 2;", "let x = 1 + (|i| i) + 2;"), + ("let x = 1 + (|i| (i + 2));", "let x = 1 + (|i| i + 2);"), + ("let x = (1 + (|i| i)) + 2;", "let x = 1 + (|i| i) + 2;"), + ("let x = (1 + (|i| (i + 2)));", "let x = 1 + (|i| i + 2);"), + ("let x = (1 + (|i| i + 2));", "let x = 1 + (|i| i + 2);"), + // Index access + ("(|i| i)[j];", "(|i| i)[j];"), + ]; + + for test_case in test_cases { + test_paren(&test_case); + } + } + + #[test] + fn test_parentheses_complex() { + let test_cases: Vec = vec![ + // Don't change concise expression + ( + "a | b * (c << d + e) & (f ^ g) = h * (i + g);", + "a | b * (c << d + e) & (f ^ g) = h * (i + g);", + ), + // Remove extra parentheses + ( + "(a | ((b * (c << (d + e))) & (f ^ g))) = (h * ((i + g)));", + "a | b * (c << d + e) & (f ^ g) = h * (i + g);", + ), + ( + "instr_or { 0, X, Y, Z } is (main_bin.latch * main_bin.sel[0]) { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C };", + "instr_or { 0, X, Y, Z } is main_bin.latch * main_bin.sel[0] { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C };", + ), + ( + "instr_or { 0, X, Y, Z } is main_bin.latch * main_bin.sel[0] { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C };", + "instr_or { 0, X, Y, Z } is main_bin.latch * main_bin.sel[0] { main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C };", + ), + ( + "pc' = (1 - first_step') * ((((instr__jump_to_operation * _operation_id) + (instr__loop * pc)) + (instr_return * 0)) + ((1 - ((instr__jump_to_operation + instr__loop) + instr_return)) * (pc + 1)));", + "pc' = (1 - first_step') * (instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1));", + ), + ( + "let root_of_unity_for_log_degree: int -> fe = |n| root_of_unity ** (2**(32 - n));", + "let root_of_unity_for_log_degree: int -> fe = (|n| root_of_unity ** (2 ** (32 - n)));", + ), + ]; + + for test_case in test_cases { + test_paren(&test_case); + } + } + + #[test] + fn test_index_access_parentheses() { + let test_cases: Vec = vec![ + ("(x')(2);", "(x')(2);"), + ("x[2](2);", "x[2](2);"), + ("(x')[2];", "(x')[2];"), + ("-x[2];", "-x[2];"), + ("(-x)[2];", "(-x)[2];"), + ("-(x[2]);", "-x[2];"), + ("1 + x[2];", "1 + x[2];"), + ("1 + x(2);", "1 + x(2);"), + ]; + + for test_case in test_cases { + test_paren(&test_case); + } + } +} diff --git a/pil-analyzer/tests/condenser.rs b/pil-analyzer/tests/condenser.rs index fdf5e51a1..23733b908 100644 --- a/pil-analyzer/tests/condenser.rs +++ b/pil-analyzer/tests/condenser.rs @@ -19,7 +19,7 @@ fn new_witness_column() { t[0] = t[1]; "#; let expected = r#"namespace N(16); - col fixed even(i) { (i * 2) }; + col fixed even(i) { i * 2 }; let new_wit: -> expr = (constr || { let x; x @@ -81,7 +81,7 @@ fn create_constraints() { y = x_is_zero + 2; "#; let expected = r#"namespace N(16); - let force_bool: expr -> std::prelude::Constr = (|c| ((c * (1 - c)) = 0)); + let force_bool: expr -> std::prelude::Constr = (|c| c * (1 - c) = 0); let new_bool: -> expr = (constr || { let x; N.force_bool(x); @@ -91,8 +91,8 @@ fn create_constraints() { let x_is_zero; N.force_bool(x_is_zero); let x_inv; - (x_is_zero = (1 - (x * x_inv))); - ((x_is_zero * x) = 0); + x_is_zero = 1 - x * x_inv; + x_is_zero * x = 0; x_is_zero }); col witness x; diff --git a/pil-analyzer/tests/parse_display.rs b/pil-analyzer/tests/parse_display.rs index d14ff90a3..ffe142a30 100644 --- a/pil-analyzer/tests/parse_display.rs +++ b/pil-analyzer/tests/parse_display.rs @@ -22,7 +22,7 @@ namespace std::convert(65536); namespace T(65536); col fixed first_step = [1] + [0]*; col fixed line(i) { i }; - let ops: int -> bool = (|i| ((i < 7) && (6 >= -i))); + let ops: int -> bool = (|i| i < 7 && 6 >= -i); col witness pc; col witness XInv; col witness XIsZero; @@ -50,7 +50,7 @@ namespace T(65536); T.A' = (((T.first_step' * 0) + (T.reg_write_X_A * T.X)) + ((1 - (T.first_step' + T.reg_write_X_A)) * T.A)); col witness X_free_value(__i) query match std::prover::eval(T.pc) { 0 => std::prover::Query::Input(1), - 3 => std::prover::Query::Input(std::convert::int::((std::prover::eval(T.CNT) + 1))), + 3 => std::prover::Query::Input(std::convert::int::(std::prover::eval(T.CNT) + 1)), 7 => std::prover::Query::Input(0), _ => std::prover::Query::None, }; @@ -117,9 +117,9 @@ namespace N(%r); namespace N(65536); col witness x; let z: int = 2; - col fixed t(i) { (i + N.z) }; + col fixed t(i) { i + N.z }; let other: int[] = [1, N.z]; - let other_fun: int, fe -> (int, (int -> int)) = (|i, j| ((i + 7), (|k| (k - i)))); + let other_fun: int, fe -> (int, (int -> int)) = (|i, j| (i + 7, (|k| k - i))); "#; let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, expected); @@ -168,8 +168,8 @@ fn namespaced_call() { "#; let expected = r#"namespace Assembly(2); let A: int -> int = (|i| 0); - let C: int -> int = (|i| (Assembly.A((i + 2)) + 3)); - let D: int -> int = (|i| Assembly.C((i + 3))); + let C: int -> int = (|i| Assembly.A(i + 2) + 3); + let D: int -> int = (|i| Assembly.C(i + 3)); "#; let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, expected); @@ -184,8 +184,8 @@ fn if_expr() { "#; let expected = r#"namespace Assembly(2); col fixed A = [0]*; - let c: int -> int = (|i| if (i < 3) { i } else { (i + 9) }); - col fixed D(i) { if (Assembly.c(i) != 0) { 3 } else { 2 } }; + let c: int -> int = (|i| if i < 3 { i } else { i + 9 }); + col fixed D(i) { if Assembly.c(i) != 0 { 3 } else { 2 } }; "#; let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, expected); @@ -205,11 +205,11 @@ fn symbolic_functions() { "#; let expected = r#"namespace N(16); let last_row: int = 15; - col fixed ISLAST(i) { if (i == N.last_row) { 1 } else { 0 } }; + col fixed ISLAST(i) { if i == N.last_row { 1 } else { 0 } }; col witness x; col witness y; - let constrain_equal_expr: expr, expr -> expr = (|A, B| (A - B)); - let on_regular_row: expr -> expr = (|cond| ((1 - N.ISLAST) * cond)); + let constrain_equal_expr: expr, expr -> expr = (|A, B| A - B); + let on_regular_row: expr -> expr = (|cond| (1 - N.ISLAST) * cond); ((1 - N.ISLAST) * (N.x' - N.y)) = 0; ((1 - N.ISLAST) * (N.y' - (N.x + N.y))) = 0; "#; @@ -228,7 +228,7 @@ fn next_op_on_param() { let expected = r#"namespace N(16); col witness x; col witness y; - let next_is_seven: expr -> expr = (|t| (t' - 7)); + let next_is_seven: expr -> expr = (|t| t' - 7); (N.y' - 7) = 0; "#; let formatted = analyze_string::(input).to_string(); @@ -247,7 +247,7 @@ fn fixed_symbolic() { "#; let expected = r#"namespace N(16); let last_row: int = 15; - let islast: int -> fe = (|i| if (i == N.last_row) { 1 } else { 0 }); + let islast: int -> fe = (|i| if i == N.last_row { 1 } else { 0 }); col fixed ISLAST(i) { N.islast(i) }; col witness x; col witness y; @@ -292,8 +292,8 @@ fn complex_type_resolution() { let z: (((int -> int), int -> int)[], expr) = ([x, x, x, x, x, x, x, x], y[0]); "#; let expected = r#"namespace N(16); - let f: int -> int = (|i| (i + 10)); - let x: (int -> int), int -> int = (|k, i| k((2 ** i))); + let f: int -> int = (|i| i + 10); + let x: (int -> int), int -> int = (|k, i| k(2 ** i)); col witness y[14]; let z: (((int -> int), int -> int)[], expr) = ([N.x, N.x, N.x, N.x, N.x, N.x, N.x, N.x], N.y[0]); "#; @@ -306,7 +306,7 @@ fn function_type_display() { let input = r#"namespace N(16); let f: (-> int)[] = [(|| 10), (|| 12)]; let g: (int -> int) -> int = (|f| f(0)); - let h: int -> (int -> int) = (|x| (|i| (x + i))); + let h: int -> (int -> int) = (|x| (|i| x + i)); "#; let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, input); @@ -323,8 +323,8 @@ fn expr_and_identity() { g((x)); "#; let expected = r#"namespace N(16); - let f: expr, expr -> std::prelude::Constr[] = (|x, y| [(x = y)]); - let g: expr -> std::prelude::Constr[] = (|x| [(x = 0)]); + let f: expr, expr -> std::prelude::Constr[] = (|x, y| [x = y]); + let g: expr -> std::prelude::Constr[] = (|x| [x = 0]); col witness x; col witness y; N.x = N.y; @@ -394,7 +394,7 @@ fn to_expr() { let expected = r#"namespace std::convert(16); let expr = []; namespace N(16); - let mul_two: int -> int = (|i| (i * 2)); + let mul_two: int -> int = (|i| i * 2); col witness y; N.y = (N.y * 14); "#; @@ -481,7 +481,7 @@ fn let_inside_block() { let t: int -> expr = constr |i| match i { 0 => { let x; x }, 1 => w, - _ => if (i < 3) { let y; y } else { w }, + _ => if i < 3 { let y; y } else { w }, }; { let z; @@ -497,7 +497,7 @@ fn let_inside_block() { x }, 1 => Main.w, - _ => if (i < 3) { + _ => if i < 3 { let y; y } else { Main.w }, @@ -642,7 +642,7 @@ fn disjoint_block_shadowing() { }; { let x = 3; - (x + b) + x + b } }; "; diff --git a/pil-analyzer/tests/side_effects.rs b/pil-analyzer/tests/side_effects.rs index 17309490b..9197d4445 100644 --- a/pil-analyzer/tests/side_effects.rs +++ b/pil-analyzer/tests/side_effects.rs @@ -12,7 +12,7 @@ fn new_wit_in_pure() { } #[test] -#[should_panic = "Tried to add a constraint in a pure context: (x = 7)"] +#[should_panic = "Tried to add a constraint in a pure context: x = 7"] fn constr_in_pure() { let input = r#"namespace N(16); let new_col = |x| { x = 7; [] }; @@ -77,7 +77,7 @@ fn constr_lambda_in_pure() { } #[test] -#[should_panic = "Tried to add a constraint in a pure context: (x = 7)"] +#[should_panic = "Tried to add a constraint in a pure context: x = 7"] fn reset_context() { let input = r#"namespace N(16); let new_col = |x| { x = 7; [] }; diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 699d3cb0f..ad10a6bc7 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -661,7 +661,7 @@ namespace N(65536); let expectation = r#"namespace N(65536); col witness x; col fixed cnt(i) { N.inc(i) }; - let inc: int -> int = (|x| (x + 1)); + let inc: int -> int = (|x| x + 1); { N.x } in { N.cnt }; "#; let optimized = optimize(analyze_string::(input)).to_string(); @@ -714,7 +714,7 @@ namespace N(65536); T, } let t: N::X[] -> int = (|r| 1); - col fixed f(i) { if (i == 0) { N.t([]) } else { (|x| 1)(N::Y::F([])) } }; + col fixed f(i) { if i == 0 { N.t([]) } else { (|x| 1)(N::Y::F([])) } }; col witness x; N.x = N.f; "#;