From a5bdd58023f2457dd6b3931e8406855a69da8f0a Mon Sep 17 00:00:00 2001 From: Leandro Pacheco Date: Mon, 16 Sep 2024 16:21:34 -0300 Subject: [PATCH 01/16] unify Query::Input/DataIdentifier handling (#1809) handle both queries in the same place --- pipeline/src/lib.rs | 58 +++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/pipeline/src/lib.rs b/pipeline/src/lib.rs index 4084f3ace9..52a73274d1 100644 --- a/pipeline/src/lib.rs +++ b/pipeline/src/lib.rs @@ -90,26 +90,6 @@ pub fn parse_query(query: &str) -> Result<(&str, Vec<&str>), String> { } } -pub fn access_element( - name: &str, - elements: &[T], - index_str: &str, -) -> Result, String> { - let index = index_str - .parse::() - .map_err(|e| format!("Error parsing index: {e})"))?; - let value = elements.get(index).cloned(); - if let Some(value) = value { - log::trace!("Query for {name}: Index {index} -> {value}"); - Ok(Some(value)) - } else { - Err(format!( - "Error accessing {name}: Index {index} out of bounds {}", - elements.len() - )) - } -} - pub fn serde_data_to_query_callback( channel: u32, bytes: Vec, @@ -117,7 +97,6 @@ pub fn serde_data_to_query_callback( move |query: &str| -> Result, String> { let (id, data) = parse_query(query)?; match id { - "None" => Ok(None), "DataIdentifier" => { let [index, cb_channel] = data[..] else { panic!() @@ -151,7 +130,6 @@ pub fn dict_data_to_query_callback( move |query: &str| -> Result, String> { let (id, data) = parse_query(query)?; match id { - "None" => Ok(None), "DataIdentifier" => { let [index, cb_channel] = data[..] else { panic!() @@ -160,7 +138,7 @@ pub fn dict_data_to_query_callback( .parse::() .map_err(|e| format!("Error parsing callback data channel: {e})"))?; - let Some(bytes) = dict.get(&cb_channel) else { + let Some(elems) = dict.get(&cb_channel) else { return Err("Callback channel mismatch".to_string()); }; @@ -170,29 +148,37 @@ pub fn dict_data_to_query_callback( // query index 0 means the length Ok(Some(match index { - 0 => (bytes.len() as u64).into(), - index => bytes[index - 1], + 0 => (elems.len() as u64).into(), + index => elems[index - 1], })) } - _ => Err(format!("Unsupported query: {query}")), - } - } -} - -pub fn inputs_to_query_callback(inputs: Vec) -> impl QueryCallback { - move |query: &str| -> Result, String> { - let (id, data) = parse_query(query)?; - match id { - "None" => Ok(None), "Input" => { assert_eq!(data.len(), 1); - access_element("prover inputs", &inputs, data[0]) + let index = data[0] + .parse::() + .map_err(|e| format!("Error parsing index: {e})"))?; + + let Some(elems) = dict.get(&0) else { + return Err("No prover inputs given".to_string()); + }; + + elems + .get(index) + .cloned() + .map(Some) + .ok_or_else(|| format!("Index out of bounds: {index}")) } _ => Err(format!("Unsupported query: {query}")), } } } +pub fn inputs_to_query_callback(inputs: Vec) -> impl QueryCallback { + let mut dict = BTreeMap::new(); + dict.insert(0, inputs); + dict_data_to_query_callback(dict) +} + #[allow(clippy::print_stdout)] pub fn handle_simple_queries_callback<'a, T: FieldElement>() -> impl QueryCallback + 'a { move |query: &str| -> Result, String> { From 7bf17e965134dea5738e8a70cf99b9e86a430f1e Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Tue, 17 Sep 2024 10:18:35 +0200 Subject: [PATCH 02/16] Avoid an array copy in the input and output of the poseidon_gl function. (#1808) The original array is mutated in-place instead. --------- Co-authored-by: Leo --- riscv-runtime/src/hash.rs | 33 ++++++++++--------- .../poseidon_gl_via_coprocessor/src/main.rs | 20 +++++------ 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/riscv-runtime/src/hash.rs b/riscv-runtime/src/hash.rs index 88715114df..71d277e73a 100644 --- a/riscv-runtime/src/hash.rs +++ b/riscv-runtime/src/hash.rs @@ -1,27 +1,30 @@ use core::arch::asm; +use core::convert::TryInto; use core::mem; use powdr_riscv_syscalls::Syscall; const GOLDILOCKS: u64 = 0xffffffff00000001; -/// Calls the low level Poseidon PIL machine, where -/// the last 4 elements are the "cap" -/// and the return value is placed in data[0:4]. -/// This is unsafe because it does not check if the u64 elements fit the Goldilocks field. -pub fn poseidon_gl_unsafe(mut data: [u64; 12]) -> [u64; 4] { +/// Calls the low level Poseidon PIL machine, where the last 4 elements are the +/// "cap", the return value is placed in data[..4] and the reference to this +/// sub-array is returned. +/// +/// This is unsafe because it does not check if the u64 elements fit the +/// Goldilocks field. +pub fn poseidon_gl_unsafe(data: &mut [u64; 12]) -> &[u64; 4] { unsafe { - asm!("ecall", in("a0") &mut data as *mut [u64; 12], in("t0") u32::from(Syscall::PoseidonGL)); + asm!("ecall", in("a0") data as *mut [u64; 12], in("t0") u32::from(Syscall::PoseidonGL)); } - - [data[0], data[1], data[2], data[3]] + data[..4].try_into().unwrap() } -/// Calls the low level Poseidon PIL machine, where -/// the last 4 elements are the "cap" -/// and the return value is placed in data[0:4]. -/// This function will panic if any of the u64 elements doesn't fit the Goldilocks field. -pub fn poseidon_gl(data: [u64; 12]) -> [u64; 4] { +/// Calls the low level Poseidon PIL machine, where the last 4 elements are the +/// "cap" and the return value is placed in data[0:4]. +/// +/// This function will panic if any of the u64 elements doesn't fit the +/// Goldilocks field. +pub fn poseidon_gl(data: &mut [u64; 12]) -> &[u64; 4] { for &n in data.iter() { assert!(n < GOLDILOCKS); } @@ -43,9 +46,9 @@ const W: usize = 32; /// Keccak function that calls the keccakf machine. /// Input is a byte array of arbitrary length and a delimiter byte. -/// Output is a byte array of length W. +/// Output is a byte array of length W. pub fn keccak(data: &[u8], delim: u8) -> [u8; W] { - let mut b = [[0u8; 200]; 2]; + let mut b = [[0u8; 200]; 2]; let [mut b_input, mut b_output] = &mut b; let rate = 200 - (2 * W); let mut pt = 0; diff --git a/riscv/tests/riscv_data/poseidon_gl_via_coprocessor/src/main.rs b/riscv/tests/riscv_data/poseidon_gl_via_coprocessor/src/main.rs index 1aa15bfa40..e5606a5902 100644 --- a/riscv/tests/riscv_data/poseidon_gl_via_coprocessor/src/main.rs +++ b/riscv/tests/riscv_data/poseidon_gl_via_coprocessor/src/main.rs @@ -5,29 +5,29 @@ use powdr_riscv_runtime::hash::{poseidon_gl, poseidon_gl_unsafe}; #[no_mangle] fn main() { - let i: [u64; 12] = [0; 12]; - let h = poseidon_gl(i); + let mut i: [u64; 12] = [0; 12]; + let h = poseidon_gl(&mut i); assert_eq!(h[0], 4330397376401421145); assert_eq!(h[1], 14124799381142128323); assert_eq!(h[2], 8742572140681234676); assert_eq!(h[3], 14345658006221440202); - let i: [u64; 12] = [1; 12]; - let h = poseidon_gl(i); + let mut i: [u64; 12] = [1; 12]; + let h = poseidon_gl(&mut i); assert_eq!(h[0], 16428316519797902711); assert_eq!(h[1], 13351830238340666928); assert_eq!(h[2], 682362844289978626); assert_eq!(h[3], 12150588177266359240); let minus_one = 0xffffffff00000001 - 1; - let i: [u64; 12] = [minus_one; 12]; - let h = poseidon_gl(i); + let mut i: [u64; 12] = [minus_one; 12]; + let h = poseidon_gl(&mut i); assert_eq!(h[0], 13691089994624172887); assert_eq!(h[1], 15662102337790434313); assert_eq!(h[2], 14940024623104903507); assert_eq!(h[3], 10772674582659927682); - let i: [u64; 12] = [ + let mut i: [u64; 12] = [ 18446744069414584321, 18446744069414584321, 18446744069414584321, @@ -41,13 +41,13 @@ fn main() { 0, 0, ]; - let h = poseidon_gl_unsafe(i); + let h = poseidon_gl_unsafe(&mut i); assert_eq!(h[0], 4330397376401421145); assert_eq!(h[1], 14124799381142128323); assert_eq!(h[2], 8742572140681234676); assert_eq!(h[3], 14345658006221440202); - let i: [u64; 12] = [ + let mut i: [u64; 12] = [ 923978, 235763497586, 9827635653498, @@ -61,7 +61,7 @@ fn main() { 0, 0, ]; - let h = poseidon_gl(i); + let h = poseidon_gl(&mut i); assert_eq!(h[0], 1892171027578617759); assert_eq!(h[1], 984732815927439256); assert_eq!(h[2], 7866041765487844082); From 3e0c96cfc85f3e994d6180e9b59307a9dde9c3a7 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Tue, 17 Sep 2024 13:45:57 -0400 Subject: [PATCH 03/16] Div and mul for 32 bit numbers for baby bear (#1778) Two operations: - `mul` takes two 32 bit-numbers, multiply them, and outputs two 32 bit-numbers, which are the upper and lower halves of the product - `div` takes two 32 bit-numbers, which are the dividend and divisor, respectively, and outputs two 32 bit-numbers, which are the quotient and the remainder, respectively Warning: - Although we do provide hints to calculate a unique remainder that's less than the divisor, this is not constrained and thus witness for `div` is not unique and not sound. Testing: - I created new test vectors but should be quite comprehensive (uses all variables). Implementation: - Uses the affine equation `x1 * y1 + x2 = y2 * 2**32 + y3`. - `x2` is constrained to 0 for `mul`. - `y2` is constrained to 0 for `div`. - This approach has the disadvantage of using five sets of columns, i.e. `x1`, `y1`, `x2`, `y2`, and `y3`, instead of just four sets. However, five sets of columns has the advantage of using roughly half of the number of constraints compared to if we were to have four sets of columns. This is because `mul` and `div` can be both defined by affine with five sets of columns, whereas defining `mul` and `div` separately in four sets of columns requires using the variables for different purposes depending on the operations, and thus more constraints. - Tl;dr: it's a trade off between one more variable (4 more columns) and fewer constraints. --- std/machines/arith16.asm | 189 +++++++++++++++++++++++++++++++++ std/machines/mod.asm | 1 + test_data/std/arith16_test.asm | 62 +++++++++++ 3 files changed, 252 insertions(+) create mode 100644 std/machines/arith16.asm create mode 100644 test_data/std/arith16_test.asm diff --git a/std/machines/arith16.asm b/std/machines/arith16.asm new file mode 100644 index 0000000000..45822182d9 --- /dev/null +++ b/std/machines/arith16.asm @@ -0,0 +1,189 @@ +use std::array; +use std::utils::unchanged_until; +use std::utils::force_bool; +use std::utils::sum; +use std::math::ff; +use std::check::panic; +use std::convert::int; +use std::convert::fe; +use std::convert::expr; +use std::prover::eval; +use std::prelude::Query; +use std::machines::range::Byte; + +// Arithmetic machine, ported mainly from Polygon: https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/arith.pil +// This machine supports eq0, which is the affine equation. Currently we only expose operations for mul and div. +machine Arith16(byte: Byte) with + latch: CLK8_7, + operation_id: operation_id, + // Allow this machine to be connected via a permutation + call_selectors: sel, +{ + col witness operation_id; + + // operation_id has to be either mul or div. + force_bool(operation_id); + + // Computes x1 * y1 + x2, where all inputs / outputs are 32-bit words (represented as 16-bit limbs in big-endian order). + // More precisely, affine_256(x1, y1, x2) = (y2, y3), where x1 * y1 + x2 = 2**16 * y2 + y3 + + // x1 * y1 = y2 * 2**16 + y3 + operation mul<0> x1c[1], x1c[0], y1c[1], y1c[0] -> y2c[1], y2c[0], y3c[1], y3c[0]; + + // Constrain that x2 = 0 when operation is mul. + array::new(4, |i| (1 - operation_id) * x2[i] = 0); + + // y3 / x1 = y1 (remainder x2) + // WARNING: it's not constrained that remainder is less than the divisor. + // This is done in the main machine, e.g. our RISCV BabyBear machine, that uses this operation. + operation div<1> y3c[1], y3c[0], x1c[1], x1c[0] -> y1c[1], y1c[0], x2c[1], x2c[0]; + + // Constrain that y2 = 0 when operation is div. + array::new(4, |i| operation_id * y2[i] = 0); + + // We need to provide hints for the quotient and remainder, because they are not unique under our current constraints. + // They are unique given additional main machine constraints, but it's still good to provide hints for the solver. + let quotient_hint = query |limb| match(eval(operation_id)) { + 1 => { + let y3 = y3_int(); + let x1 = x1_int(); + let quotient = y3 / x1; + Query::Hint(fe(select_limb(quotient, limb))) + }, + _ => Query::None + }; + + col witness y1_0(i) query quotient_hint(0); + col witness y1_1(i) query quotient_hint(1); + col witness y1_2(i) query quotient_hint(2); + col witness y1_3(i) query quotient_hint(3); + + let y1: expr[] = [y1_0, y1_1, y1_2, y1_3]; + + let remainder_hint = query |limb| match(eval(operation_id)) { + 1 => { + let y3 = y3_int(); + let x1 = x1_int(); + let remainder = y3 % x1; + Query::Hint(fe(select_limb(remainder, limb))) + }, + _ => Query::None + }; + + col witness x2_0(i) query remainder_hint(0); + col witness x2_1(i) query remainder_hint(1); + col witness x2_2(i) query remainder_hint(2); + col witness x2_3(i) query remainder_hint(3); + + let x2: expr[] = [x2_0, x2_1, x2_2, x2_3]; + + pol commit x1[4], y2[4], y3[4]; + + // Selects the ith limb of x (little endian) + // All limbs are 8 bits + let select_limb = |x, i| if i >= 0 { + (x >> (i * 8)) & 0xff + } else { + 0 + }; + + let limbs_to_int: expr[] -> int = query |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(eval(limb)) << (i * 8))); + + let x1_int = query || limbs_to_int(x1); + let y1_int = query || limbs_to_int(y1); + let x2_int = query || limbs_to_int(x2); + let y2_int = query || limbs_to_int(y2); + let y3_int = query || limbs_to_int(y3); + + let combine: expr[] -> expr[] = |x| array::new(array::len(x) / 2, |i| x[2 * i + 1] * 2**8 + x[2 * i]); + // Intermediate polynomials, arrays of 16 columns, 16 bit per column. + col x1c[2] = combine(x1); + col y1c[2] = combine(y1); + col x2c[2] = combine(x2); + col y2c[2] = combine(y2); + col y3c[2] = combine(y3); + + let CLK8: col[8] = array::new(8, |i| |row| if row % 8 == i { 1 } else { 0 }); + let CLK8_7: expr = CLK8[7]; + + /**** + * + * LATCH POLS: x1,y1,x2,y2,y3 + * + *****/ + + let fixed_inside_8_block = |e| unchanged_until(e, CLK8[7]); + + array::map(x1, fixed_inside_8_block); + array::map(y1, fixed_inside_8_block); + array::map(x2, fixed_inside_8_block); + array::map(y2, fixed_inside_8_block); + array::map(y3, fixed_inside_8_block); + + /**** + * + * RANGE CHECK x1,y1,x2,y2,y3 + * + *****/ + + link => byte.check(sum(4, |i| x1[i] * CLK8[i]) + sum(4, |i| y1[i] * CLK8[4 + i])); + link => byte.check(sum(4, |i| x2[i] * CLK8[i]) + sum(4, |i| y2[i] * CLK8[4 + i])); + link => byte.check(sum(4, |i| y3[i] * CLK8[i])); + + /******* + * + * EQ0: A(x1) * B(y1) + C(x2) = D (y2) * 2 ** 16 + op (y3) + * x1 * y1 + x2 - y2 * 2**256 - y3 = 0 + * + *******/ + + /// returns a(0) * b(0) + ... + a(n - 1) * b(n - 1) + let dot_prod = |n, a, b| sum(n, |i| a(i) * b(i)); + /// returns |n| a(0) * b(n) + ... + a(n) * b(0) + let product = |a, b| |n| dot_prod(n + 1, a, |i| b(n - i)); + + /// Converts array to function, extended by zeros. + let array_as_fun: expr[] -> (int -> expr) = |arr| |i| if 0 <= i && i < array::len(arr) { + arr[i] + } else { + 0 + }; + let shift_right = |fn, amount| |i| fn(i - amount); + + let x1f = array_as_fun(x1); + let y1f = array_as_fun(y1); + let x2f = array_as_fun(x2); + let y2f = array_as_fun(y2); + let y3f = array_as_fun(y3); + + // Defined for arguments from 0 to 7 (inclusive) + let eq0 = |nr| + product(x1f, y1f)(nr) + + x2f(nr) + - shift_right(y2f, 4)(nr) + - y3f(nr); + + /******* + * + * Carry + * + *******/ + + pol witness carry_low, carry_high; + link => byte.check(carry_low); + link => byte.check(carry_high); + + let carry = carry_high * 2**8 + carry_low; + + carry * CLK8[0] = 0; + + /******* + * + * Putting everything together + * + *******/ + + col eq0_sum = sum(8, |i| eq0(i) * CLK8[i]); + + eq0_sum + carry = carry' * 2**8; +} diff --git a/std/machines/mod.asm b/std/machines/mod.asm index 5350d4675c..cee2d20682 100644 --- a/std/machines/mod.asm +++ b/std/machines/mod.asm @@ -1,4 +1,5 @@ mod arith; +mod arith16; mod binary; mod binary_bb; mod range; diff --git a/test_data/std/arith16_test.asm b/test_data/std/arith16_test.asm new file mode 100644 index 0000000000..c0c18c83cc --- /dev/null +++ b/test_data/std/arith16_test.asm @@ -0,0 +1,62 @@ +use std::machines::arith16::Arith16; +use std::machines::range::Byte; + +machine Main with degree: 65536 { + reg pc[@pc]; + reg A0[<=]; + reg A1[<=]; + reg B0[<=]; + reg B1[<=]; + reg C0[<=]; + reg C1[<=]; + reg D0[<=]; + reg D1[<=]; + + reg t_0_0; + reg t_0_1; + reg t_1_0; + reg t_1_1; + + Byte byte; + + Arith16 arith(byte); + + instr mul A0, A1, B0, B1 -> C0, C1, D0, D1 + link ~> (C0, C1, D0, D1) = arith.mul(A0, A1, B0, B1); + + instr div A0, A1, B0, B1 -> C0, C1, D0, D1 + link ~> (C0, C1, D0, D1) = arith.div(A0, A1, B0, B1); + + instr assert_eq A0, A1, B0, B1, C0, C1, D0, D1 { + A0 = C0, + A1 = C1, + B0 = D0, + B1 = D1 + } + + function main { + // 2 * 3 = 6 + t_0_0, t_0_1, t_1_0, t_1_1 <== mul(0, 2, 0, 3); + assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0, 0, 0, 6; + + // (2**32 - 1) * (2**32 - 1) = 2**64 - 2**33 + 1 + t_0_0, t_0_1, t_1_0, t_1_1 <== mul(0xffff, 0xffff, 0xffff, 0xffff); + assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0xffff, 0xfffe, 0x0000, 0x0001; + + // 7 / 3 = 2 (remainder 1) + t_0_0, t_0_1, t_1_0, t_1_1 <== div(0, 7, 0, 3); + assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0, 2, 0, 1; + + // 0xffffffff / 0xfffff = 0x1000 (remainder 0xfff) + t_0_0, t_0_1, t_1_0, t_1_1 <== div(0xffff, 0xffff, 0xf, 0xffff); + assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0, 0x1000, 0, 0xfff; + + // 0xfffffffe / 0xff = 0x1010100 (remainder 0xfe) + t_0_0, t_0_1, t_1_0, t_1_1 <== div(0xffff, 0xfffe, 0, 0xff); + assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0x101, 0x100, 0, 0xfe; + + // 0xffffeff / 0xfffff = 0xff (remainder 0xffffe) + t_0_0, t_0_1, t_1_0, t_1_1 <== div(0xfff, 0xfeff, 0xf, 0xffff); + assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0, 0xff, 0xf, 0xfffe; + } +} From 10f59cd9e3aae2b9bb906d73bf43eecf5f3739b8 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Wed, 18 Sep 2024 21:52:30 +0200 Subject: [PATCH 04/16] SplitBB machine (#1811) Very copy-pastey stuff. --- pipeline/tests/powdr_std.rs | 8 ++++ std/machines/split/mod.asm | 3 +- std/machines/split/split_bb.asm | 77 +++++++++++++++++++++++++++++++++ test_data/std/split_bb_test.asm | 49 +++++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 std/machines/split/split_bb.asm create mode 100644 test_data/std/split_bb_test.asm diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 0254b7bf1a..55e3d77719 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -58,6 +58,14 @@ fn split_gl_test() { gen_estark_proof(make_simple_prepared_pipeline(f)); } +#[cfg(feature = "plonky3")] +#[test] +#[ignore = "Too slow"] +fn split_bb_test() { + let f = "std/split_bb_test.asm"; + test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); +} + #[test] #[ignore = "Too slow"] fn arith_test() { diff --git a/std/machines/split/mod.asm b/std/machines/split/mod.asm index 866a5d4889..bfd89c2c71 100644 --- a/std/machines/split/mod.asm +++ b/std/machines/split/mod.asm @@ -1,5 +1,6 @@ mod split_bn254; mod split_gl; +mod split_bb; use std::utils::cross_product; @@ -21,4 +22,4 @@ machine ByteCompare with col fixed latch = [1]*; col fixed operation_id = [0]*; -} \ No newline at end of file +} diff --git a/std/machines/split/split_bb.asm b/std/machines/split/split_bb.asm new file mode 100644 index 0000000000..0f72854349 --- /dev/null +++ b/std/machines/split/split_bb.asm @@ -0,0 +1,77 @@ +use std::prelude::Query; +use super::ByteCompare; + +// Splits an arbitrary field element into two u16s, on the BabyBear field. +machine SplitBB(byte_compare: ByteCompare) with + latch: RESET, + // Allow this machine to be connected via a permutation + call_selectors: sel, +{ + operation split in_acc -> output_low, output_high; + + // Latch and operation ID + col fixed RESET(i) { if i % 4 == 3 { 1 } else { 0 } }; + + // 1. Decompose the input into bytes + + // The byte decomposition of the input, in little-endian order + // and shifted forward by one (to use the last row of the + // previous block) + // A hint is provided because automatic witness generation does not + // understand step 3 to figure out that the byte decomposition is unique. + let select_byte: fe, int -> fe = |input, byte| std::convert::fe((std::convert::int(input) >> (byte * 8)) & 0xff); + col witness bytes; + query |i| { + std::prover::provide_value(bytes, i, select_byte(std::prover::eval(in_acc'), (i + 1) % 4)); + }; + // Puts the bytes together to form the input + col witness in_acc; + // Factors to multiply the bytes by + col fixed FACTOR(i) { 1 << (((i + 1) % 4) * 8) }; + + in_acc' = (1 - RESET) * in_acc + bytes * FACTOR; + + // 2. Build the output, packing chunks of 2 bytes (i.e., 16 bit) into a field element + col witness output_low, output_high; + col fixed FACTOR_OUTPUT_LOW = [0x100, 0, 0, 1]*; + col fixed FACTOR_OUTPUT_HIGH = [0, 1, 0x100, 0]*; + output_low' = (1 - RESET) * output_low + bytes * FACTOR_OUTPUT_LOW; + output_high' = (1 - RESET) * output_high + bytes * FACTOR_OUTPUT_HIGH; + + // 3. Check that the byte decomposition does not overflow + // + // Skipping this step would work but it wouldn't be sound, because + // the 4-byte decomposition could overflow, since the BabyBear + // prime 2**31 - 2**27 + 1 is smaller than 2^32. + // + // The approach is to compare the byte decomposition with that of + // the maximum possible value (0x78000000) byte by byte, + // from most significant to least significant (i.e., going backwards). + // A byte can only be larger than that of the max value if any previous + // byte has been smaller. + + // This is an example for input 0x77ffffff: + // Row RESET bytes BYTES_MAX lt was_lt gt + // -1 0x1 0xff 0x0 0x0 0x1 0x1 + // 0 0x0 0xff 0x0 0x0 0x1 0x1 + // 1 0x0 0xff 0x0 0x0 0x1 0x1 + // 2 0x0 0x77 0x78 0x1 0x1 0x0 # 0x77 < 0x78, so now greater bytes are allowed + + // Bytes of the maximum value, in little endian order, rotated by one + col fixed BYTES_MAX = [0, 0, 0x78, 0]*; + + // Compare the current byte with the corresponding byte of the maximum value. + col witness lt; + col witness gt; + link => (lt, gt) = byte_compare.run(bytes, BYTES_MAX); + + // Compute whether the current or any previous byte has been less than + // the corresponding byte of the maximum value. + // This moves *backward* from the second to last row. + col witness was_lt; + was_lt = RESET' * lt + (1 - RESET') * (was_lt' + lt - was_lt' * lt); + + // If any byte is larger, but no previous byte was smaller, the byte + // decomposition has overflowed and should be rejected. + gt * (1 - was_lt) = 0; +} diff --git a/test_data/std/split_bb_test.asm b/test_data/std/split_bb_test.asm new file mode 100644 index 0000000000..fd9e0b2725 --- /dev/null +++ b/test_data/std/split_bb_test.asm @@ -0,0 +1,49 @@ +use std::machines::split::ByteCompare; +use std::machines::split::split_bb::SplitBB; + +machine Main with degree: 65536 { + reg pc[@pc]; + reg X0[<=]; + reg X1[<=]; + reg X2[<=]; + reg low; + reg high; + + ByteCompare byte_compare; + SplitBB split_machine(byte_compare); + + instr split X0 -> X1, X2 link ~> (X1, X2) = split_machine.split(X0); + + instr assert_eq X0, X1 { + X0 = X1 + } + + function main { + + // Min value + // Note that this has two byte decompositions, 0 and p = 0x78000001. + // The second would lead to a different split value, but should be ruled + // out by the overflow check. + low, high <== split(0); + assert_eq low, 0; + assert_eq high, 0; + + // Max value + // On BabyBear, this is 0x78000000. + low, high <== split(-1); + assert_eq low, 0; + assert_eq high, 0x7800; + + // Max low value + low, high <== split(0x77ffffff); + assert_eq low, 0xffff; + assert_eq high, 0x77ff; + + // Some other value + low, high <== split(0x42abcdef); + assert_eq low, 0xcdef; + assert_eq high, 0x42ab; + + return; + } +} From 0edb7686cc3401f46c7127eae4dee78dd1548c2d Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 19 Sep 2024 16:04:38 +0200 Subject: [PATCH 05/16] Check length. (#1806) --- std/array.asm | 6 ++++-- std/machines/hash/poseidon_bn254.asm | 3 ++- std/machines/hash/poseidon_gl.asm | 3 ++- std/machines/hash/poseidon_gl_memory.asm | 3 ++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/std/array.asm b/std/array.asm index dbe271948d..d2eefdfdc4 100644 --- a/std/array.asm +++ b/std/array.asm @@ -36,8 +36,10 @@ let sum: T[] -> T = |arr| fold(arr, 0, |a, b| a + b); let product: T[] -> T = |arr| fold(arr, 1, |a, b| a * b); /// Zips two arrays -/// TODO: Assert that lengths are equal when expressions are supported. -let zip: T1[], T2[], (T1, T2 -> T3) -> T3[] = |array1, array2, fn| new(len(array1), |i| fn(array1[i], array2[i])); +let zip: T1[], T2[], (T1, T2 -> T3) -> T3[] = |array1, array2, fn| { + std::check::assert(len(array1) == len(array2), || "Array lengths do not match"); + new(len(array1), |i| fn(array1[i], array2[i])) +}; /// Returns f(i, arr[i]) for the first i where this is not None, or None if no such i exists. let find_map_enumerated: T1[], (int, T1 -> Option) -> Option = diff --git a/std/machines/hash/poseidon_bn254.asm b/std/machines/hash/poseidon_bn254.asm index 7336b8ef01..82b6971f31 100644 --- a/std/machines/hash/poseidon_bn254.asm +++ b/std/machines/hash/poseidon_bn254.asm @@ -89,7 +89,8 @@ machine PoseidonBN254 with array::zip(state, c, |state, c| (state' - c) * (1-LAST) = 0); // In the last row, the first OUTPUT_SIZE elements of the state should equal output - array::zip(output, state, |output, state| LASTBLOCK * (output - state) = 0); + let output_state = array::sub_array(state, 0, OUTPUT_SIZE); + array::zip(output, output_state, |output, state| LASTBLOCK * (output - state) = 0); // The output should stay constant in the block array::map(output, |c| unchanged_until(c, LAST)); diff --git a/std/machines/hash/poseidon_gl.asm b/std/machines/hash/poseidon_gl.asm index bb6e8dfb6c..827d31f893 100644 --- a/std/machines/hash/poseidon_gl.asm +++ b/std/machines/hash/poseidon_gl.asm @@ -103,7 +103,8 @@ machine PoseidonGL with array::zip(state, c, |state, c| (state' - c) * (1-LAST) = 0); // In the last row, the first OUTPUT_SIZE elements of the state should equal output - array::zip(output, state, |output, state| LASTBLOCK * (output - state) = 0); + let output_state = array::sub_array(state, 0, OUTPUT_SIZE); + array::zip(output, output_state, |output, state| LASTBLOCK * (output - state) = 0); // The output should stay constant in the block array::map(output, |c| unchanged_until(c, LAST)); diff --git a/std/machines/hash/poseidon_gl_memory.asm b/std/machines/hash/poseidon_gl_memory.asm index 18a351d1d9..bf048346b3 100644 --- a/std/machines/hash/poseidon_gl_memory.asm +++ b/std/machines/hash/poseidon_gl_memory.asm @@ -193,7 +193,8 @@ machine PoseidonGLMemory(mem: Memory, split_gl: SplitGL) with array::zip(state, c, |state, c| (state' - c) * (1-LAST) = 0); // In the last row, the first OUTPUT_SIZE elements of the state should equal output - array::zip(output, state, |output, state| LASTBLOCK * (output - state) = 0); + let output_state = array::sub_array(state, 0, OUTPUT_SIZE); + array::zip(output, output_state, |output, state| LASTBLOCK * (output - state) = 0); // The output should stay constant in the block array::map(output, |c| unchanged_until(c, LAST)); From 55a1748786391832fd1436ccee39f37ed49877d2 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 19 Sep 2024 19:13:29 +0200 Subject: [PATCH 06/16] Support capturing enum values and challenges (#1799) After this, we can process prover functions that capture enum values and challenges. This is needed to improve the lookup transformation functions. --- pil-analyzer/src/condenser.rs | 96 +++++++++++++---- pil-analyzer/src/evaluator.rs | 182 +++++++++++++++++++++++++------- pil-analyzer/tests/condenser.rs | 79 ++++++++++++++ pipeline/src/lib.rs | 1 + 4 files changed, 301 insertions(+), 57 deletions(-) diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 5f16646d53..7749e9ba40 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -13,8 +13,8 @@ use num_traits::sign::Signed; use powdr_ast::{ analyzed::{ - self, AlgebraicExpression, AlgebraicReference, Analyzed, DegreeRange, Expression, - FunctionValueDefinition, Identity, IdentityKind, PolyID, PolynomialReference, + self, AlgebraicExpression, AlgebraicReference, Analyzed, Challenge, DegreeRange, + Expression, FunctionValueDefinition, Identity, IdentityKind, PolyID, PolynomialReference, PolynomialType, PublicDeclaration, Reference, SelectedExpressions, StatementIdentifier, Symbol, SymbolKind, }, @@ -24,15 +24,15 @@ use powdr_ast::{ display::format_type_scheme_around_name, types::{ArrayType, Type}, visitor::{AllChildren, ExpressionVisitable}, - ArrayLiteral, BlockExpression, FunctionKind, LambdaExpression, LetStatementInsideBlock, - Number, Pattern, TypedExpression, UnaryOperation, + ArrayLiteral, BlockExpression, FunctionCall, FunctionKind, LambdaExpression, + LetStatementInsideBlock, Number, Pattern, TypedExpression, UnaryOperation, }, }; use powdr_number::{BigUint, FieldElement}; use powdr_parser_util::SourceRef; use crate::{ - evaluator::{self, Closure, Definitions, EvalError, SymbolLookup, Value}, + evaluator::{self, Closure, Definitions, EnumValue, EvalError, SymbolLookup, Value}, statement_processor::Counters, }; @@ -625,8 +625,18 @@ fn to_constraint( source: SourceRef, counters: &mut Counters, ) -> AnalyzedIdentity { - match constraint { - Value::Enum("Identity", Some(fields)) => { + let Value::Enum(EnumValue { + enum_decl, + variant, + data, + }) = constraint + else { + panic!("Expected constraint but got {constraint}") + }; + assert_eq!(enum_decl.name, "std::prelude::Constr"); + let fields = data.as_ref().unwrap(); + match &**variant { + "Identity" => { assert_eq!(fields.len(), 2); AnalyzedIdentity::from_polynomial_identity( counters.dispense_identity_id(), @@ -634,9 +644,9 @@ fn to_constraint( to_expr(&fields[0]) - to_expr(&fields[1]), ) } - Value::Enum(kind @ "Lookup" | kind @ "Permutation", Some(fields)) => { + "Lookup" | "Permutation" => { assert_eq!(fields.len(), 2); - let kind = if *kind == "Lookup" { + let kind = if variant == &"Lookup" { IdentityKind::Plookup } else { IdentityKind::Permutation @@ -672,7 +682,7 @@ fn to_constraint( right: to_selected_exprs(sel_to, to), } } - Value::Enum("Connection", Some(fields)) => { + "Connection" => { assert_eq!(fields.len(), 1); let (from, to): (Vec<_>, Vec<_>) = if let Value::Array(a) = fields[0].as_ref() { @@ -719,9 +729,14 @@ fn to_selected_exprs<'a, T: Clone + Debug>( } fn to_option_expr(value: &Value<'_, T>) -> Option> { - match value { - Value::Enum("None", None) => None, - Value::Enum("Some", Some(fields)) => { + let Value::Enum(enum_value) = value else { + panic!("Expected option but got {value:?}") + }; + assert_eq!(enum_value.enum_decl.name, "std::prelude::Option"); + match enum_value.variant { + "None" => None, + "Some" => { + let fields = enum_value.data.as_ref().unwrap(); assert_eq!(fields.len(), 1); Some(to_expr(&fields[0])) } @@ -940,16 +955,31 @@ fn try_value_to_expression(value: &Value<'_, T>) -> Result try_closure_to_expression(c)?, - Value::TypeConstructor(c) => { + Value::TypeConstructor(type_constructor) => { return Err(EvalError::TypeError(format!( - "Type constructor as captured value not supported: {c}." + "Type constructor as captured value not supported: {type_constructor}.", ))) } - Value::Enum(variant, _items) => { - // The main problem is that we do not know the type of the enum. - return Err(EvalError::TypeError(format!( - "Enum as captured value not supported: {variant}." - ))); + Value::Enum(enum_value) => { + let variant_ref = Expression::Reference( + SourceRef::unknown(), + Reference::Poly(PolynomialReference { + name: format!("{}::{}", enum_value.enum_decl.name, enum_value.variant), + // We do not know the type args here. + type_args: None, + }), + ); + match &enum_value.data { + None => variant_ref, + Some(items) => FunctionCall { + function: Box::new(variant_ref), + arguments: items + .iter() + .map(|i| try_value_to_expression(i)) + .collect::>()?, + } + .into(), + } } Value::BuiltinFunction(_) => { return Err(EvalError::TypeError( @@ -968,6 +998,32 @@ fn try_value_to_expression(value: &Value<'_, T>) -> Result { + let function = Expression::Reference( + SourceRef::unknown(), + Reference::Poly(PolynomialReference { + name: "std::prelude::challenge".to_string(), + type_args: None, + }), + ) + .into(); + let arguments = [*stage as u64, *id] + .into_iter() + .map(|x| BigUint::from(x).into()) + .collect(); + Expression::FunctionCall( + SourceRef::unknown(), + FunctionCall { + function, + arguments, + }, + ) + } + AlgebraicExpression::Number(n) => Number { + value: n.to_arbitrary_integer(), + type_: Some(Type::Expr), + } + .into(), _ => { return Err(EvalError::TypeError(format!( "Algebraic expression as captured value not supported: {e}." diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index 4fba22f957..5c4adf85cd 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -16,9 +16,10 @@ use powdr_ast::{ parsed::{ display::quote, types::{ArrayType, Type, TypeScheme}, - ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, - IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Number, - Pattern, StatementInsideBlock, UnaryOperation, UnaryOperator, + ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, EnumDeclaration, + FunctionCall, IfExpression, IndexAccess, LambdaExpression, LetStatementInsideBlock, + MatchArm, MatchExpression, Number, Pattern, StatementInsideBlock, UnaryOperation, + UnaryOperator, }, }; use powdr_number::{BigInt, BigUint, FieldElement, LargeInt}; @@ -144,8 +145,8 @@ pub enum Value<'a, T> { Tuple(Vec>), Array(Vec>), Closure(Closure<'a, T>), - TypeConstructor(&'a str), - Enum(&'a str, Option>>), + TypeConstructor(TypeConstructorValue<'a>), + Enum(EnumValue<'a, T>), BuiltinFunction(BuiltinFunction), Expression(AlgebraicExpression), } @@ -221,8 +222,8 @@ impl<'a, T: FieldElement> Value<'a, T> { ) } Value::Closure(c) => c.type_formatted(), - Value::TypeConstructor(name) => format!("{name}_constructor"), - Value::Enum(name, _) => name.to_string(), + Value::TypeConstructor(tc) => tc.type_formatted(), + Value::Enum(enum_val) => enum_val.type_formatted(), Value::BuiltinFunction(b) => format!("builtin_{b:?}"), Value::Expression(_) => "expr".to_string(), } @@ -286,14 +287,14 @@ impl<'a, T: FieldElement> Value<'a, T> { } Pattern::Variable(_, _) => Some(vec![v.clone()]), Pattern::Enum(_, name, fields_pattern) => { - let Value::Enum(n, data) = v.as_ref() else { + let Value::Enum(enum_value) = v.as_ref() else { panic!() }; - if name.name() != n { + if name.name() != enum_value.variant { return None; } if let Some(fields) = fields_pattern { - Value::try_match_pattern_list(data.as_ref().unwrap(), fields) + Value::try_match_pattern_list(enum_value.data.as_ref().unwrap(), fields) } else { Some(vec![]) } @@ -318,6 +319,86 @@ impl<'a, T: FieldElement> Value<'a, T> { } } +/// An enum variant with its data as a value. +/// The enum declaration is provided to allow proper printing and other functions. +#[derive(Clone, Debug)] +pub struct EnumValue<'a, T> { + pub enum_decl: &'a EnumDeclaration, + pub variant: &'a str, + pub data: Option>>>, +} + +impl<'a, T: Display> EnumValue<'a, T> { + pub fn type_formatted(&self) -> String { + self.enum_decl.name.to_string() + } +} + +impl<'a, T: Display> Display for EnumValue<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}::{}", self.enum_decl.name, self.variant)?; + if let Some(data) = &self.data { + write!(f, "({})", data.iter().format(", "))?; + } + Ok(()) + } +} + +/// An enum type constructor value, i.e. the value arising from referencing an +/// enum variant that takes data. +#[derive(Clone, Debug)] +pub struct TypeConstructorValue<'a> { + pub enum_decl: &'a EnumDeclaration, + pub variant: &'a str, +} + +impl<'a> TypeConstructorValue<'a> { + pub fn type_formatted(&self) -> String { + self.enum_decl.name.to_string() + } + + pub fn to_enum_value(&self, data: Vec>>) -> EnumValue<'a, T> { + EnumValue { + enum_decl: self.enum_decl, + variant: self.variant, + data: Some(data), + } + } +} + +impl<'a> Display for TypeConstructorValue<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}::{}", self.enum_decl.name, self.variant) + } +} + +// Some enums from the prelude. We can remove this once we implement the +// `$`, `in`, `is` and `connect` operators using traits. +// The declarations are wrong, but we only need their name for now. +lazy_static::lazy_static! { + static ref OPTION: EnumDeclaration = EnumDeclaration { name: "std::prelude::Option".to_string(), type_vars: Default::default(), variants: Default::default() }; + static ref SELECTED_EXPRS: EnumDeclaration = EnumDeclaration { name: "std::prelude::SelectedExprs".to_string(), type_vars: Default::default(), variants: Default::default() }; + static ref CONSTR: EnumDeclaration = EnumDeclaration { name: "std::prelude::Constr".to_string(), type_vars: Default::default(), variants: Default::default() }; +} + +/// Convenience functions to build an Option::Some value. +fn some_value(data: Arc>) -> Value<'_, T> { + Value::Enum(EnumValue { + enum_decl: &OPTION, + variant: "Some", + data: Some(vec![data]), + }) +} + +/// Convenience functions to build an Option::None value. +fn none_value<'a, T>() -> Value<'a, T> { + Value::Enum(EnumValue { + enum_decl: &OPTION, + variant: "None", + data: None, + }) +} + const BUILTINS: [(&str, BuiltinFunction); 20] = [ ("std::array::len", BuiltinFunction::ArrayLen), ("std::check::panic", BuiltinFunction::Panic), @@ -401,14 +482,8 @@ impl<'a, T: Display> Display for Value<'a, T> { Value::Tuple(items) => write!(f, "({})", items.iter().format(", ")), Value::Array(elements) => write!(f, "[{}]", elements.iter().format(", ")), Value::Closure(closure) => write!(f, "{closure}"), - Value::TypeConstructor(name) => write!(f, "{name}_constructor"), - Value::Enum(name, data) => { - write!(f, "{name}")?; - if let Some(data) = data { - write!(f, "({})", data.iter().format(", "))?; - } - Ok(()) - } + Value::TypeConstructor(tc) => write!(f, "{tc}"), + Value::Enum(enum_value) => write!(f, "{enum_value}"), Value::BuiltinFunction(b) => write!(f, "{b:?}"), Value::Expression(e) => write!(f, "{e}"), } @@ -492,11 +567,20 @@ impl<'a> Definitions<'a> { let type_args = type_arg_mapping(type_scheme, type_args); evaluate_generic(value, &type_args, symbols)? } - Some(FunctionValueDefinition::TypeConstructor(_type_name, variant)) => { + Some(FunctionValueDefinition::TypeConstructor(type_name, variant)) => { if variant.fields.is_none() { - Value::Enum(&variant.name, None).into() + Value::Enum(EnumValue { + enum_decl: type_name.as_ref(), + variant: &variant.name, + data: None, + }) + .into() } else { - Value::TypeConstructor(&variant.name).into() + Value::TypeConstructor(TypeConstructorValue { + enum_decl: type_name.as_ref(), + variant: &variant.name, + }) + .into() } } Some(FunctionValueDefinition::TraitFunction(_, _)) => { @@ -1055,9 +1139,9 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { self.value_stack .push(evaluate_builtin_function(*b, arguments, self.symbols)?) } - Value::TypeConstructor(name) => self + Value::TypeConstructor(type_constructor) => self .value_stack - .push(Value::Enum(name, Some(arguments)).into()), + .push(Value::Enum(type_constructor.to_enum_value(arguments)).into()), Value::Closure(Closure { lambda, environment, @@ -1184,7 +1268,12 @@ fn evaluate_binary_operation<'a, T: FieldElement>( } } (l @ Value::Expression(_), BinaryOperator::Identity, r @ Value::Expression(_)) => { - Value::Enum("Identity", Some(vec![l.clone().into(), r.clone().into()])).into() + Value::Enum(EnumValue { + enum_decl: &CONSTR, + variant: "Identity", + data: Some(vec![l.clone().into(), r.clone().into()]), + }) + .into() } (Value::Expression(l), op, Value::Expression(r)) => match (l, r) { (AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => { @@ -1201,9 +1290,12 @@ fn evaluate_binary_operation<'a, T: FieldElement>( )) .into(), }, - (Value::Expression(_), BinaryOperator::Select, Value::Array(_)) => { - Value::Enum("SelectedExprs", Some(vec![left, right])).into() - } + (Value::Expression(_), BinaryOperator::Select, Value::Array(_)) => Value::Enum(EnumValue { + enum_decl: &SELECTED_EXPRS, + variant: "SelectedExprs", + data: Some(vec![left, right]), + }) + .into(), (_, BinaryOperator::In | BinaryOperator::Is, _) => { let (left_sel, left_exprs) = to_selected_exprs_expanded(&left); let (right_sel, right_exprs) = to_selected_exprs_expanded(&right); @@ -1214,11 +1306,21 @@ fn evaluate_binary_operation<'a, T: FieldElement>( }; let selectors = Value::Tuple(vec![left_sel, right_sel]).into(); let expr_pairs = zip_expressions_for_op(op, left_exprs, right_exprs)?; - Value::Enum(name, Some(vec![selectors, expr_pairs])).into() + Value::Enum(EnumValue { + enum_decl: &CONSTR, + variant: name, + data: Some(vec![selectors, expr_pairs]), + }) + .into() } (Value::Array(left), BinaryOperator::Connect, Value::Array(right)) => { let expr_pairs = zip_expressions_for_op(op, left, right)?; - Value::Enum("Connection", Some(vec![expr_pairs])).into() + Value::Enum(EnumValue { + enum_decl: &CONSTR, + variant: "Connection", + data: Some(vec![expr_pairs]), + }) + .into() } (l, op, r) => Err(EvalError::TypeError(format!( "Operator \"{op}\" not supported on types: {l}: {}, {r}: {}", @@ -1256,17 +1358,23 @@ fn to_selected_exprs_expanded<'a, 'b, T>( ) -> (Arc>, &'a Vec>>) { match selected_exprs { // An array of expressions or a selected expressions without selector. - Value::Array(items) | Value::Enum("JustExprs", Some(items)) => { - (Value::Enum("None", None).into(), &items) - } + Value::Array(items) + | Value::Enum(EnumValue { + variant: "JustExprs", + data: Some(items), + .. + }) => (none_value().into(), items), // A selected expressions - Value::Enum("SelectedExprs", Some(items)) => { + Value::Enum(EnumValue { + variant: "SelectedExprs", + data: Some(items), + .. + }) => { let [sel, exprs] = &items[..] else { panic!() }; - let selector = Value::Enum("Some", Some(vec![sel.clone()])).into(); let Value::Array(exprs) = exprs.as_ref() else { panic!(); }; - (selector, exprs) + (some_value(sel.clone()).into(), exprs) } _ => panic!(), } @@ -1443,8 +1551,8 @@ fn evaluate_builtin_function<'a, T: FieldElement>( ), }; match result { - Ok(v) => Value::Enum("Some", Some(vec![v])), - Err(EvalError::DataNotAvailable) => Value::Enum("None", None), + Ok(v) => some_value(v), + Err(EvalError::DataNotAvailable) => none_value(), Err(e) => return Err(e), } .into() diff --git a/pil-analyzer/tests/condenser.rs b/pil-analyzer/tests/condenser.rs index 9b7577a809..ca506a2815 100644 --- a/pil-analyzer/tests/condenser.rs +++ b/pil-analyzer/tests/condenser.rs @@ -774,3 +774,82 @@ namespace N(16); let formatted = analyze_string(input).to_string(); assert_eq!(formatted, expected); } + +#[test] +fn capture_enums() { + let input = r#" + namespace N(16); + enum E { A(T), B, C(T, int), D() } + (|| { + let x = E::A("abc"); + let y = E::B::; + let z: E = E::C([1, 2], 9); + let w: E = E::D(); + query |_| { + let t = (x, y, z, w); + } + })(); + + "#; + let expected = r#"namespace N(16); + enum E { + A(T), + B, + C(T, int), + D(), + } + { + let x = N::E::A("abc"); + let y = N::E::B; + let z = N::E::C([1, 2], 9); + let w = N::E::D(); + query |_| { + let t: (N::E, N::E, N::E, N::E) = (x, y, z, w); + } + }; +"#; + let formatted = analyze_string(input).to_string(); + assert_eq!(formatted, expected); + let re_analyzed = analyze_string(&formatted); + assert_eq!(re_analyzed.to_string(), expected); +} + +#[test] +fn capture_challenges_and_numbers() { + let input = r#" + namespace std::prelude; + let challenge = 8; + namespace std::prover; + let provide_value = 9; + let eval = -1; + namespace N(16); + (constr || { + let x = std::prelude::challenge(0, 4); + let y; + let t = 2; + query |i| { + std::prover::provide_value(y, i, std::prover::eval(x) + t); + } + })(); + + "#; + let expected = r#"namespace std::prelude; + let challenge = 8; +namespace std::prover; + let provide_value = 9; + let eval = -1; + col witness y; + { + let x = std::prelude::challenge(0, 4); + let y = std::prover::y; + let t = 2; + query |i| { + std::prover::provide_value(y, i, std::prover::eval(x) + t); + } + }; +"#; + let formatted = analyze_string(input).to_string(); + assert_eq!(formatted, expected); + let re_analyzed = analyze_string(&formatted); + assert_eq!(re_analyzed.to_string(), expected); +} diff --git a/pipeline/src/lib.rs b/pipeline/src/lib.rs index 52a73274d1..4cdeaf54c3 100644 --- a/pipeline/src/lib.rs +++ b/pipeline/src/lib.rs @@ -77,6 +77,7 @@ impl HostContext { // TODO at some point, we could also just pass evaluator::Values around - would be much faster. pub fn parse_query(query: &str) -> Result<(&str, Vec<&str>), String> { // We are expecting an enum value + let query = query.strip_prefix("std::prelude::Query::").unwrap_or(query); if let Some(paren) = query.find('(') { let name = &query[..paren]; let data = query[paren + 1..].strip_suffix(')').ok_or_else(|| { From b381adb5fc4ef4fe98244fbbb02690b3f896d381 Mon Sep 17 00:00:00 2001 From: Leo Date: Thu, 19 Sep 2024 20:19:57 +0200 Subject: [PATCH 07/16] fix macro example (#1815) --- backend/src/field_filter.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/src/field_filter.rs b/backend/src/field_filter.rs index 7ccd55911f..d5a9b675d0 100644 --- a/backend/src/field_filter.rs +++ b/backend/src/field_filter.rs @@ -5,10 +5,8 @@ /// the macro will forward the calls to the restricted factory. Otherwise, it will /// panic. /// -/// # Example -/// ``` +/// Example: /// generalize_factory!(Factory <- RestrictedFactory, [GoldilocksField, BabyBearField]); -/// ``` macro_rules! generalize_factory { ($general_factory:ident <- $restricted_factory:ident, [$($supported_type:ty),*]) => { pub(crate) struct $general_factory; From f5cb35e7cfb0718f95bb5cf22265b54dafa28045 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 20 Sep 2024 19:01:37 +0200 Subject: [PATCH 08/16] Improve error messages for converting values to expressions. (#1819) Co-authored-by: chriseth --- pil-analyzer/src/condenser.rs | 18 +++++++++------ pil-analyzer/tests/condenser.rs | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 7749e9ba40..6d4973daa8 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -477,7 +477,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { let value = try_to_function_value_definition(value.as_ref(), FunctionKind::Pure) .map_err(|e| match e { EvalError::TypeError(e) => { - EvalError::TypeError(format!("Error creating fixed column {name}: {e}")) + EvalError::TypeError(format!("Error creating fixed column {name}:\n{e}")) } _ => e, })?; @@ -558,7 +558,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { try_to_function_value_definition(expr.as_ref(), FunctionKind::Query).map_err(|e| { match e { EvalError::TypeError(e) => { - EvalError::TypeError(format!("Error setting hint for column {col}: {e}")) + EvalError::TypeError(format!("Error setting hint for column {col}:\n{e}")) } _ => e, } @@ -591,7 +591,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { } Value::Closure(..) => { let e = try_value_to_expression(&constraints).map_err(|e| { - EvalError::TypeError(format!("Error adding prover function: {e}")) + EvalError::TypeError(format!("Error adding prover function:\n{e}")) })?; self.new_prover_functions.push(e); @@ -819,7 +819,11 @@ fn try_closure_to_expression( let statements = env_map .values() .map(|(new_id, name, value)| { - let mut expr = try_value_to_expression(value.as_ref())?; + let mut expr = try_value_to_expression(value.as_ref()).map_err(|e| { + EvalError::TypeError(format!( + "Error converting captured variable {name} to expression:\n{e}", + )) + })?; // The call to try_value_to_expression assumed a fresh environment, // but we already have `new_id` let statements at this point, // so we adjust the local variable references inside `expr` accordingly. @@ -957,7 +961,7 @@ fn try_value_to_expression(value: &Value<'_, T>) -> Result try_closure_to_expression(c)?, Value::TypeConstructor(type_constructor) => { return Err(EvalError::TypeError(format!( - "Type constructor as captured value not supported: {type_constructor}.", + "Converting type constructor to expression not supported: {type_constructor}.", ))) } Value::Enum(enum_value) => { @@ -983,7 +987,7 @@ fn try_value_to_expression(value: &Value<'_, T>) -> Result { return Err(EvalError::TypeError( - "Builtin function as captured value not supported.".to_string(), + "Converting builtin functions to expressions not supported.".to_string(), )) } Value::Expression(e) => match e { @@ -1026,7 +1030,7 @@ fn try_value_to_expression(value: &Value<'_, T>) -> Result { return Err(EvalError::TypeError(format!( - "Algebraic expression as captured value not supported: {e}." + "Converting complex algebraic expressions to expressions not supported: {e}." ))) } }, diff --git a/pil-analyzer/tests/condenser.rs b/pil-analyzer/tests/condenser.rs index ca506a2815..9e1ed32f28 100644 --- a/pil-analyzer/tests/condenser.rs +++ b/pil-analyzer/tests/condenser.rs @@ -853,3 +853,42 @@ namespace std::prover; let re_analyzed = analyze_string(&formatted); assert_eq!(re_analyzed.to_string(), expected); } + +#[test] +#[should_panic = "Converting complex algebraic expressions to expressions not supported: std::prover::x + std::prover::y"] +fn capture_not_supported() { + let input = r#" + namespace std::prover; + let provide_value = 9; + let eval = -1; + namespace N(16); + (constr || { + let x; + let y; + let t = x + y; + query |i| { + let _ = std::prover::eval(t); + } + })(); + + "#; + let expected = r#"namespace std::prelude; + let challenge = 8; +namespace std::prover; + let provide_value = 9; + let eval = -1; + col witness y; + { + let x = std::prelude::challenge(0, 4); + let y = std::prover::y; + let t = 2; + query |i| { + std::prover::provide_value(y, i, std::prover::eval(x) + t); + } + }; +"#; + let formatted = analyze_string(input).to_string(); + assert_eq!(formatted, expected); + let re_analyzed = analyze_string(&formatted); + assert_eq!(re_analyzed.to_string(), expected); +} From 06abb2a28ad72ee1bd233a5679e8f296e31ea915 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 14:53:25 +0200 Subject: [PATCH 09/16] Simplify parsed identities and introduce the concept of proof item. (#1820) At the on-site, we introduced operators for all lookup-like constraints. This means that `[x, y] in s $ [a, b];` is not parsed as a lookup identity any more, but just as a regular expression with binary operators. The whole concept of a lookup or identity as a rust type is now only present after the condenser has run. Because of that, we can remove a lot of code that was concerned with parsed identities. Since prover functions were introduced, we can have both constraints and prover functions at statement level. Because of that I extended the concept of "identity" (which we partly renamed to "constraint" already) to "proof item". A proof item is either a constraint or a prover function. Later on, we might also include fixed columns, challenges, etc. Co-authored-by: chriseth --- ast/src/analyzed/display.rs | 20 +---- ast/src/analyzed/mod.rs | 74 ++---------------- backend/src/composite/split.rs | 2 +- .../json_exporter/expression_counter.rs | 2 +- backend/src/estark/json_exporter/mod.rs | 2 +- executor/src/constant_evaluator/mod.rs | 2 +- pil-analyzer/src/condenser.rs | 75 ++++++------------- pil-analyzer/src/evaluator.rs | 14 ++-- pil-analyzer/src/pil_analyzer.rs | 66 ++++++---------- pil-analyzer/src/statement_processor.rs | 43 ++--------- 10 files changed, 72 insertions(+), 228 deletions(-) diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index e85b261a48..5e97f35779 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -143,7 +143,7 @@ impl Display for Analyzed { is_local.into(), )?; } - StatementIdentifier::Identity(i) => { + StatementIdentifier::ProofItem(i) => { writeln_indented(f, &self.identities[*i])?; } StatementIdentifier::ProverFunction(i) => { @@ -317,24 +317,6 @@ impl Display for SelectedExpressions { } } -impl Display for Identity> { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - match self.kind { - IdentityKind::Polynomial => { - let (left, right) = self.as_polynomial_identity(); - let right = right - .as_ref() - .map(|r| r.to_string()) - .unwrap_or_else(|| "0".into()); - write!(f, "{left} = {right};") - } - IdentityKind::Plookup => write!(f, "{} in {};", self.left, self.right), - IdentityKind::Permutation => write!(f, "{} is {};", self.left, self.right), - IdentityKind::Connect => write!(f, "{} connect {};", self.left, self.right), - } - } -} - impl Display for Identity>> { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self.kind { diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 394298e366..0f278be74b 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -20,8 +20,7 @@ use crate::parsed::visitor::{Children, ExpressionVisitable}; pub use crate::parsed::BinaryOperator; pub use crate::parsed::UnaryOperator; use crate::parsed::{ - self, ArrayExpression, ArrayLiteral, EnumDeclaration, EnumVariant, TraitDeclaration, - TraitFunction, + self, ArrayExpression, EnumDeclaration, EnumVariant, TraitDeclaration, TraitFunction, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -29,8 +28,8 @@ pub enum StatementIdentifier { /// Either an intermediate column or a definition. Definition(String), PublicDeclaration(String), - /// Index into the vector of identities. - Identity(usize), + /// Index into the vector of proof items. + ProofItem(usize), /// Index into the vector of prover functions. ProverFunction(usize), } @@ -208,7 +207,7 @@ impl Analyzed { ), ); self.source_order - .push(StatementIdentifier::Identity(self.identities.len() - 1)); + .push(StatementIdentifier::ProofItem(self.identities.len() - 1)); id } @@ -217,7 +216,7 @@ impl Analyzed { pub fn remove_identities(&mut self, to_remove: &BTreeSet) { let mut shift = 0; self.source_order.retain_mut(|s| { - if let StatementIdentifier::Identity(index) = s { + if let StatementIdentifier::ProofItem(index) = s { if to_remove.contains(index) { shift += 1; return false; @@ -735,6 +734,7 @@ pub struct Identity { pub right: SelectedExpressions, } +// TODO This is the only version of Identity left. impl Identity>> { /// Constructs an Identity from a polynomial identity (expression assumed to be identical zero). pub fn from_polynomial_identity( @@ -794,56 +794,6 @@ impl Identity>> { } } -impl Identity>> { - /// Constructs an Identity from a polynomial identity (expression assumed to be identical zero). - pub fn from_polynomial_identity( - id: u64, - source: SourceRef, - identity: parsed::Expression, - ) -> Self { - Identity { - id, - kind: IdentityKind::Polynomial, - source, - left: parsed::SelectedExpressions { - selector: Some(identity), - expressions: Box::new(ArrayLiteral { items: vec![] }.into()), - }, - right: Default::default(), - } - } - /// Returns the expression in case this is a polynomial identity. - pub fn expression_for_poly_id(&self) -> &parsed::Expression { - assert_eq!(self.kind, IdentityKind::Polynomial); - self.left.selector.as_ref().unwrap() - } - - /// Returns the expression in case this is a polynomial identity. - pub fn expression_for_poly_id_mut(&mut self) -> &mut parsed::Expression { - assert_eq!(self.kind, IdentityKind::Polynomial); - self.left.selector.as_mut().unwrap() - } - /// Either returns (a, Some(b)) if this is a - b or (a, None) - /// if it is a polynomial identity of a different structure. - /// Panics if it is a different kind of constraint. - pub fn as_polynomial_identity( - &self, - ) -> (&parsed::Expression, Option<&parsed::Expression>) { - assert_eq!(self.kind, IdentityKind::Polynomial); - match self.expression_for_poly_id() { - parsed::Expression::BinaryOperation( - _, - parsed::BinaryOperation { - left, - op: BinaryOperator::Sub, - right, - }, - ) => (left.as_ref(), Some(right.as_ref())), - a => (a, None), - } - } -} - impl Children> for Identity>> { fn children_mut(&mut self) -> Box> + '_> { Box::new(self.left.children_mut().chain(self.right.children_mut())) @@ -854,18 +804,6 @@ impl Children> for Identity Children> - for Identity>> -{ - fn children_mut(&mut self) -> Box> + '_> { - Box::new(self.left.children_mut().chain(self.right.children_mut())) - } - - fn children(&self) -> Box> + '_> { - Box::new(self.left.children().chain(self.right.children())) - } -} - #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Serialize, Deserialize, JsonSchema, )] diff --git a/backend/src/composite/split.rs b/backend/src/composite/split.rs index 219dd4b243..15aa9c072a 100644 --- a/backend/src/composite/split.rs +++ b/backend/src/composite/split.rs @@ -184,7 +184,7 @@ fn split_by_namespace( // add `statement` to `namespace` Some((namespace, statement)) } - StatementIdentifier::Identity(i) => { + StatementIdentifier::ProofItem(i) => { let identity = &pil.identities[*i]; let namespaces = referenced_namespaces(identity); diff --git a/backend/src/estark/json_exporter/expression_counter.rs b/backend/src/estark/json_exporter/expression_counter.rs index ead49617da..659b380540 100644 --- a/backend/src/estark/json_exporter/expression_counter.rs +++ b/backend/src/estark/json_exporter/expression_counter.rs @@ -28,7 +28,7 @@ pub fn compute_intermediate_expression_ids(analyzed: &Analyzed) -> HashMap StatementIdentifier::PublicDeclaration(name) => { analyzed.public_declarations[name].expression_count() } - StatementIdentifier::Identity(id) => analyzed.identities[*id].expression_count(), + StatementIdentifier::ProofItem(id) => analyzed.identities[*id].expression_count(), StatementIdentifier::ProverFunction(_) => 0, } } diff --git a/backend/src/estark/json_exporter/mod.rs b/backend/src/estark/json_exporter/mod.rs index 939c4b8390..6c806b6044 100644 --- a/backend/src/estark/json_exporter/mod.rs +++ b/backend/src/estark/json_exporter/mod.rs @@ -80,7 +80,7 @@ pub fn export(analyzed: &Analyzed) -> PIL { name: name.clone(), }); } - StatementIdentifier::Identity(id) => { + StatementIdentifier::ProofItem(id) => { let identity = &analyzed.identities[*id]; // PILCOM strips the path from filenames, we do the same here for compatibility let file_name = identity diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 0c1784ed72..47f0327c17 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -686,7 +686,7 @@ mod test { let f: -> () = || (); let g: col = |i| { // This returns an empty tuple, we check that this does not lead to - // a call to add_constraints() + // a call to add_proof_items() f(); i }; diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 6d4973daa8..53910f6313 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -25,7 +25,7 @@ use powdr_ast::{ types::{ArrayType, Type}, visitor::{AllChildren, ExpressionVisitable}, ArrayLiteral, BlockExpression, FunctionCall, FunctionKind, LambdaExpression, - LetStatementInsideBlock, Number, Pattern, TypedExpression, UnaryOperation, + LetStatementInsideBlock, Number, Pattern, SourceReference, TypedExpression, UnaryOperation, }, }; use powdr_number::{BigUint, FieldElement}; @@ -36,14 +36,13 @@ use crate::{ statement_processor::Counters, }; -type ParsedIdentity = Identity>; type AnalyzedIdentity = Identity>>; pub fn condense( mut definitions: HashMap)>, solved_impls: HashMap, Arc>>, public_declarations: HashMap, - identities: &[ParsedIdentity], + proof_items: &[Expression], source_order: Vec, auto_added_symbols: HashSet, ) -> Analyzed { @@ -65,8 +64,8 @@ pub fn condense( condenser.set_namespace_and_degree(namespace, definitions[name].0.degree); } let statement = match s { - StatementIdentifier::Identity(index) => { - condenser.condense_identity(&identities[index]); + StatementIdentifier::ProofItem(index) => { + condenser.condense_proof_item(&proof_items[index]); None } StatementIdentifier::Definition(name) @@ -135,7 +134,7 @@ pub fn condense( .map(|identity| { let index = condensed_identities.len(); condensed_identities.push(identity); - StatementIdentifier::Identity(index) + StatementIdentifier::ProofItem(index) }) .collect::>(); @@ -240,34 +239,21 @@ impl<'a, T: FieldElement> Condenser<'a, T> { } } - pub fn condense_identity(&mut self, identity: &'a ParsedIdentity) { - if identity.kind == IdentityKind::Polynomial { - let expr = identity.expression_for_poly_id(); - evaluator::evaluate(expr, self) - .and_then(|expr| { - if let Value::Tuple(items) = expr.as_ref() { - assert!(items.is_empty()); - Ok(()) - } else { - self.add_constraints(expr, identity.source.clone()) - } - }) - .unwrap_or_else(|err| { - panic!( - "Error reducing expression to constraint:\nExpression: {expr}\nError: {err:?}" - ) - }); - } else { - let left = self.condense_selected_expressions(&identity.left); - let right = self.condense_selected_expressions(&identity.right); - self.new_constraints.push(Identity { - id: self.counters.dispense_identity_id(), - kind: identity.kind, - source: identity.source.clone(), - left, - right, + pub fn condense_proof_item(&mut self, item: &'a Expression) { + evaluator::evaluate(item, self) + .and_then(|expr| { + if let Value::Tuple(items) = expr.as_ref() { + assert!(items.is_empty()); + Ok(()) + } else { + self.add_proof_items(expr, item.source_reference().clone()) + } }) - } + .unwrap_or_else(|err| { + panic!( + "Error reducing expression to constraint:\nExpression: {item}\nError: {err:?}" + ) + }); } /// Sets the current namespace which will be used for newly generated witness columns. @@ -308,19 +294,6 @@ impl<'a, T: FieldElement> Condenser<'a, T> { std::mem::take(&mut self.new_prover_functions) } - fn condense_selected_expressions( - &mut self, - sel_expr: &'a parsed::SelectedExpressions, - ) -> SelectedExpressions> { - SelectedExpressions { - selector: sel_expr - .selector - .as_ref() - .map(|expr| self.condense_to_algebraic_expression(expr)), - expressions: self.condense_to_array_of_algebraic_expressions(&sel_expr.expressions), - } - } - /// Evaluates the expression and expects it to result in an algebraic expression. fn condense_to_algebraic_expression(&mut self, e: &'a Expression) -> AlgebraicExpression { let result = evaluator::evaluate(e, self).unwrap_or_else(|err| { @@ -574,12 +547,12 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { Ok(()) } - fn add_constraints( + fn add_proof_items( &mut self, - constraints: Arc>, + items: Arc>, source: SourceRef, ) -> Result<(), EvalError> { - match constraints.as_ref() { + match items.as_ref() { Value::Array(items) => { for item in items { self.new_constraints.push(to_constraint( @@ -590,7 +563,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { } } Value::Closure(..) => { - let e = try_value_to_expression(&constraints).map_err(|e| { + let e = try_value_to_expression(&items).map_err(|e| { EvalError::TypeError(format!("Error adding prover function:\n{e}")) })?; @@ -598,7 +571,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { } _ => self .new_constraints - .push(to_constraint(&constraints, source, &mut self.counters)), + .push(to_constraint(&items, source, &mut self.counters)), } Ok(()) } diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index 5c4adf85cd..44658f2996 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -714,13 +714,13 @@ pub trait SymbolLookup<'a, T: FieldElement> { )) } - fn add_constraints( + fn add_proof_items( &mut self, - _constraints: Arc>, + _items: Arc>, _source: SourceRef, ) -> Result<(), EvalError> { Err(EvalError::Unsupported( - "Tried to add constraints outside of statement context.".to_string(), + "Tried to add proof items outside of statement context.".to_string(), )) } @@ -772,7 +772,7 @@ enum Operation<'a, T> { /// Evaluate a let statement, adding matched pattern variables to the local variables. LetStatement(&'a LetStatementInsideBlock), /// Add a constraint to the constraint set. - AddConstraint, + AddProofItem, } /// We use a non-recursive algorithm to evaluate potentially recursive expressions. @@ -832,11 +832,11 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { self.type_args = new_type_args; } Operation::LetStatement(s) => self.evaluate_let_statement(s)?, - Operation::AddConstraint => { + Operation::AddProofItem => { let result = self.value_stack.pop().unwrap(); match result.as_ref() { Value::Tuple(t) if t.is_empty() => {} - _ => self.symbols.add_constraints(result, SourceRef::unknown())?, + _ => self.symbols.add_proof_items(result, SourceRef::unknown())?, } } }; @@ -947,7 +947,7 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { } } StatementInsideBlock::Expression(expr) => { - self.op_stack.push(Operation::AddConstraint); + self.op_stack.push(Operation::AddProofItem); self.op_stack.push(Operation::Expand(expr)); } } diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index 195252b441..3f7b229e82 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -10,18 +10,18 @@ use itertools::Itertools; use powdr_ast::parsed::asm::{ parse_absolute_path, AbsoluteSymbolPath, ModuleStatement, SymbolPath, }; -use powdr_ast::parsed::types::{ArrayType, Type}; +use powdr_ast::parsed::types::Type; use powdr_ast::parsed::visitor::{AllChildren, Children}; use powdr_ast::parsed::{ - self, FunctionKind, LambdaExpression, PILFile, PilStatement, SelectedExpressions, - SymbolCategory, TraitImplementation, + self, FunctionKind, LambdaExpression, PILFile, PilStatement, SymbolCategory, + TraitImplementation, }; use powdr_number::{FieldElement, GoldilocksField}; use powdr_ast::analyzed::{ - type_from_definition, Analyzed, DegreeRange, Expression, FunctionValueDefinition, Identity, - IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference, - StatementIdentifier, Symbol, SymbolKind, TypedExpression, + type_from_definition, Analyzed, DegreeRange, Expression, FunctionValueDefinition, + PolynomialReference, PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, + SymbolKind, TypedExpression, }; use powdr_parser::{parse, parse_module, parse_type}; use powdr_parser_util::Error; @@ -71,7 +71,8 @@ struct PILAnalyzer { /// Map of definitions, gradually being built up here. definitions: HashMap)>, public_declarations: HashMap, - identities: Vec>>, + /// The list of proof items, i.e. statements that evaluate to constraints or prover functions. + proof_items: Vec, /// The order in which definitions and identities /// appear in the source. source_order: Vec, @@ -238,14 +239,10 @@ impl PILAnalyzer { } } - // for all identities, check that they call pure or constr functions - for id in &self.identities { - id.children() - .try_for_each(|e| { - side_effect_checker::check(&self.definitions, FunctionKind::Constr, e) - }) - .unwrap_or_else(|err| errors.push(err)) - } + // for all proof items, check that they call pure or constr functions + errors.extend(self.proof_items.iter().filter_map(|e| { + side_effect_checker::check(&self.definitions, FunctionKind::Constr, e).err() + })); if errors.is_empty() { Ok(()) @@ -335,29 +332,9 @@ impl PILAnalyzer { Some((name.clone(), (type_scheme, expr))) }) .collect(); - for id in &mut self.identities { - if id.kind == IdentityKind::Polynomial { - // At statement level, we allow Constr, Constr[] or (). - expressions.push(( - id.expression_for_poly_id_mut(), - constr_function_statement_type(), - )); - } else { - for part in [&mut id.left, &mut id.right] { - if let Some(selector) = &mut part.selector { - expressions.push((selector, Type::Expr.into())) - } - - expressions.push(( - part.expressions.as_mut(), - Type::Array(ArrayType { - base: Box::new(Type::Expr), - length: None, - }) - .into(), - )) - } - } + for expr in &mut self.proof_items { + // At statement level, we allow Constr, Constr[], (int -> ()) or (). + expressions.push((expr, constr_function_statement_type())); } let inferred_types = infer_types(definitions, &mut expressions)?; @@ -408,7 +385,7 @@ impl PILAnalyzer { } } - for identity in &self.identities { + for identity in &self.proof_items { for expr in identity.all_children() { resolve_references(expr); } @@ -431,7 +408,7 @@ impl PILAnalyzer { self.definitions, solved_impls, self.public_declarations, - &self.identities, + &self.proof_items, self.source_order, self.auto_added_symbols, )) @@ -505,10 +482,11 @@ impl PILAnalyzer { self.source_order .push(StatementIdentifier::PublicDeclaration(name)); } - PILItem::Identity(identity) => { - let index = self.identities.len(); - self.source_order.push(StatementIdentifier::Identity(index)); - self.identities.push(identity) + PILItem::ProofItem(item) => { + let index = self.proof_items.len(); + self.source_order + .push(StatementIdentifier::ProofItem(index)); + self.proof_items.push(item) } PILItem::TraitImplementation(trait_impl) => self .implementations diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index 143514bfb6..8374f2d200 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -10,16 +10,15 @@ use powdr_ast::parsed::types::TupleType; use powdr_ast::parsed::{ self, types::{ArrayType, Type, TypeScheme}, - ArrayLiteral, EnumDeclaration, EnumVariant, FunctionDefinition, FunctionKind, LambdaExpression, - PilStatement, PolynomialName, SelectedExpressions, TraitDeclaration, TraitFunction, + EnumDeclaration, EnumVariant, FunctionDefinition, FunctionKind, LambdaExpression, PilStatement, + PolynomialName, TraitDeclaration, TraitFunction, }; use powdr_ast::parsed::{ArrayExpression, NamedExpression, SymbolCategory, TraitImplementation}; use powdr_parser_util::SourceRef; use std::str::FromStr; use powdr_ast::analyzed::{ - Expression, FunctionValueDefinition, Identity, IdentityKind, PolynomialType, PublicDeclaration, - Symbol, SymbolKind, + Expression, FunctionValueDefinition, PolynomialType, PublicDeclaration, Symbol, SymbolKind, }; use crate::type_processor::TypeProcessor; @@ -30,7 +29,7 @@ use crate::expression_processor::ExpressionProcessor; pub enum PILItem { Definition(Symbol, Option), PublicDeclaration(PublicDeclaration), - Identity(Identity>), + ProofItem(Expression), TraitImplementation(TraitImplementation), } @@ -209,7 +208,10 @@ where let trait_impl = self.process_trait_implementation(trait_impl); vec![PILItem::TraitImplementation(trait_impl)] } - _ => self.handle_identity_statement(statement), + PilStatement::Expression(_, expr) => vec![PILItem::ProofItem( + self.expression_processor(&Default::default()) + .process_expression(expr), + )], } } @@ -338,35 +340,6 @@ where } } - fn handle_identity_statement(&mut self, statement: PilStatement) -> Vec { - let (source, kind, left, right) = match statement { - PilStatement::Expression(source, expression) => ( - source, - IdentityKind::Polynomial, - SelectedExpressions { - selector: Some( - self.expression_processor(&Default::default()) - .process_expression(expression), - ), - expressions: Box::new(ArrayLiteral { items: vec![] }.into()), - }, - SelectedExpressions::default(), - ), - // TODO at some point, these should all be caught by the type checker. - _ => { - panic!("Only identities allowed at this point.") - } - }; - - vec![PILItem::Identity(Identity { - id: self.counters.dispense_identity_id(), - kind, - source, - left, - right, - })] - } - fn handle_polynomial_declarations( &mut self, source: SourceRef, From 4f2587e70fe064cbf29c992e88ac598ec7325ca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gast=C3=B3n=20Zanitti?= Date: Mon, 23 Sep 2024 10:45:18 -0300 Subject: [PATCH 10/16] Rename trait functions (#1816) Since we also need name and type pairs for structs, it makes sense to rename `TraitFunction` to something more general (`NamedType` in this case). There are also `NamedExpressions` and everything could be combined into a `NamedElement` or something similar that covers both cases. Personally I prefer this option, but the discussion is open. --- ast/src/analyzed/mod.rs | 6 +++--- ast/src/parsed/display.rs | 2 +- ast/src/parsed/mod.rs | 10 +++++----- importer/src/path_canonicalizer.rs | 15 ++++++--------- parser/src/powdr.lalrpop | 15 ++++++++++++--- pil-analyzer/src/statement_processor.rs | 6 +++--- 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 0f278be74b..fc32a43419 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -20,7 +20,7 @@ use crate::parsed::visitor::{Children, ExpressionVisitable}; pub use crate::parsed::BinaryOperator; pub use crate::parsed::UnaryOperator; use crate::parsed::{ - self, ArrayExpression, EnumDeclaration, EnumVariant, TraitDeclaration, TraitFunction, + self, ArrayExpression, EnumDeclaration, EnumVariant, NamedType, TraitDeclaration, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -623,7 +623,7 @@ pub enum FunctionValueDefinition { TypeDeclaration(EnumDeclaration), TypeConstructor(Arc, EnumVariant), TraitDeclaration(TraitDeclaration), - TraitFunction(Arc, TraitFunction), + TraitFunction(Arc, NamedType), } impl Children for FunctionValueDefinition { @@ -667,7 +667,7 @@ impl Children for TraitDeclaration { } } -impl Children for TraitFunction { +impl Children for NamedType { fn children(&self) -> Box + '_> { Box::new(empty()) } diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 4098367eea..319d4e056d 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -587,7 +587,7 @@ impl Display for TraitDeclaration { } } -impl Display for TraitFunction { +impl Display for NamedType { fn fmt(&self, f: &mut Formatter<'_>) -> Result { write!(f, "{}: {}", self.name, self.ty) } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index fd3e5fe708..f1953e526f 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -379,15 +379,15 @@ pub struct NamedExpression { pub struct TraitDeclaration { pub name: String, pub type_vars: Vec, - pub functions: Vec>, + pub functions: Vec>, } impl TraitDeclaration { - pub fn function_by_name(&self, name: &str) -> Option<&TraitFunction> { + pub fn function_by_name(&self, name: &str) -> Option<&NamedType> { self.functions.iter().find(|f| f.name == name) } - pub fn function_by_name_mut(&mut self, name: &str) -> Option<&mut TraitFunction> { + pub fn function_by_name_mut(&mut self, name: &str) -> Option<&mut NamedType> { self.functions.iter_mut().find(|f| f.name == name) } } @@ -402,12 +402,12 @@ impl Children> for TraitDeclaration> { } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] -pub struct TraitFunction { +pub struct NamedType { pub name: String, pub ty: Type, } -impl Children> for TraitFunction> { +impl Children> for NamedType> { fn children(&self) -> Box> + '_> { self.ty.children() } diff --git a/importer/src/path_canonicalizer.rs b/importer/src/path_canonicalizer.rs index 3863c160ee..5f7fdaec66 100644 --- a/importer/src/path_canonicalizer.rs +++ b/importer/src/path_canonicalizer.rs @@ -16,7 +16,7 @@ use powdr_ast::parsed::{ visitor::{Children, ExpressionVisitable}, ArrayLiteral, BinaryOperation, BlockExpression, EnumDeclaration, EnumVariant, Expression, FunctionCall, IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, - MatchExpression, Pattern, PilStatement, StatementInsideBlock, TraitDeclaration, TraitFunction, + MatchExpression, NamedType, Pattern, PilStatement, StatementInsideBlock, TraitDeclaration, TypedExpression, UnaryOperation, }; use powdr_parser_util::{Error, SourceRef}; @@ -972,14 +972,11 @@ fn check_trait_declaration( trait_decl .functions .iter() - .try_fold( - BTreeSet::default(), - |mut acc, TraitFunction { name, .. }| { - acc.insert(name.clone()).then_some(acc).ok_or(format!( - "Duplicate method `{name}` defined in trait `{location}`" - )) - }, - ) + .try_fold(BTreeSet::default(), |mut acc, NamedType { name, .. }| { + acc.insert(name.clone()).then_some(acc).ok_or(format!( + "Duplicate method `{name}` defined in trait `{location}`" + )) + }) .map_err(|e| SourceRef::unknown().with_error(e))?; let type_vars = trait_decl.type_vars.iter().collect(); diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 578e6fefb2..d59b2ae08e 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -674,6 +674,15 @@ NamedExpression: NamedExpression = { ":" => NamedExpression { name, body } } +NamedTypes: Vec> = { + => vec![], + "," )*> "," => { list.push(end); list } +} + +NamedType: NamedType = { + ":" > => NamedType { name, ty } +} + // ---------------------------- Pattern ----------------------------- Pattern: Pattern = { @@ -726,13 +735,13 @@ TraitVars: Vec = { "," )*> ","? => { list.push(end); list } } -TraitFunctions: Vec> = { +TraitFunctions: Vec> = { => vec![], "," )*> "," => { list.push(end); list } } -TraitFunction: TraitFunction = { - ":" > "->" > => TraitFunction { name, ty: Type::Function(FunctionType{params, value}) } +TraitFunction: NamedType = { + ":" > "->" > => NamedType { name, ty: Type::Function(FunctionType{params, value}) } } TraitImplementation: TraitImplementation = { diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index 8374f2d200..73c518162d 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -10,8 +10,8 @@ use powdr_ast::parsed::types::TupleType; use powdr_ast::parsed::{ self, types::{ArrayType, Type, TypeScheme}, - EnumDeclaration, EnumVariant, FunctionDefinition, FunctionKind, LambdaExpression, PilStatement, - PolynomialName, TraitDeclaration, TraitFunction, + EnumDeclaration, EnumVariant, FunctionDefinition, FunctionKind, LambdaExpression, NamedType, + PilStatement, PolynomialName, TraitDeclaration, }; use powdr_ast::parsed::{ArrayExpression, NamedExpression, SymbolCategory, TraitImplementation}; use powdr_parser_util::SourceRef; @@ -428,7 +428,7 @@ where let functions = trait_decl .functions .into_iter() - .map(|f| TraitFunction { + .map(|f| NamedType { name: f.name, ty: self.type_processor(&type_vars).process_type(f.ty), }) From 993c3f16a44e28c10f575e04cac0a7896aacdbb3 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 16:18:59 +0200 Subject: [PATCH 11/16] Simplify lookup (#1802) Co-authored-by: chriseth --- std/math/fp2.asm | 9 ++-- std/protocols/lookup.asm | 42 +++++++++++++++---- test_data/std/lookup_via_challenges.asm | 12 +----- test_data/std/lookup_via_challenges_ext.asm | 28 +------------ .../std/lookup_via_challenges_ext_simple.asm | 29 +------------ 5 files changed, 45 insertions(+), 75 deletions(-) diff --git a/std/math/fp2.asm b/std/math/fp2.asm index c6ce88399e..e1d1be39a6 100644 --- a/std/math/fp2.asm +++ b/std/math/fp2.asm @@ -115,9 +115,12 @@ let unpack_ext_array: Fp2 -> T[] = |a| match a { }; /// Whether we need to operate on the F_{p^2} extension field (because the current field is too small). -let needs_extension: -> bool = || match known_field() { - Option::Some(KnownField::Goldilocks) => true, - Option::Some(KnownField::BN254) => false, +let needs_extension: -> bool = || required_extension_size() > 1; + +/// How many field elements / field extensions are recommended for the current base field. +let required_extension_size: -> int = || match known_field() { + Option::Some(KnownField::Goldilocks) => 2, + Option::Some(KnownField::BN254) => 1, None => panic("The permutation/lookup argument is not implemented for the current field!") }; diff --git a/std/protocols/lookup.asm b/std/protocols/lookup.asm index ecd82ae2d2..8fd7d159d1 100644 --- a/std/protocols/lookup.asm +++ b/std/protocols/lookup.asm @@ -1,3 +1,4 @@ +use std::array; use std::array::fold; use std::array::len; use std::array::map; @@ -15,6 +16,8 @@ use std::math::fp2::eval_ext; use std::math::fp2::from_base; use std::math::fp2::fp2_from_array; use std::math::fp2::constrain_eq_ext; +use std::math::fp2::required_extension_size; +use std::math::fp2::needs_extension; use std::protocols::fingerprint::fingerprint; use std::utils::unwrap_or_else; @@ -51,21 +54,29 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr, expr -> fe[] = quer unpack_ext_array(res) }; -/// Adds constraints that enforce that rhs is the lookup for lhs +/// Transfroms a single lookup constraint to identity constraint, challenges and +/// higher-stage witness columns. +/// Use this function if the backend does not support lookup constraints natively. +/// WARNING: This function can currently not be used multiple times since +/// the used challenges would overlap. +/// TODO: Implement this for an array of constraints. /// Arguments: -/// - acc: A phase-2 witness column to be used as the accumulator. If 2 are provided, computations -/// are done on the F_{p^2} extension field. -/// - alpha: A challenge used to compress the LHS and RHS values -/// - beta: A challenge used to update the accumulator /// - lookup_constraint: The lookup constraint -/// - multiplicities: The multiplicities which shows how many times each RHS value appears in the LHS -let lookup: expr[], Fp2, Fp2, Constr, expr -> () = constr |acc, alpha, beta, lookup_constraint, multiplicities| { +/// - multiplicities: A multiplicities column which shows how many times each row of the RHS value appears in the LHS +let lookup: Constr, expr -> () = constr |lookup_constraint, multiplicities| { + std::check::assert(required_extension_size() <= 2, || "Invalid extension size"); + // Alpha is used to compress the LHS and RHS arrays. + let alpha = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 1))); + // Beta is used to update the accumulator. + let beta = fp2_from_array(array::new(required_extension_size(), |i| challenge(0, i + 3))); let (lhs_selector, lhs, rhs_selector, rhs) = unpack_lookup_constraint(lookup_constraint); let lhs_denom = sub_ext(beta, fingerprint(lhs, alpha)); let rhs_denom = sub_ext(beta, fingerprint(rhs, alpha)); let m_ext = from_base(multiplicities); + + let acc = array::new(required_extension_size(), |i| std::prover::new_witness_col_at_stage("acc", 1)); let acc_ext = fp2_from_array(acc); let next_acc = next_ext(acc_ext); @@ -100,4 +111,21 @@ let lookup: expr[], Fp2, Fp2, Constr, expr -> () = constr |acc, alph is_first * acc_1 = 0; is_first * acc_2 = 0; constrain_eq_ext(update_expr, from_base(0)); + + // In the extension field, we need a prover function for the accumulator. + if needs_extension() { + // TODO: Helper columns, because we can't access the previous row in hints + let acc_next_col = std::array::map(acc, |_| std::prover::new_witness_col_at_stage("acc_next", 1)); + query |i| { + let _ = std::array::zip( + acc_next_col, + compute_next_z(acc_ext, alpha, beta, lookup_constraint, multiplicities), + |acc_next, hint_val| std::prover::provide_value(acc_next, i, hint_val) + ); + }; + std::array::zip(acc, acc_next_col, |acc_col, acc_next| { + acc_col' = acc_next + }); + } else { + } }; \ No newline at end of file diff --git a/test_data/std/lookup_via_challenges.asm b/test_data/std/lookup_via_challenges.asm index 9832a9136c..0c85357468 100644 --- a/test_data/std/lookup_via_challenges.asm +++ b/test_data/std/lookup_via_challenges.asm @@ -5,10 +5,6 @@ use std::math::fp2::from_base; use std::prover::challenge; machine Main with degree: 8 { - - let alpha = from_base(challenge(0, 1)); - let beta = from_base(challenge(0, 2)); - col fixed random_six = [1, 1, 1, 0, 1, 1, 1, 0]; col fixed first_seven = [1, 1, 1, 1, 1, 1, 1, 0]; @@ -25,9 +21,5 @@ machine Main with degree: 8 { let lookup_constraint = random_six $ [a1, a2, a3] in first_seven $ [b1, b2, b3]; - // TODO: Functions currently cannot add witness columns at later stages, - // so we have to manually create it here and pass it to permutation(). - col witness stage(1) z; - lookup([z], alpha, beta, lookup_constraint, m); - -} \ No newline at end of file + lookup(lookup_constraint, m); +} diff --git a/test_data/std/lookup_via_challenges_ext.asm b/test_data/std/lookup_via_challenges_ext.asm index 57d8bc055d..fcc1ffba40 100644 --- a/test_data/std/lookup_via_challenges_ext.asm +++ b/test_data/std/lookup_via_challenges_ext.asm @@ -6,14 +6,6 @@ use std::math::fp2::Fp2; use std::prover::challenge; machine Main with degree: 8 { - - let alpha1: expr = challenge(0, 1); - let alpha2: expr = challenge(0, 2); - let beta1: expr = challenge(0, 3); - let beta2: expr = challenge(0, 4); - let alpha = Fp2::Fp2(alpha1, alpha2); - let beta = Fp2::Fp2(beta1, beta2); - col fixed a_sel = [0, 1, 1, 1, 0, 1, 0, 0]; col fixed b_sel = [1, 1, 0, 1, 1, 1, 1, 0]; @@ -30,23 +22,5 @@ machine Main with degree: 8 { let lookup_constraint = a_sel $ [a1, a2, a3] in b_sel $ [b1, b2, b3]; - // TODO: Functions currently cannot add witness columns at later stages, - // so we have to manually create it here and pass it to lookup(). - col witness stage(1) z1; - col witness stage(1) z2; - let z = Fp2::Fp2(z1, z2); - - lookup([z1, z2], alpha, beta, lookup_constraint, m); - - // TODO: Helper columns, because we can't access the previous row in hints - col witness stage(1) z1_next; - col witness stage(1) z2_next; - query |i| { - let hint = compute_next_z(z, alpha, beta, lookup_constraint, m); - std::prover::provide_value(z1_next, i, hint[0]); - std::prover::provide_value(z2_next, i, hint[1]); - }; - - z1' = z1_next; - z2' = z2_next; + lookup(lookup_constraint, m); } \ No newline at end of file diff --git a/test_data/std/lookup_via_challenges_ext_simple.asm b/test_data/std/lookup_via_challenges_ext_simple.asm index 44525e636e..33194aa7f8 100644 --- a/test_data/std/lookup_via_challenges_ext_simple.asm +++ b/test_data/std/lookup_via_challenges_ext_simple.asm @@ -6,15 +6,6 @@ use std::math::fp2::Fp2; use std::prover::challenge; machine Main with degree: 8 { - - // We don't need an alpha here, because we only "fold" one element. - // Therefore, the optimizer will remove it, but the hint still accesses it... - let alpha = Fp2::Fp2(0, 0); - - let beta1: expr = challenge(0, 3); - let beta2: expr = challenge(0, 4); - let beta = Fp2::Fp2(beta1, beta2); - col fixed a = [1, 1, 4, 1, 1, 2, 1, 1]; let b; query |i| { @@ -24,23 +15,5 @@ machine Main with degree: 8 { let lookup_constraint = [a] in [b]; - // TODO: Functions currently cannot add witness columns at later stages, - // so we have to manually create it here and pass it to lookup(). - col witness stage(1) z1; - col witness stage(1) z2; - let z = Fp2::Fp2(z1, z2); - - lookup([z1, z2], alpha, beta, lookup_constraint, m); - - // TODO: Helper columns, because we can't access the previous row in hints - col witness stage(1) z1_next; - col witness stage(1) z2_next; - query |i| { - let hint = compute_next_z(z, alpha, beta, lookup_constraint, m); - std::prover::provide_value(z1_next, i, hint[0]); - std::prover::provide_value(z2_next, i, hint[1]); - }; - - z1' = z1_next; - z2' = z2_next; + lookup(lookup_constraint, m); } \ No newline at end of file From fbc8c8c927af62fa2afd7fa192866716cf3825c9 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 23 Sep 2024 16:44:04 +0200 Subject: [PATCH 12/16] Introduce `AlgebraicVariable` (#1755) (First 4 commits of #1650) This PR prepares witness generation for scalar publics (#1756). Scalar publics are similar to cells in the trace, but are global (i.e., independent on the row number). With this PR, affine expressions use a new `AlgebraicVariable` enum, that can be either a column reference (`&'a AlgebraicReference`, which was used previously), or a reference to a public. --- executor/Cargo.toml | 1 + executor/src/witgen/affine_expression.rs | 28 ++++++++++-- executor/src/witgen/block_processor.rs | 4 +- executor/src/witgen/eval_result.rs | 5 +-- executor/src/witgen/expression_evaluator.rs | 27 ++++++------ executor/src/witgen/fixed_evaluator.rs | 38 +++++++++------- executor/src/witgen/generator.rs | 5 ++- executor/src/witgen/global_constraints.rs | 43 ++++++++++++------- executor/src/witgen/identity_processor.rs | 8 ++-- executor/src/witgen/machines/block_machine.rs | 10 ++--- .../witgen/machines/fixed_lookup_machine.rs | 31 +++++++++---- .../witgen/machines/sorted_witness_machine.rs | 18 +++++--- executor/src/witgen/processor.rs | 34 +++++++++------ executor/src/witgen/query_processor.rs | 24 ++++++++--- executor/src/witgen/rows.rs | 34 +++++++++------ executor/src/witgen/symbolic_evaluator.rs | 18 +++++--- .../src/witgen/symbolic_witness_evaluator.rs | 37 ++++++++++------ executor/src/witgen/vm_processor.rs | 9 ++-- 18 files changed, 242 insertions(+), 132 deletions(-) diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 0079cce9a9..bb7080b174 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -18,6 +18,7 @@ log = { version = "0.4.17" } rayon = "1.7.0" bit-vec = "0.6.3" num-traits = "0.2.15" +derive_more = "0.99.17" lazy_static = "1.4.0" indicatif = "0.17.7" serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } diff --git a/executor/src/witgen/affine_expression.rs b/executor/src/witgen/affine_expression.rs index e2d7ddf35b..4dabeb907a 100644 --- a/executor/src/witgen/affine_expression.rs +++ b/executor/src/witgen/affine_expression.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use itertools::{Either, Itertools}; use num_traits::Zero; +use powdr_ast::analyzed::AlgebraicReference; use powdr_number::{FieldElement, LargeInt}; use super::global_constraints::RangeConstraintSet; @@ -19,6 +20,27 @@ pub enum AffineExpression { ManyVars(Vec<(K, T)>, T), } +/// A variable in an affine expression. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Copy, derive_more::Display)] +pub enum AlgebraicVariable<'a> { + /// Reference to a (witness) column + Column(&'a AlgebraicReference), + /// Reference to a public value + // TODO: This should be using the ID instead of the name, but we + // currently store the name in AlgebraicExpression::PublicReference. + Public(&'a str), +} + +impl AlgebraicVariable<'_> { + /// Returns the column reference if the variable is a column, otherwise None. + pub fn try_as_column(&self) -> Option<&AlgebraicReference> { + match self { + AlgebraicVariable::Column(r) => Some(r), + AlgebraicVariable::Public(_) => None, + } + } +} + pub type AffineResult = Result, IncompleteCause>; impl From for AffineExpression { @@ -30,7 +52,7 @@ impl From for AffineExpression { impl AffineExpression where - K: Copy + Ord, + K: Ord, T: FieldElement, { pub fn from_variable_id(var_id: K) -> AffineExpression { @@ -514,7 +536,7 @@ where impl std::ops::Neg for AffineExpression where - K: Copy + Ord, + K: Ord, T: FieldElement, { type Output = Self; @@ -569,7 +591,7 @@ impl std::ops::Mul for AffineExpression { impl Display for AffineExpression where - K: Copy + Ord + Display, + K: Ord + Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.is_constant() { diff --git a/executor/src/witgen/block_processor.rs b/executor/src/witgen/block_processor.rs index 34bf6c1c67..b3fa6396a1 100644 --- a/executor/src/witgen/block_processor.rs +++ b/executor/src/witgen/block_processor.rs @@ -1,9 +1,9 @@ -use powdr_ast::analyzed::AlgebraicReference; use powdr_number::{DegreeType, FieldElement}; use crate::Identity; use super::{ + affine_expression::AlgebraicVariable, data_structures::finalizable_data::FinalizableData, machines::MachineParts, processor::{OuterQuery, Processor}, @@ -63,7 +63,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> BlockProcessor<'a, 'b, 'c pub fn solve( &mut self, sequence_iterator: &mut ProcessingSequenceIterator, - ) -> Result, EvalError> { + ) -> Result, T>, EvalError> { let mut outer_assignments = vec![]; let mut is_identity_complete = diff --git a/executor/src/witgen/eval_result.rs b/executor/src/witgen/eval_result.rs index a03b467603..c48907a086 100644 --- a/executor/src/witgen/eval_result.rs +++ b/executor/src/witgen/eval_result.rs @@ -1,9 +1,8 @@ use std::fmt::{self, Debug}; -use powdr_ast::analyzed::AlgebraicReference; use powdr_number::FieldElement; -use super::range_constraints::RangeConstraint; +use super::{affine_expression::AlgebraicVariable, range_constraints::RangeConstraint}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum IncompleteCause { @@ -160,7 +159,7 @@ impl EvalValue { /// Result of evaluating an expression / lookup. /// New assignments or constraints for witness columns identified by an ID. -pub type EvalResult<'a, T, K = &'a AlgebraicReference> = Result, EvalError>; +pub type EvalResult<'a, T, K = AlgebraicVariable<'a>> = Result, EvalError>; #[derive(Clone, PartialEq)] pub enum EvalError { diff --git a/executor/src/witgen/expression_evaluator.rs b/executor/src/witgen/expression_evaluator.rs index ec46fe372c..8e6c8cf993 100644 --- a/executor/src/witgen/expression_evaluator.rs +++ b/executor/src/witgen/expression_evaluator.rs @@ -2,19 +2,22 @@ use std::marker::PhantomData; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, - AlgebraicReference, AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge, + AlgebraicUnaryOperation, AlgebraicUnaryOperator, Challenge, }; use powdr_number::FieldElement; -use super::{affine_expression::AffineResult, IncompleteCause}; +use super::{ + affine_expression::{AffineResult, AlgebraicVariable}, + IncompleteCause, +}; pub trait SymbolicVariables { - /// Value of a polynomial (fixed or witness). - fn value<'a>(&self, poly: &'a AlgebraicReference) -> AffineResult<&'a AlgebraicReference, T>; + /// Value of a polynomial (fixed or witness) or public. + fn value<'a>(&self, var: AlgebraicVariable<'a>) -> AffineResult, T>; /// Value of a challenge. - fn challenge<'a>(&self, _challenge: &'a Challenge) -> AffineResult<&'a AlgebraicReference, T> { + fn challenge<'a>(&self, _challenge: &'a Challenge) -> AffineResult, T> { // Only needed for evaluating identities, so we leave this unimplemented by default. unimplemented!() } @@ -36,14 +39,14 @@ where marker: PhantomData, } } - /// Tries to evaluate the expression to an expression affine in the witness polynomials, - /// taking current values of polynomials into account. - /// @returns an expression affine in the witness polynomials - pub fn evaluate<'a>(&self, expr: &'a Expression) -> AffineResult<&'a AlgebraicReference, T> { + /// Tries to evaluate the expression to an affine expression in the witness polynomials + /// or publics, taking their current values into account. + /// @returns an expression affine in the witness polynomials or publics. + pub fn evaluate<'a>(&self, expr: &'a Expression) -> AffineResult, T> { // @TODO if we iterate on processing the constraints in the same row, // we could store the simplified values. match expr { - Expression::Reference(poly) => self.variables.value(poly), + Expression::Reference(poly) => self.variables.value(AlgebraicVariable::Column(poly)), Expression::Number(n) => Ok((*n).into()), Expression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { self.evaluate_binary_operation(left, op, right) @@ -61,7 +64,7 @@ where left: &'a Expression, op: &AlgebraicBinaryOperator, right: &'a Expression, - ) -> AffineResult<&'a AlgebraicReference, T> { + ) -> AffineResult, T> { match op { AlgebraicBinaryOperator::Add => { let left_expr = self.evaluate(left)?; @@ -127,7 +130,7 @@ where &self, op: &AlgebraicUnaryOperator, expr: &'a Expression, - ) -> AffineResult<&'a AlgebraicReference, T> { + ) -> AffineResult, T> { self.evaluate(expr).map(|v| match op { AlgebraicUnaryOperator::Minus => -v, }) diff --git a/executor/src/witgen/fixed_evaluator.rs b/executor/src/witgen/fixed_evaluator.rs index aa97108d3b..d147a203a1 100644 --- a/executor/src/witgen/fixed_evaluator.rs +++ b/executor/src/witgen/fixed_evaluator.rs @@ -1,7 +1,6 @@ -use super::affine_expression::AffineResult; +use super::affine_expression::{AffineResult, AlgebraicVariable}; use super::expression_evaluator::SymbolicVariables; use super::FixedData; -use powdr_ast::analyzed::AlgebraicReference; use powdr_number::{DegreeType, FieldElement}; /// Evaluates only fixed columns on a specific row. @@ -22,19 +21,28 @@ impl<'a, T: FieldElement> FixedEvaluator<'a, T> { } impl<'a, T: FieldElement> SymbolicVariables for FixedEvaluator<'a, T> { - fn value<'b>(&self, poly: &'b AlgebraicReference) -> AffineResult<&'b AlgebraicReference, T> { + fn value<'b>(&self, poly: AlgebraicVariable<'b>) -> AffineResult, T> { // TODO arrays - assert!( - poly.is_fixed(), - "Can only access fixed columns in the fixed evaluator." - ); - let col_data = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); - let degree = col_data.len(); - let row = if poly.next { - (self.row + 1) % degree - } else { - self.row - }; - Ok(col_data[row].into()) + match poly { + AlgebraicVariable::Column(poly) => { + assert!( + poly.is_fixed(), + "Can only access fixed columns in the fixed evaluator, got column of type {:?}.", poly.poly_id.ptype + ); + let col_data = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); + let degree = col_data.len(); + let row = if poly.next { + (self.row + 1) % degree + } else { + self.row + }; + Ok(col_data[row].into()) + } + AlgebraicVariable::Public(public_name) => { + panic!( + "Can only access fixed columns in the fixed evaluator, got public: {public_name}" + ) + } + } } } diff --git a/executor/src/witgen/generator.rs b/executor/src/witgen/generator.rs index 92ddbb70df..98c465057f 100644 --- a/executor/src/witgen/generator.rs +++ b/executor/src/witgen/generator.rs @@ -1,4 +1,4 @@ -use powdr_ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference}; +use powdr_ast::analyzed::AlgebraicExpression as Expression; use powdr_number::{DegreeType, FieldElement}; use std::collections::HashMap; @@ -7,6 +7,7 @@ use crate::witgen::machines::profiling::{record_end, record_start}; use crate::witgen::processor::OuterQuery; use crate::witgen::EvalValue; +use super::affine_expression::AlgebraicVariable; use super::block_processor::BlockProcessor; use super::machines::{Machine, MachineParts}; use super::rows::{Row, RowIndex, RowPair}; @@ -15,7 +16,7 @@ use super::vm_processor::VmProcessor; use super::{EvalResult, FixedData, MutableState, QueryCallback}; struct ProcessResult<'a, T: FieldElement> { - eval_value: EvalValue<&'a AlgebraicReference, T>, + eval_value: EvalValue, T>, block: FinalizableData, } diff --git a/executor/src/witgen/global_constraints.rs b/executor/src/witgen/global_constraints.rs index a8a55c50e1..f4150ab865 100644 --- a/executor/src/witgen/global_constraints.rs +++ b/executor/src/witgen/global_constraints.rs @@ -13,6 +13,7 @@ use powdr_number::FieldElement; use crate::witgen::data_structures::column_map::{FixedColumnMap, WitnessColumnMap}; use crate::Identity; +use super::affine_expression::AlgebraicVariable; use super::expression_evaluator::ExpressionEvaluator; use super::range_constraints::RangeConstraint; use super::symbolic_evaluator::SymbolicEvaluator; @@ -28,12 +29,17 @@ pub struct SimpleRangeConstraintSet<'a, T: FieldElement> { range_constraints: &'a BTreeMap>, } -impl<'a, T: FieldElement> RangeConstraintSet<&AlgebraicReference, T> +impl<'a, T: FieldElement> RangeConstraintSet, T> for SimpleRangeConstraintSet<'a, T> { - fn range_constraint(&self, id: &AlgebraicReference) -> Option> { - assert!(!id.next); - self.range_constraints.get(&id.poly_id).cloned() + fn range_constraint(&self, id: AlgebraicVariable<'a>) -> Option> { + match id { + AlgebraicVariable::Column(id) => { + assert!(!id.next); + self.range_constraints.get(&id.poly_id).cloned() + } + AlgebraicVariable::Public(_) => unimplemented!(), + } } } @@ -301,11 +307,15 @@ fn is_binary_constraint(expr: &Expression) -> Option if let ([(id1, Constraint::Assignment(value1))], [(id2, Constraint::Assignment(value2))]) = (&left_root.constraints[..], &right_root.constraints[..]) { - if id1 != id2 || !id2.is_witness() { - return None; - } - if (value1.is_zero() && value2.is_one()) || (value1.is_one() && value2.is_zero()) { - return Some(id1.poly_id); + // We expect range constraints only on columns, because the verifier could easily + // check range constraints on publics themselves. + if let (AlgebraicVariable::Column(id1), AlgebraicVariable::Column(id2)) = (id1, id2) { + if id1 != id2 || !id2.is_witness() { + return None; + } + if (value1.is_zero() && value2.is_one()) || (value1.is_one() && value2.is_zero()) { + return Some(id1.poly_id); + } } } } @@ -338,13 +348,16 @@ fn try_transfer_constraints( result .constraints .into_iter() - .flat_map(|(poly, cons)| { - if let Constraint::RangeConstraint(cons) = cons { - assert!(!poly.next); - Some((poly.poly_id, cons)) - } else { - None + .flat_map(|(poly, cons)| match poly { + AlgebraicVariable::Column(poly) => { + if let Constraint::RangeConstraint(cons) = cons { + assert!(!poly.next); + Some((poly.poly_id, cons)) + } else { + None + } } + AlgebraicVariable::Public(_) => unimplemented!(), }) .collect() } diff --git a/executor/src/witgen/identity_processor.rs b/executor/src/witgen/identity_processor.rs index b99966c8ba..a990915e32 100644 --- a/executor/src/witgen/identity_processor.rs +++ b/executor/src/witgen/identity_processor.rs @@ -4,7 +4,7 @@ use std::{ }; use lazy_static::lazy_static; -use powdr_ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference, IdentityKind}; +use powdr_ast::analyzed::{AlgebraicExpression as Expression, IdentityKind}; use powdr_number::FieldElement; use crate::{ @@ -13,8 +13,8 @@ use crate::{ }; use super::{ - machines::KnownMachine, processor::OuterQuery, rows::RowPair, EvalResult, EvalValue, - IncompleteCause, MutableState, QueryCallback, + affine_expression::AlgebraicVariable, machines::KnownMachine, processor::OuterQuery, + rows::RowPair, EvalResult, EvalValue, IncompleteCause, MutableState, QueryCallback, }; /// A list of mutable references to machines. @@ -243,7 +243,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> IdentityProcessor<'a, 'b, &self, left_selector: &'a Expression, rows: &RowPair, - ) -> Option> { + ) -> Option, T>> { let value = match rows.evaluate(left_selector) { Err(incomplete_cause) => return Some(EvalValue::incomplete(incomplete_cause)), Ok(value) => value, diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index 149a224ee4..05e437de35 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -4,6 +4,7 @@ use std::iter::{self, once}; use super::{EvalResult, FixedData, MachineParts}; +use crate::witgen::affine_expression::AlgebraicVariable; use crate::witgen::block_processor::BlockProcessor; use crate::witgen::data_structures::finalizable_data::FinalizableData; use crate::witgen::processor::{OuterQuery, Processor}; @@ -17,19 +18,18 @@ use crate::witgen::{MutableState, QueryCallback}; use crate::Identity; use itertools::Itertools; use powdr_ast::analyzed::{ - AlgebraicExpression as Expression, AlgebraicReference, DegreeRange, IdentityKind, PolyID, - PolynomialType, + AlgebraicExpression as Expression, DegreeRange, IdentityKind, PolyID, PolynomialType, }; use powdr_ast::parsed::visitor::ExpressionVisitable; use powdr_number::{DegreeType, FieldElement}; enum ProcessResult<'a, T: FieldElement> { - Success(FinalizableData, EvalValue<&'a AlgebraicReference, T>), - Incomplete(EvalValue<&'a AlgebraicReference, T>), + Success(FinalizableData, EvalValue, T>), + Incomplete(EvalValue, T>), } impl<'a, T: FieldElement> ProcessResult<'a, T> { - fn new(data: FinalizableData, updates: EvalValue<&'a AlgebraicReference, T>) -> Self { + fn new(data: FinalizableData, updates: EvalValue, T>) -> Self { match updates.is_complete() { true => ProcessResult::Success(data, updates), false => ProcessResult::Incomplete(updates), diff --git a/executor/src/witgen/machines/fixed_lookup_machine.rs b/executor/src/witgen/machines/fixed_lookup_machine.rs index f8df195458..4f35bd0fae 100644 --- a/executor/src/witgen/machines/fixed_lookup_machine.rs +++ b/executor/src/witgen/machines/fixed_lookup_machine.rs @@ -7,7 +7,7 @@ use itertools::Itertools; use powdr_ast::analyzed::{AlgebraicReference, IdentityKind, PolyID, PolynomialType}; use powdr_number::FieldElement; -use crate::witgen::affine_expression::AffineExpression; +use crate::witgen::affine_expression::{AffineExpression, AlgebraicVariable}; use crate::witgen::global_constraints::{GlobalConstraints, RangeConstraintSet}; use crate::witgen::processor::OuterQuery; use crate::witgen::range_constraints::RangeConstraint; @@ -208,7 +208,7 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> { fn process_plookup_internal( &mut self, rows: &RowPair<'_, '_, T>, - left: &[AffineExpression<&'a AlgebraicReference, T>], + left: &[AffineExpression, T>], mut right: Peekable>, ) -> EvalResult<'a, T> { if left.len() == 1 @@ -216,7 +216,11 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> { && right.peek().unwrap().poly_id.ptype == PolynomialType::Constant { // Lookup of the form "c $ [ X ] in [ B ]". Might be a conditional range check. - return self.process_range_check(rows, left.first().unwrap(), right.peek().unwrap()); + return self.process_range_check( + rows, + left.first().unwrap(), + AlgebraicVariable::Column(right.peek().unwrap()), + ); } // split the fixed columns depending on whether their associated lookup variable is constant or not. Preserve the value of the constant arguments. @@ -292,8 +296,8 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> { fn process_range_check<'b>( &self, rows: &RowPair<'_, '_, T>, - lhs: &AffineExpression<&'b AlgebraicReference, T>, - rhs: &'b AlgebraicReference, + lhs: &AffineExpression, T>, + rhs: AlgebraicVariable<'b>, ) -> EvalResult<'b, T> { // Use AffineExpression::solve_with_range_constraints to transfer range constraints // from the rhs to the lhs. @@ -309,7 +313,12 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> { updates .constraints .into_iter() - .filter(|(poly, _)| poly.poly_id.ptype == PolynomialType::Committed) + .filter(|(poly, _)| match poly { + AlgebraicVariable::Column(poly) => { + poly.poly_id.ptype == PolynomialType::Committed + } + _ => unimplemented!(), + }) .collect(), IncompleteCause::NotConcrete, )) @@ -362,12 +371,16 @@ pub struct UnifiedRangeConstraints<'a, T: FieldElement> { global_constraints: &'a GlobalConstraints, } -impl RangeConstraintSet<&AlgebraicReference, T> +impl<'a, T: FieldElement> RangeConstraintSet, T> for UnifiedRangeConstraints<'_, T> { - fn range_constraint(&self, poly: &AlgebraicReference) -> Option> { + fn range_constraint(&self, var: AlgebraicVariable<'a>) -> Option> { + let poly = match var { + AlgebraicVariable::Column(poly) => poly, + _ => unimplemented!(), + }; match poly.poly_id.ptype { - PolynomialType::Committed => self.witness_constraints.range_constraint(poly), + PolynomialType::Committed => self.witness_constraints.range_constraint(var), PolynomialType::Constant => self.global_constraints.range_constraint(poly), PolynomialType::Intermediate => unimplemented!(), } diff --git a/executor/src/witgen/machines/sorted_witness_machine.rs b/executor/src/witgen/machines/sorted_witness_machine.rs index c29d14c8de..b07b03d160 100644 --- a/executor/src/witgen/machines/sorted_witness_machine.rs +++ b/executor/src/witgen/machines/sorted_witness_machine.rs @@ -1,10 +1,9 @@ use std::collections::{BTreeMap, HashMap}; -use itertools::Itertools; - use super::super::affine_expression::AffineExpression; use super::{EvalResult, FixedData}; use super::{Machine, MachineParts}; +use crate::witgen::affine_expression::AlgebraicVariable; use crate::witgen::rows::RowPair; use crate::witgen::{ expression_evaluator::ExpressionEvaluator, fixed_evaluator::FixedEvaluator, @@ -12,6 +11,7 @@ use crate::witgen::{ }; use crate::witgen::{EvalValue, IncompleteCause, MutableState, QueryCallback}; use crate::Identity; +use itertools::Itertools; use powdr_ast::analyzed::{ AlgebraicExpression as Expression, AlgebraicReference, IdentityKind, PolyID, }; @@ -149,13 +149,17 @@ fn check_constraint(constraint: &Expression) -> Option return None, }; let mut coeff = sort_constraint.nonzero_coefficients(); - let first = coeff.next()?; - let second = coeff.next()?; + let first = coeff + .next() + .and_then(|(k, v)| k.try_as_column().map(|k| (k, v)))?; + let second = coeff + .next() + .and_then(|(k, v)| k.try_as_column().map(|k| (k, v)))?; if coeff.next().is_some() { return None; } let key_column_id = match (first, second) { - ((key, _), _) | (_, (key, _)) if !key.next => *key, + ((key, _), _) | (_, (key, _)) if !key.next => key, _ => return None, }; if key_column_id.next || key_column_id.is_fixed() { @@ -165,8 +169,8 @@ fn check_constraint(constraint: &Expression) -> Option = Vec>; +type Left<'a, T> = Vec, T>>; /// Data needed to handle an outer query. #[derive(Clone)] @@ -287,7 +288,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): pub fn process_outer_query( &mut self, row_index: usize, - ) -> Result<(bool, Constraints<&'a AlgebraicReference, T>), EvalError> { + ) -> Result<(bool, Constraints, T>), EvalError> { let mut progress = false; let right = &self.outer_query.as_ref().unwrap().connecting_identity.right; if let Some(selector) = right.selector.as_ref() { @@ -331,10 +332,13 @@ Known values in current row (local: {row_index}, global {global_row_index}): let outer_assignments = updates .constraints .into_iter() - .filter(|(poly, update)| match update { - Constraint::Assignment(_) => !self.is_relevant_witness[&poly.poly_id], + .filter(|(var, update)| match (var, update) { + (AlgebraicVariable::Column(poly), Constraint::Assignment(_)) => { + !self.is_relevant_witness[&poly.poly_id] + } + (AlgebraicVariable::Public(_), Constraint::Assignment(_)) => unimplemented!(), // Range constraints are currently not communicated between callee and caller. - Constraint::RangeConstraint(_) => false, + (_, Constraint::RangeConstraint(_)) => false, }) .collect::>(); @@ -351,13 +355,14 @@ Known values in current row (local: {row_index}, global {global_row_index}): for (poly_id, value) in self.inputs.iter() { if !self.data[row_index].value_is_known(poly_id) { input_updates.combine(EvalValue::complete(vec![( - &self.fixed_data.witness_cols[poly_id].poly, + AlgebraicVariable::Column(&self.fixed_data.witness_cols[poly_id].poly), Constraint::Assignment(*value), )])); } } - for (poly, _) in &input_updates.constraints { + for (var, _) in &input_updates.constraints { + let poly = var.try_as_column().expect("Expected column"); let poly_id = &poly.poly_id; if let Some(start_row) = self.previously_set_inputs.remove(poly_id) { log::trace!( @@ -369,7 +374,8 @@ Known values in current row (local: {row_index}, global {global_row_index}): } } } - for (poly, _) in &input_updates.constraints { + for (var, _) in &input_updates.constraints { + let poly = var.try_as_column().expect("Expected column"); self.previously_set_inputs.insert(poly.poly_id, row_index); } self.apply_updates(row_index, &input_updates, || "inputs".to_string()) @@ -382,7 +388,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): expression: &'a Expression, value: T, name: impl Fn() -> String, - ) -> Result> { + ) -> Result>> { let row_pair = RowPair::new( &self.data[row_index], &self.data[row_index + 1], @@ -401,7 +407,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): fn apply_updates( &mut self, row_index: usize, - updates: &EvalValue<&'a AlgebraicReference, T>, + updates: &EvalValue, T>, source_name: impl Fn() -> String, ) -> bool { if updates.constraints.is_empty() { @@ -411,7 +417,11 @@ Known values in current row (local: {row_index}, global {global_row_index}): log::trace!(" Updates from: {}", source_name()); let mut progress = false; - for (poly, c) in &updates.constraints { + for (var, c) in &updates.constraints { + let poly = match var { + AlgebraicVariable::Column(poly) => poly, + _ => unimplemented!(), + }; if self.parts.witnesses.contains(&poly.poly_id) { // Build RowUpdater // (a bit complicated, because we need two mutable @@ -426,7 +436,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): let left = &mut self.outer_query.as_mut().unwrap().left; log::trace!(" => {} (outer) = {}", poly, v); for l in left.iter_mut() { - l.assign(poly, *v); + l.assign(*var, *v); } progress = true; }; diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index ddf1aa8a45..9b371d6770 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -7,6 +7,7 @@ use powdr_ast::parsed::types::Type; use powdr_number::{BigInt, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, SymbolLookup, Value}; +use super::affine_expression::AlgebraicVariable; use super::Constraints; use super::{rows::RowPair, Constraint, EvalResult, EvalValue, FixedData, IncompleteCause}; @@ -114,7 +115,10 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> if let Some(value) = (self.query_callback)(&query_str).map_err(super::EvalError::ProverQueryError)? { - EvalValue::complete(vec![(poly, Constraint::Assignment(value))]) + EvalValue::complete(vec![( + AlgebraicVariable::Column(poly), + Constraint::Assignment(value), + )]) } else { EvalValue::incomplete(IncompleteCause::NoQueryAnswer( query_str, @@ -150,7 +154,7 @@ struct Symbols<'a, 'b, 'c, T: FieldElement, QueryCallback: Send + Sync> { fixed_data: &'a FixedData<'a, T>, rows: &'b RowPair<'b, 'a, T>, size: DegreeType, - updates: Constraints<&'a AlgebraicReference, T>, + updates: Constraints, T>, query_callback: &'c mut QueryCallback, } @@ -198,14 +202,18 @@ impl<'a, 'b, 'c, T: FieldElement, QueryCallback: super::QueryCallback> Symbol ) -> Result>, EvalError> { Ok(Value::FieldElement(match poly_ref.poly_id.ptype { PolynomialType::Committed | PolynomialType::Intermediate => { - if let Some((_, update)) = self.updates.iter().find(|(p, _)| p == &poly_ref) { + if let Some((_, update)) = self + .updates + .iter() + .find(|(p, _)| p.try_as_column().map(|p| p == poly_ref).unwrap_or_default()) + { let Constraint::Assignment(value) = update else { unreachable!() }; *value } else { self.rows - .get_value(poly_ref) + .get_value(AlgebraicVariable::Column(poly_ref)) .ok_or(EvalError::DataNotAvailable)? } } @@ -241,6 +249,7 @@ impl<'a, 'b, 'c, T: FieldElement, QueryCallback: super::QueryCallback> Symbol value: Arc>, ) -> Result<(), EvalError> { // TODO allow "next: true" in the future. + // TODO allow assigning to publics in the future let Value::Expression(AlgebraicExpression::Reference(AlgebraicReference { poly_id, next: false, @@ -278,7 +287,10 @@ impl<'a, 'b, 'c, T: FieldElement, QueryCallback: super::QueryCallback> Symbol } } Err(EvalError::DataNotAvailable) => { - self.updates.push((col, Constraint::Assignment(*value))); + self.updates.push(( + AlgebraicVariable::Column(col), + Constraint::Assignment(*value), + )); } Err(e) => return Err(e), } @@ -325,7 +337,7 @@ impl<'a, 'b, 'c, T: FieldElement, QueryCallback: super::QueryCallback> Symbol impl<'a, 'b, 'c, T: FieldElement, QueryCallback: Send + Sync> Symbols<'a, 'b, 'c, T, QueryCallback> { - fn updates(self) -> Constraints<&'a AlgebraicReference, T> { + fn updates(self) -> Constraints, T> { self.updates } } diff --git a/executor/src/witgen/rows.rs b/executor/src/witgen/rows.rs index 72c7909b1c..03e7e3006b 100644 --- a/executor/src/witgen/rows.rs +++ b/executor/src/witgen/rows.rs @@ -10,7 +10,7 @@ use powdr_number::{DegreeType, FieldElement}; use crate::witgen::Constraint; use super::{ - affine_expression::{AffineExpression, AffineResult}, + affine_expression::{AffineExpression, AffineResult, AlgebraicVariable}, data_structures::{column_map::WitnessColumnMap, finalizable_data::FinalizedRow}, expression_evaluator::ExpressionEvaluator, global_constraints::RangeConstraintSet, @@ -454,19 +454,24 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { } } - pub fn get_value(&self, poly: &AlgebraicReference) -> Option { - let row = self.get_row(poly.next); - if self.unknown_strategy == UnknownStrategy::Zero { - Some(row.value_or_zero(&poly.poly_id)) - } else { - row.value(&poly.poly_id) + pub fn get_value(&self, poly: AlgebraicVariable) -> Option { + match poly { + AlgebraicVariable::Column(poly) => { + let row = self.get_row(poly.next); + if self.unknown_strategy == UnknownStrategy::Zero { + Some(row.value_or_zero(&poly.poly_id)) + } else { + row.value(&poly.poly_id) + } + } + _ => todo!(), } } /// Tries to evaluate the expression to an expression affine in the witness polynomials, /// taking current values of polynomials into account. /// @returns an expression affine in the witness polynomials - pub fn evaluate<'b>(&self, expr: &'b Expression) -> AffineResult<&'b AlgebraicReference, T> { + pub fn evaluate<'b>(&self, expr: &'b Expression) -> AffineResult, T> { ExpressionEvaluator::new(SymbolicWitnessEvaluator::new( self.fixed_data, self.current_row_index.into(), @@ -478,7 +483,7 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { } impl WitnessColumnEvaluator for RowPair<'_, '_, T> { - fn value<'b>(&self, poly: &'b AlgebraicReference) -> AffineResult<&'b AlgebraicReference, T> { + fn value<'b>(&self, poly: AlgebraicVariable<'b>) -> AffineResult, T> { Ok(match self.get_value(poly) { Some(v) => v.into(), None => AffineExpression::from_variable_id(poly), @@ -486,8 +491,13 @@ impl WitnessColumnEvaluator for RowPair<'_, '_, T> { } } -impl RangeConstraintSet<&AlgebraicReference, T> for RowPair<'_, '_, T> { - fn range_constraint(&self, poly: &AlgebraicReference) -> Option> { - self.get_row(poly.next).range_constraint(&poly.poly_id) +impl<'a, T: FieldElement> RangeConstraintSet, T> for RowPair<'_, '_, T> { + fn range_constraint(&self, poly: AlgebraicVariable<'a>) -> Option> { + match poly { + AlgebraicVariable::Column(poly) => { + self.get_row(poly.next).range_constraint(&poly.poly_id) + } + _ => todo!(), + } } } diff --git a/executor/src/witgen/symbolic_evaluator.rs b/executor/src/witgen/symbolic_evaluator.rs index de85da92f2..97dc8548a1 100644 --- a/executor/src/witgen/symbolic_evaluator.rs +++ b/executor/src/witgen/symbolic_evaluator.rs @@ -1,8 +1,7 @@ -use super::affine_expression::{AffineExpression, AffineResult}; +use super::affine_expression::{AffineExpression, AffineResult, AlgebraicVariable}; use super::expression_evaluator::SymbolicVariables; use super::IncompleteCause; -use powdr_ast::analyzed::AlgebraicReference; use powdr_number::FieldElement; /// A purely symbolic evaluator, uses AlgebraicReference as keys @@ -11,16 +10,21 @@ use powdr_number::FieldElement; pub struct SymbolicEvaluator; impl SymbolicVariables for SymbolicEvaluator { - fn value<'b>(&self, poly: &'b AlgebraicReference) -> AffineResult<&'b AlgebraicReference, T> { - assert!(poly.is_fixed() || poly.is_witness()); - // TODO arrays - Ok(AffineExpression::from_variable_id(poly)) + fn value<'b>(&self, var: AlgebraicVariable<'b>) -> AffineResult, T> { + match var { + AlgebraicVariable::Column(poly) => { + assert!(poly.is_fixed() || poly.is_witness()); + // TODO arrays + Ok(AffineExpression::from_variable_id(var)) + } + _ => todo!(), + } } fn challenge<'a>( &self, _challenge: &'a powdr_ast::analyzed::Challenge, - ) -> AffineResult<&'a AlgebraicReference, T> { + ) -> AffineResult, T> { // TODO: Challenges can't be symbolically evaluated, because they can't be // represented as an AffineExpression<&AlgebraicReference, T>... Err(IncompleteCause::SymbolicEvaluationOfChallenge) diff --git a/executor/src/witgen/symbolic_witness_evaluator.rs b/executor/src/witgen/symbolic_witness_evaluator.rs index 679341263e..1c2c0dba11 100644 --- a/executor/src/witgen/symbolic_witness_evaluator.rs +++ b/executor/src/witgen/symbolic_witness_evaluator.rs @@ -1,13 +1,17 @@ -use powdr_ast::analyzed::{AlgebraicReference, Challenge}; +use powdr_ast::analyzed::Challenge; use powdr_number::{DegreeType, FieldElement}; -use super::{affine_expression::AffineResult, expression_evaluator::SymbolicVariables, FixedData}; +use super::{ + affine_expression::{AffineResult, AlgebraicVariable}, + expression_evaluator::SymbolicVariables, + FixedData, +}; pub trait WitnessColumnEvaluator { /// Returns a symbolic or concrete value for the given witness column and next flag. /// This function defines the mapping to IDs. /// It should be used together with a matching reverse mapping in WitnessColumnNamer. - fn value<'b>(&self, poly: &'b AlgebraicReference) -> AffineResult<&'b AlgebraicReference, T>; + fn value<'b>(&self, poly: AlgebraicVariable<'b>) -> AffineResult, T>; } /// An evaluator (to be used together with ExpressionEvaluator) that performs concrete @@ -46,20 +50,25 @@ impl<'a, T: FieldElement, WA> SymbolicVariables for SymbolicWitnessEvaluator< where WA: WitnessColumnEvaluator, { - fn value<'b>(&self, poly: &'b AlgebraicReference) -> AffineResult<&'b AlgebraicReference, T> { - // TODO arrays - if poly.is_witness() { - self.witness_access.value(poly) - } else { - // Constant polynomial (or something else) - let values = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); - let row = - if poly.next { self.row + 1 } else { self.row } % (values.len() as DegreeType); - Ok(values[row as usize].into()) + fn value<'b>(&self, var: AlgebraicVariable<'b>) -> AffineResult, T> { + match var { + AlgebraicVariable::Column(poly) => { + // TODO arrays + if poly.is_witness() { + self.witness_access.value(var) + } else { + // Constant polynomial (or something else) + let values = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); + let row = if poly.next { self.row + 1 } else { self.row } + % (values.len() as DegreeType); + Ok(values[row as usize].into()) + } + } + _ => todo!(), } } - fn challenge<'b>(&self, challenge: &'b Challenge) -> AffineResult<&'b AlgebraicReference, T> { + fn challenge<'b>(&self, challenge: &'b Challenge) -> AffineResult, T> { Ok(self .fixed_data .challenges diff --git a/executor/src/witgen/vm_processor.rs b/executor/src/witgen/vm_processor.rs index 9f815f08fb..22b9dc5236 100644 --- a/executor/src/witgen/vm_processor.rs +++ b/executor/src/witgen/vm_processor.rs @@ -1,6 +1,6 @@ use indicatif::{ProgressBar, ProgressStyle}; use itertools::Itertools; -use powdr_ast::analyzed::{AlgebraicReference, DegreeRange, IdentityKind}; +use powdr_ast::analyzed::{DegreeRange, IdentityKind}; use powdr_ast::indent; use powdr_number::{DegreeType, FieldElement}; use std::cmp::max; @@ -11,6 +11,7 @@ use crate::witgen::identity_processor::{self}; use crate::witgen::IncompleteCause; use crate::Identity; +use super::affine_expression::AlgebraicVariable; use super::data_structures::finalizable_data::FinalizableData; use super::machines::MachineParts; use super::processor::{OuterQuery, Processor}; @@ -123,7 +124,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> VmProcessor<'a, 'b, 'c, T /// Starting out with a single row (at a given offset), iteratively append rows /// until we have exhausted the rows or the latch expression (if available) evaluates to 1. - pub fn run(&mut self, is_main_run: bool) -> EvalValue<&'a AlgebraicReference, T> { + pub fn run(&mut self, is_main_run: bool) -> EvalValue, T> { assert!(self.processor.len() == 1); if is_main_run { @@ -273,7 +274,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> VmProcessor<'a, 'b, 'c, T } } - fn compute_row(&mut self, row_index: DegreeType) -> Constraints<&'a AlgebraicReference, T> { + fn compute_row(&mut self, row_index: DegreeType) -> Constraints, T> { log::trace!( "===== Starting to process row: {}", row_index + self.row_offset @@ -342,7 +343,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> VmProcessor<'a, 'b, 'c, T &mut self, row_index: DegreeType, identities: &mut CompletableIdentities<'a, T>, - ) -> Result, Vec>> { + ) -> Result, T>, Vec>> { let mut outer_assignments = vec![]; // The PC lookup fills most of the columns and enables hints thus it should be run first. From a1b1447a6ee6122a8f5af043e68afc55a34f068d Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Mon, 23 Sep 2024 17:04:08 +0200 Subject: [PATCH 13/16] Point plonky3 to previous version (#1824) We merged plonky3 challenges in our fork. This breaks main. Point plonky3 to the commit before the merge. This will be removed by #1737 --- number/Cargo.toml | 6 +++--- plonky3/Cargo.toml | 38 +++++++++++++++++++------------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/number/Cargo.toml b/number/Cargo.toml index 0a38c05798..79a8bad037 100644 --- a/number/Cargo.toml +++ b/number/Cargo.toml @@ -13,9 +13,9 @@ ark-bn254 = { version = "0.4.0", default-features = false, features = [ ] } ark-ff = "0.4.2" ark-serialize = "0.4.2" -p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-mersenne-31 = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } +p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-mersenne-31 = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } num-bigint = { version = "0.4.3", features = ["serde"] } num-traits = "0.2.15" csv = "1.3" diff --git a/plonky3/Cargo.toml b/plonky3/Cargo.toml index 488b6e3bfc..e1d894aa8e 100644 --- a/plonky3/Cargo.toml +++ b/plonky3/Cargo.toml @@ -13,31 +13,31 @@ rand = "0.8.5" powdr-analysis = { path = "../analysis" } powdr-executor = { path = "../executor" } -p3-air = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-matrix = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-uni-stark = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-commit = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main", features = [ +p3-air = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-matrix = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-uni-stark = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-commit = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75", features = [ "test-utils", ] } -p3-poseidon2 = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-poseidon = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-fri = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } +p3-poseidon2 = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-poseidon = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-fri = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } # We don't use p3-maybe-rayon directly, but it is a dependency of p3-uni-stark. # Activating the "parallel" feature gives us parallelism in the prover. -p3-maybe-rayon = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main", features = [ +p3-maybe-rayon = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75", features = [ "parallel", ] } -p3-mds = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-merkle-tree = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-mersenne-31 = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-circle = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-goldilocks = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-symmetric = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-dft = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-challenger = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } -p3-util = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } +p3-mds = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-merkle-tree = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-mersenne-31 = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-circle = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-goldilocks = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-symmetric = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-dft = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-challenger = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } +p3-util = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "6afe4f75" } lazy_static = "1.4.0" rand_chacha = "0.3.1" bincode = "1.3.3" From 80a7924e6f3333bc2e7a05c5e5fa9070c84c0bd9 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 18:17:57 +0200 Subject: [PATCH 14/16] Builtin: Capture stage (#1334) Implements part of https://github.com/powdr-labs/powdr/issues/424 --------- Co-authored-by: Georg Wiese --- ast/src/analyzed/display.rs | 10 +- ast/src/parsed/display.rs | 4 +- pil-analyzer/src/condenser.rs | 80 ++++++++-- pil-analyzer/src/evaluator.rs | 43 +++++- pil-analyzer/src/side_effect_checker.rs | 2 + pil-analyzer/src/type_builtins.rs | 5 + pil-analyzer/tests/condenser.rs | 189 +++++++++++++++++++++++- plonky3/src/stark.rs | 2 +- std/prover.asm | 10 +- 9 files changed, 320 insertions(+), 25 deletions(-) diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 5e97f35779..7f8d00c885 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -162,12 +162,8 @@ fn format_fixed_column( definition: &Option, ) -> String { assert_eq!(symbol.kind, SymbolKind::Poly(PolynomialType::Constant)); - let stage = symbol - .stage - .map(|s| format!("stage({s}) ")) - .unwrap_or_default(); + assert!(symbol.stage.is_none()); if let Some(TypedExpression { type_scheme, e }) = try_to_simple_expression(definition) { - assert!(symbol.stage.is_none()); if symbol.length.is_some() { assert!(matches!( type_scheme, @@ -187,7 +183,7 @@ fn format_fixed_column( .as_ref() .map(ToString::to_string) .unwrap_or_default(); - format!("col fixed {stage}{name}{value};",) + format!("col fixed {name}{value};",) } } @@ -199,7 +195,7 @@ fn format_witness_column( assert_eq!(symbol.kind, SymbolKind::Poly(PolynomialType::Committed)); let stage = symbol .stage - .map(|s| format!("stage({s}) ")) + .and_then(|s| (s > 0).then(|| format!("stage({s}) "))) .unwrap_or_default(); let length = symbol .length diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 319d4e056d..b13908122c 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -516,7 +516,9 @@ impl Display for PilStatement { f, format!( "pol commit {}{}{};", - stage.map(|s| format!("stage({s}) ")).unwrap_or_default(), + stage + .and_then(|s| (s > 0).then(|| format!("stage({s}) "))) + .unwrap_or_default(), names.iter().format(", "), value.as_ref().map(|v| format!("{v}")).unwrap_or_default() ), diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 53910f6313..52fb958d61 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -32,7 +32,10 @@ use powdr_number::{BigUint, FieldElement}; use powdr_parser_util::SourceRef; use crate::{ - evaluator::{self, Closure, Definitions, EnumValue, EvalError, SymbolLookup, Value}, + evaluator::{ + self, evaluate_function_call, Closure, Definitions, EnumValue, EvalError, SymbolLookup, + Value, + }, statement_processor::Counters, }; @@ -57,12 +60,15 @@ pub fn condense( let source_order = source_order .into_iter() .flat_map(|s| { + // Potentially modify the current namespace. if let StatementIdentifier::Definition(name) = &s { let mut namespace = AbsoluteSymbolPath::default().join(SymbolPath::from_str(name).unwrap()); namespace.pop(); condenser.set_namespace_and_degree(namespace, definitions[name].0.degree); } + + // Condense identities and definitions. let statement = match s { StatementIdentifier::ProofItem(index) => { condenser.condense_proof_item(&proof_items[index]); @@ -213,8 +219,12 @@ pub struct Condenser<'a, T> { new_intermediate_column_values: HashMap>>, /// The names of all new columns ever generated, to avoid duplicates. new_symbols: HashSet, - new_constraints: Vec>, + /// Constraints added since the last extraction. The values should be enums of type `std::prelude::Constr`. + new_constraints: Vec<(Arc>, SourceRef)>, + /// Prover functions added since the last extraction. new_prover_functions: Vec, + /// The current stage. New columns are created at that stage. + stage: u32, } impl<'a, T: FieldElement> Condenser<'a, T> { @@ -236,6 +246,7 @@ impl<'a, T: FieldElement> Condenser<'a, T> { new_symbols: HashSet::new(), new_constraints: vec![], new_prover_functions: vec![], + stage: 0, } } @@ -286,7 +297,10 @@ impl<'a, T: FieldElement> Condenser<'a, T> { /// Returns the new constraints generated since the last call to this function. pub fn extract_new_constraints(&mut self) -> Vec> { - std::mem::take(&mut self.new_constraints) + self.new_constraints + .drain(..) + .map(|(item, source)| to_constraint(item.as_ref(), source, &mut self.counters)) + .collect() } /// Returns the new prover functions generated since the last call to this function. @@ -458,6 +472,25 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { self.new_column_values.insert(name.clone(), value); } + if self.stage != 0 && stage.is_some() { + return Err(EvalError::TypeError(format!( + "Tried to create a column with an explicit stage ({}) while the current stage was not zero, but {}.", + stage.unwrap(), self.stage + ))); + } + + let stage = if matches!( + kind, + SymbolKind::Poly(PolynomialType::Constant | PolynomialType::Intermediate) + ) { + // Fixed columns are pre-stage 0 and the stage of an intermediate column + // is the max of the stages in the value, so we omit it in both cases. + assert!(stage.is_none()); + None + } else { + Some(stage.unwrap_or(self.stage)) + }; + let symbol = Symbol { id: self.counters.dispense_symbol_id(kind, length), source, @@ -555,11 +588,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { match items.as_ref() { Value::Array(items) => { for item in items { - self.new_constraints.push(to_constraint( - item, - source.clone(), - &mut self.counters, - )) + self.new_constraints.push((item.clone(), source.clone())); } } Value::Closure(..) => { @@ -569,9 +598,38 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { self.new_prover_functions.push(e); } - _ => self - .new_constraints - .push(to_constraint(&items, source, &mut self.counters)), + _ => self.new_constraints.push((items, source)), + } + Ok(()) + } + + fn capture_constraints( + &mut self, + fun: Arc>, + ) -> Result>, EvalError> { + let existing_constraints = self.new_constraints.len(); + let result = evaluate_function_call(fun, vec![], self); + let constrs = self + .new_constraints + .drain(existing_constraints..) + .map(|(c, _)| c) + .collect(); + let result = result?; + assert!( + matches!(result.as_ref(), Value::Tuple(items) if items.is_empty()), + "Function should return ()" + ); + + Ok(Arc::new(Value::Array(constrs))) + } + + fn at_next_stage(&mut self, fun: Arc>) -> Result<(), EvalError> { + self.stage += 1; + let result = evaluate_function_call(fun, vec![], self); + self.stage -= 1; + let result = result?; + if !matches!(result.as_ref(), Value::Tuple(items) if items.is_empty()) { + panic!(); } Ok(()) } diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index 44658f2996..b47f809fe3 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -157,7 +157,7 @@ impl<'a, T: FieldElement> From for Value<'a, T> { } } -impl<'a, T: FieldElement> From> for Value<'a, T> { +impl<'a, T> From> for Value<'a, T> { fn from(value: AlgebraicExpression) -> Self { Value::Expression(value) } @@ -399,7 +399,7 @@ fn none_value<'a, T>() -> Value<'a, T> { }) } -const BUILTINS: [(&str, BuiltinFunction); 20] = [ +const BUILTINS: [(&str, BuiltinFunction); 21] = [ ("std::array::len", BuiltinFunction::ArrayLen), ("std::check::panic", BuiltinFunction::Panic), ("std::convert::expr", BuiltinFunction::ToExpr), @@ -407,6 +407,11 @@ const BUILTINS: [(&str, BuiltinFunction); 20] = [ ("std::convert::int", BuiltinFunction::ToInt), ("std::debug::print", BuiltinFunction::Print), ("std::field::modulus", BuiltinFunction::Modulus), + ( + "std::prover::capture_constraints", + BuiltinFunction::CaptureConstraints, + ), + ("std::prover::at_next_stage", BuiltinFunction::AtNextStage), ("std::prelude::challenge", BuiltinFunction::Challenge), ( "std::prover::new_witness_col_at_stage", @@ -419,7 +424,6 @@ const BUILTINS: [(&str, BuiltinFunction); 20] = [ ("std::prover::degree", BuiltinFunction::Degree), ("std::prover::eval", BuiltinFunction::Eval), ("std::prover::try_eval", BuiltinFunction::TryEval), - ("std::prover::try_eval", BuiltinFunction::TryEval), ("std::prover::get_input", BuiltinFunction::GetInput), ( "std::prover::get_input_from_channel", @@ -446,6 +450,13 @@ pub enum BuiltinFunction { ToInt, /// std::convert::fe: int/fe -> fe, converts int to fe ToFe, + /// std::prover::capture_constraints: (-> ()) -> Constr[] + /// Calls the argument and returns all constraints that it added to the global set + /// (Those are removed from the global set). + CaptureConstraints, + /// std::prover::at_next_stage: (-> ()) -> (), calls the argument at the next proof stage + /// and resets the stage again. + AtNextStage, /// std::prover::challenge: int, int -> expr, constructs a challenge with a given stage and ID. Challenge, /// std::prover::new_witness_col_at_stage: string, int -> expr, creates a new witness column at a certain proof stage. @@ -724,6 +735,21 @@ pub trait SymbolLookup<'a, T: FieldElement> { )) } + fn capture_constraints( + &mut self, + _fun: Arc>, + ) -> Result>, EvalError> { + Err(EvalError::Unsupported( + "The function capture_constraints is not allowed at this point.".to_string(), + )) + } + + fn at_next_stage(&mut self, _fun: Arc>) -> Result<(), EvalError> { + Err(EvalError::Unsupported( + "The function at_next_stage is not allowed at this point.".to_string(), + )) + } + fn provide_value( &mut self, _col: Arc>, @@ -1401,6 +1427,8 @@ fn evaluate_builtin_function<'a, T: FieldElement>( BuiltinFunction::MinDegree => 0, BuiltinFunction::MaxDegree => 0, BuiltinFunction::Degree => 0, + BuiltinFunction::CaptureConstraints => 1, + BuiltinFunction::AtNextStage => 1, BuiltinFunction::Eval => 1, BuiltinFunction::TryEval => 1, BuiltinFunction::GetInput => 1, @@ -1531,6 +1559,15 @@ fn evaluate_builtin_function<'a, T: FieldElement>( BuiltinFunction::MaxDegree => symbols.max_degree()?, BuiltinFunction::MinDegree => symbols.min_degree()?, BuiltinFunction::Degree => symbols.degree()?, + BuiltinFunction::CaptureConstraints => { + let fun = arguments.pop().unwrap(); + symbols.capture_constraints(fun)? + } + BuiltinFunction::AtNextStage => { + let fun = arguments.pop().unwrap(); + symbols.at_next_stage(fun)?; + Value::Tuple(vec![]).into() + } BuiltinFunction::Eval => { let arg = arguments.pop().unwrap(); match arg.as_ref() { diff --git a/pil-analyzer/src/side_effect_checker.rs b/pil-analyzer/src/side_effect_checker.rs index 5ac435533b..1aae2c8448 100644 --- a/pil-analyzer/src/side_effect_checker.rs +++ b/pil-analyzer/src/side_effect_checker.rs @@ -153,6 +153,8 @@ lazy_static! { ("std::convert::expr", FunctionKind::Pure), ("std::debug::print", FunctionKind::Pure), ("std::field::modulus", FunctionKind::Pure), + ("std::prover::capture_constraints", FunctionKind::Constr), + ("std::prover::at_next_stage", FunctionKind::Constr), ("std::prelude::challenge", FunctionKind::Constr), // strictly, only new_challenge would need "constr" ("std::prover::min_degree", FunctionKind::Pure), ("std::prover::max_degree", FunctionKind::Pure), diff --git a/pil-analyzer/src/type_builtins.rs b/pil-analyzer/src/type_builtins.rs index f375131d18..cecf9c4931 100644 --- a/pil-analyzer/src/type_builtins.rs +++ b/pil-analyzer/src/type_builtins.rs @@ -47,6 +47,11 @@ lazy_static! { ("", "string, int -> expr") ), ("std::prover::min_degree", ("", "-> int")), + ( + "std::prover::capture_constraints", + ("", "(-> ()) -> std::prelude::Constr[]") + ), + ("std::prover::at_next_stage", ("", "(-> ()) -> ()")), ("std::prover::max_degree", ("", "-> int")), ("std::prover::degree", ("", "-> int")), ( diff --git a/pil-analyzer/tests/condenser.rs b/pil-analyzer/tests/condenser.rs index 9e1ed32f28..dc89a0bbab 100644 --- a/pil-analyzer/tests/condenser.rs +++ b/pil-analyzer/tests/condenser.rs @@ -764,7 +764,7 @@ namespace N(16); let v: expr = std::prover::new_witness_col_at_stage("y", 2); let unused: expr = std::prover::new_witness_col_at_stage("z", 10); col witness y; - col witness stage(0) x_1; + col witness x_1; col witness stage(1) x_2; col witness stage(2) x_3; col witness stage(1) y_1; @@ -892,3 +892,190 @@ namespace std::prover; let re_analyzed = analyze_string(&formatted); assert_eq!(re_analyzed.to_string(), expected); } + +#[test] +pub fn capture_constraints_empty() { + let input = r#" + namespace std::prover; + let capture_constraints: (-> ()) -> Constr[] = 9; + + namespace Main; + let gen = || { }; + let a; + let b; + a = 1; + std::prover::capture_constraints(gen); + b = 2; + "#; + let formatted = analyze_string(input).to_string(); + let expected = "namespace std::prover; + let capture_constraints: (-> ()) -> std::prelude::Constr[] = 9; +namespace Main; + let gen: -> () = || { }; + col witness a; + col witness b; + Main::a = 1; + Main::b = 2; +"; + assert_eq!(formatted, expected); +} + +#[test] +pub fn capture_constraints_new_col_and_constr() { + let input = r#" + namespace std::prover; + let capture_constraints: (-> ()) -> Constr[] = 9; + + namespace Main; + let gen = constr || { + let x; + [x = 1, x = 2]; + x = 3; + }; + let a; + let b; + a = 1; + let constrs = std::prover::capture_constraints(gen); + // Ignore the second constraint + [constrs[0], constrs[2]]; + b = 2; + "#; + let formatted = analyze_string(input).to_string(); + let expected = "namespace std::prover; + let capture_constraints: (-> ()) -> std::prelude::Constr[] = 9; +namespace Main; + let gen: -> () = constr || { + let x: col; + [x = 1, x = 2]; + x = 3; + }; + col witness a; + col witness b; + Main::a = 1; + let constrs: std::prelude::Constr[] = std::prover::capture_constraints(Main::gen); + col witness x; + Main::x = 1; + Main::x = 3; + Main::b = 2; +"; + assert_eq!(formatted, expected); +} + +#[test] +pub fn capture_constraints_recursive() { + let input = r#" + namespace std::prover; + let capture_constraints: (-> ()) -> Constr[] = 9; + + namespace Main; + let a; + [a] in [b]; + let constrs = std::prover::capture_constraints(constr || { + let x; + [x = 1, x = 2]; + std::prover::capture_constraints(constr || { + let y; + [y = 1, [y] in [x]]; + y = 3; + })[1]; + x = 3; + }); + // Ignore the second constraint + [constrs[0], constrs[2], constrs[3]]; + let b; + b = 2; + "#; + let formatted = analyze_string(input).to_string(); + let expected = "namespace std::prover; + let capture_constraints: (-> ()) -> std::prelude::Constr[] = 9; +namespace Main; + col witness a; + [Main::a] in [Main::b]; + let constrs: std::prelude::Constr[] = std::prover::capture_constraints(constr || { + let x: col; + [x = 1, x = 2]; + std::prover::capture_constraints(constr || { + let y: col; + [y = 1, [y] in [x]]; + y = 3; + })[1]; + x = 3; + }); + col witness x; + col witness y; + Main::x = 1; + [Main::y] in [Main::x]; + Main::x = 3; + col witness b; + Main::b = 2; +"; + assert_eq!(formatted, expected); +} + +#[test] +pub fn at_next_stage() { + let input = r#" + namespace std::prover; + let at_next_stage: (-> ()) -> () = 9; + + namespace Main; + let a; + std::prover::at_next_stage(constr || { + let x; + std::prover::at_next_stage(constr || { + let y; + x = 1; + y = 2; + }); + let c; + x = a + c; + }); + let b; + "#; + let formatted = analyze_string(input).to_string(); + let expected = "namespace std::prover; + let at_next_stage: (-> ()) -> () = 9; +namespace Main; + col witness a; + col witness stage(1) x; + col witness stage(2) y; + col witness stage(1) c; + Main::x = 1; + Main::y = 2; + Main::x = Main::a + Main::c; + col witness b; +"; + assert_eq!(formatted, expected); +} + +#[test] +pub fn at_next_stage_intermediate_and_fixed() { + let input = r#" + namespace std::prover; + let at_next_stage: (-> ()) -> () = 9; + + namespace Main; + let a; + std::prover::at_next_stage(constr || { + let b: inter = a * a; + let c; + let first: col = |i| if i == 0 { 1 } else { 0 }; + let d: inter = a + c; + c' = first; + }); + let x; + "#; + let formatted = analyze_string(input).to_string(); + let expected = "namespace std::prover; + let at_next_stage: (-> ()) -> () = 9; +namespace Main; + col witness a; + col b = Main::a * Main::a; + col witness stage(1) c; + col fixed first(i) { if i == 0 { 1 } else { 0 } }; + col d = Main::a + Main::c; + Main::c' = Main::first; + col witness x; +"; + assert_eq!(formatted, expected); +} diff --git a/plonky3/src/stark.rs b/plonky3/src/stark.rs index aa760aaf16..ac63d85a78 100644 --- a/plonky3/src/stark.rs +++ b/plonky3/src/stark.rs @@ -405,7 +405,7 @@ mod tests { namespace Global(N); let beta: expr = std::prelude::challenge(0, 42); - col witness stage(0) x; + col witness x; col witness stage(1) y; x = y + beta; "#; diff --git a/std/prover.asm b/std/prover.asm index 0c610cc9a8..49f23bf4d7 100644 --- a/std/prover.asm +++ b/std/prover.asm @@ -60,4 +60,12 @@ let degree: -> int = []; let require_min_degree: int -> () = |m| std::check::assert(degree() >= m, || "Degree too small."); /// Asserts that the current degree or row count is at most m; -let require_max_degree: int -> () = |m| std::check::assert(degree() <= m, || "Degree too large."); \ No newline at end of file +let require_max_degree: int -> () = |m| std::check::assert(degree() <= m, || "Degree too large."); + +/// Calls the argument and returns all constraints that were generated during the call. +/// If the constraints are not added to the global set again, they are ignored. +let capture_constraints: (-> ()) -> Constr[] = []; + +/// Calls the argument with the current stage counter incremented. This means that columns created during +/// the call will be next-stage columns. The stage counter is reset afterwards. +let at_next_stage: (-> ()) -> () = []; From 5739ebcee8773f9d5f6b6e9bcd65d01d6ab320f5 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Mon, 23 Sep 2024 15:37:25 -0400 Subject: [PATCH 15/16] Plonky3 Keccak (#1801) Adapted from Plonky3 Keccak. Uses prover functions for all witness generation. Inputs and outputs are both in 16-bit limbs, so works for BabyBear and small fields. For quick iteration, currently the operation is `preimage[0], preimage[1], ..., preimage[99] -> output[0], output[1], ..., output[99]`. Will change to only take an input pointer and an output pointer in another PR. Two things to note: 1. Main machine needs to be degree 64 or more, so that Keccak block can run at least twice (24 rows per block), or the machine is detected as an VM and weird bugs happen. 2. For some reason, when calling the Keccak submachine twice or more in the main machine, witgen marks existing value as `Incomplete` and generates a weird bug. --- pipeline/tests/powdr_std.rs | 7 + std/machines/hash/keccakf16.asm | 613 +++++++++++++++++++++++++++++++ std/machines/hash/mod.asm | 1 + test_data/std/keccakf16_test.asm | 330 +++++++++++++++++ 4 files changed, 951 insertions(+) create mode 100644 std/machines/hash/keccakf16.asm create mode 100644 test_data/std/keccakf16_test.asm diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 55e3d77719..cb0da73303 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -43,6 +43,13 @@ fn poseidon_gl_memory_test() { gen_estark_proof(pipeline); } +#[test] +#[ignore = "Too slow"] +fn keccakf16_test() { + let f = "std/keccakf16_test.asm"; + test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); +} + #[test] #[ignore = "Too slow"] fn split_bn254_test() { diff --git a/std/machines/hash/keccakf16.asm b/std/machines/hash/keccakf16.asm new file mode 100644 index 0000000000..07698397d7 --- /dev/null +++ b/std/machines/hash/keccakf16.asm @@ -0,0 +1,613 @@ +use std::array; +use std::utils; +use std::utils::unchanged_until; +use std::utils::force_bool; +use std::convert::expr; +use std::convert::int; +use std::convert::fe; +use std::prelude::set_hint; +use std::prelude::Query; +use std::prover::eval; +use std::prover::provide_value; + +machine Keccakf16 with + latch: final_step, + operation_id: operation_id, + call_selectors: sel, +{ + // Adapted from Plonky3 implementation of Keccak: https://github.com/Plonky3/Plonky3/tree/main/keccak-air/src + + std::check::assert(std::field::modulus() >= 65535, || "The field modulo should be at least 2^16 - 1 to work in the keccakf16 machine."); + + // Expects input of 25 64-bit numbers decomposed to 25 chunks of 4 16-bit little endian limbs. + // The output is a_prime_prime_prime_0_0_limbs for the first 4 and a_prime_prime for the rest. + operation keccakf16<0> preimage[0], preimage[1], preimage[2], preimage[3], preimage[4], preimage[5], preimage[6], preimage[7], preimage[8], preimage[9], preimage[10], preimage[11], preimage[12], preimage[13], preimage[14], preimage[15], preimage[16], preimage[17], preimage[18], preimage[19], preimage[20], preimage[21], preimage[22], preimage[23], preimage[24], preimage[25], preimage[26], preimage[27], preimage[28], preimage[29], preimage[30], preimage[31], preimage[32], preimage[33], preimage[34], preimage[35], preimage[36], preimage[37], preimage[38], preimage[39], preimage[40], preimage[41], preimage[42], preimage[43], preimage[44], preimage[45], preimage[46], preimage[47], preimage[48], preimage[49], preimage[50], preimage[51], preimage[52], preimage[53], preimage[54], preimage[55], preimage[56], preimage[57], preimage[58], preimage[59], preimage[60], preimage[61], preimage[62], preimage[63], preimage[64], preimage[65], preimage[66], preimage[67], preimage[68], preimage[69], preimage[70], preimage[71], preimage[72], preimage[73], preimage[74], preimage[75], preimage[76], preimage[77], preimage[78], preimage[79], preimage[80], preimage[81], preimage[82], preimage[83], preimage[84], preimage[85], preimage[86], preimage[87], preimage[88], preimage[89], preimage[90], preimage[91], preimage[92], preimage[93], preimage[94], preimage[95], preimage[96], preimage[97], preimage[98], preimage[99] -> a_prime_prime_prime_0_0_limbs[0], a_prime_prime_prime_0_0_limbs[1], a_prime_prime_prime_0_0_limbs[2], a_prime_prime_prime_0_0_limbs[3], a_prime_prime[4], a_prime_prime[5], a_prime_prime[6], a_prime_prime[7], a_prime_prime[8], a_prime_prime[9], a_prime_prime[10], a_prime_prime[11], a_prime_prime[12], a_prime_prime[13], a_prime_prime[14], a_prime_prime[15], a_prime_prime[16], a_prime_prime[17], a_prime_prime[18], a_prime_prime[19], a_prime_prime[20], a_prime_prime[21], a_prime_prime[22], a_prime_prime[23], a_prime_prime[24], a_prime_prime[25], a_prime_prime[26], a_prime_prime[27], a_prime_prime[28], a_prime_prime[29], a_prime_prime[30], a_prime_prime[31], a_prime_prime[32], a_prime_prime[33], a_prime_prime[34], a_prime_prime[35], a_prime_prime[36], a_prime_prime[37], a_prime_prime[38], a_prime_prime[39], a_prime_prime[40], a_prime_prime[41], a_prime_prime[42], a_prime_prime[43], a_prime_prime[44], a_prime_prime[45], a_prime_prime[46], a_prime_prime[47], a_prime_prime[48], a_prime_prime[49], a_prime_prime[50], a_prime_prime[51], a_prime_prime[52], a_prime_prime[53], a_prime_prime[54], a_prime_prime[55], a_prime_prime[56], a_prime_prime[57], a_prime_prime[58], a_prime_prime[59], a_prime_prime[60], a_prime_prime[61], a_prime_prime[62], a_prime_prime[63], a_prime_prime[64], a_prime_prime[65], a_prime_prime[66], a_prime_prime[67], a_prime_prime[68], a_prime_prime[69], a_prime_prime[70], a_prime_prime[71], a_prime_prime[72], a_prime_prime[73], a_prime_prime[74], a_prime_prime[75], a_prime_prime[76], a_prime_prime[77], a_prime_prime[78], a_prime_prime[79], a_prime_prime[80], a_prime_prime[81], a_prime_prime[82], a_prime_prime[83], a_prime_prime[84], a_prime_prime[85], a_prime_prime[86], a_prime_prime[87], a_prime_prime[88], a_prime_prime[89], a_prime_prime[90], a_prime_prime[91], a_prime_prime[92], a_prime_prime[93], a_prime_prime[94], a_prime_prime[95], a_prime_prime[96], a_prime_prime[97], a_prime_prime[98], a_prime_prime[99]; + + col witness operation_id; + + let NUM_ROUNDS: int = 24; + + // pub struct KeccakCols { + // /// The `i`th value is set to 1 if we are in the `i`th round, otherwise 0. + // pub step_flags: [T; NUM_ROUNDS], + + // /// A register which indicates if a row should be exported, i.e. included in a multiset equality + // /// argument. Should be 1 only for certain rows which are final steps, i.e. with + // /// `step_flags[23] = 1`. + // pub export: T, + + // /// Permutation inputs, stored in y-major order. + // pub preimage: [[[T; U64_LIMBS]; 5]; 5], + + // pub a: [[[T; U64_LIMBS]; 5]; 5], + + // /// ```ignore + // /// C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]) + // /// ``` + // pub c: [[T; 64]; 5], + + // /// ```ignore + // /// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]) + // /// ``` + // pub c_prime: [[T; 64]; 5], + + // // Note: D is inlined, not stored in the witness. + // /// ```ignore + // /// A'[x, y] = xor(A[x, y], D[x]) + // /// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) + // /// ``` + // pub a_prime: [[[T; 64]; 5]; 5], + + // /// ```ignore + // /// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + // /// ``` + // pub a_prime_prime: [[[T; U64_LIMBS]; 5]; 5], + + // /// The bits of `A''[0, 0]`. + // pub a_prime_prime_0_0_bits: [T; 64], + + // /// ```ignore + // /// A'''[0, 0, z] = A''[0, 0, z] ^ RC[k, z] + // /// ``` + // pub a_prime_prime_prime_0_0_limbs: [T; U64_LIMBS], + // } + + pol commit preimage[5 * 5 * 4]; + pol commit a[5 * 5 * 4]; + pol commit c[5 * 64]; + array::map(c, |i| force_bool(i)); + pol commit c_prime[5 * 64]; + pol commit a_prime[5 * 5 * 64]; + array::map(a_prime, |i| force_bool(i)); + pol commit a_prime_prime[5 * 5 * 4]; + pol commit a_prime_prime_0_0_bits[64]; + array::map(a_prime_prime_0_0_bits, |i| force_bool(i)); + pol commit a_prime_prime_prime_0_0_limbs[4]; + + // Initially, the first step flag should be 1 while the others should be 0. + // builder.when_first_row().assert_one(local.step_flags[0]); + // for i in 1..NUM_ROUNDS { + // builder.when_first_row().assert_zero(local.step_flags[i]); + // } + // for i in 0..NUM_ROUNDS { + // let current_round_flag = local.step_flags[i]; + // let next_round_flag = next.step_flags[(i + 1) % NUM_ROUNDS]; + // builder + // .when_transition() + // .assert_eq(next_round_flag, current_round_flag); + // } + + let step_flags: col[NUM_ROUNDS] = array::new(NUM_ROUNDS, |i| |row| if row % NUM_ROUNDS == i { 1 } else { 0 } ); + + // let main = builder.main(); + // let (local, next) = (main.row_slice(0), main.row_slice(1)); + // let local: &KeccakCols = (*local).borrow(); + // let next: &KeccakCols = (*next).borrow(); + + // let first_step = local.step_flags[0]; + // let final_step = local.step_flags[NUM_ROUNDS - 1]; + // let not_final_step = AB::Expr::one() - final_step; + + let first_step: expr = step_flags[0]; // Aliasing instead of defining a new fixed column. + let final_step: expr = step_flags[NUM_ROUNDS - 1]; + col fixed is_last = [0]* + [1]; + + // // If this is the first step, the input A must match the preimage. + // for y in 0..5 { + // for x in 0..5 { + // for limb in 0..U64_LIMBS { + // builder + // .when(first_step) + // .assert_eq(local.preimage[y][x][limb], local.a[y][x][limb]); + // } + // } + // } + + array::zip(preimage, a, |p_i, a_i| first_step * (p_i - a_i) = 0); + + // // The export flag must be 0 or 1. + // builder.assert_bool(local.export); + + // force_bool(export); + + // // If this is not the final step, the export flag must be off. + // builder + // .when(not_final_step.clone()) + // .assert_zero(local.export); + + // not_final_step * export = 0; + + // // If this is not the final step, the local and next preimages must match. + // for y in 0..5 { + // for x in 0..5 { + // for limb in 0..U64_LIMBS { + // builder + // .when(not_final_step.clone()) + // .when_transition() + // .assert_eq(local.preimage[y][x][limb], next.preimage[y][x][limb]); + // } + // } + // } + + array::map(preimage, |p| unchanged_until(p, final_step + is_last)); + + // for x in 0..5 { + // for z in 0..64 { + // builder.assert_bool(local.c[x][z]); + // let xor = xor3_gen::( + // local.c[x][z].into(), + // local.c[(x + 4) % 5][z].into(), + // local.c[(x + 1) % 5][(z + 63) % 64].into(), + // ); + // let c_prime = local.c_prime[x][z]; + // builder.assert_eq(c_prime, xor); + // } + // } + + let andn: expr, expr -> expr = |a, b| (1 - a) * b; + let xor: expr, expr -> expr = |a, b| a + b - 2*a*b; + let xor3: expr, expr, expr -> expr = |a, b, c| xor(xor(a, b), c); + // a b c xor3 + // 0 0 0 0 + // 0 0 1 1 + // 0 1 0 1 + // 0 1 1 0 + // 1 0 0 1 + // 1 0 1 0 + // 1 1 0 0 + // 1 1 1 1 + + array::new(320, |i| { + let x = i / 64; + let z = i % 64; + c_prime[i] = xor3( + c[i], + c[((x + 4) % 5) * 64 + z], + c[((x + 1) % 5) * 64 + ((z + 63) % 64)] + ) + }); + + // // Check that the input limbs are consistent with A' and D. + // // A[x, y, z] = xor(A'[x, y, z], D[x, y, z]) + // // = xor(A'[x, y, z], C[x - 1, z], C[x + 1, z - 1]) + // // = xor(A'[x, y, z], C[x, z], C'[x, z]). + // // The last step is valid based on the identity we checked above. + // // It isn't required, but makes this check a bit cleaner. + // for y in 0..5 { + // for x in 0..5 { + // let get_bit = |z| { + // let a_prime: AB::Var = local.a_prime[y][x][z]; + // let c: AB::Var = local.c[x][z]; + // let c_prime: AB::Var = local.c_prime[x][z]; + // xor3_gen::(a_prime.into(), c.into(), c_prime.into()) + // }; + + // for limb in 0..U64_LIMBS { + // let a_limb = local.a[y][x][limb]; + // let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) // bigger address correspond to more significant bit + // .rev() + // .fold(AB::Expr::zero(), |acc, z| { + // builder.assert_bool(local.a_prime[y][x][z]); + // acc.double() + get_bit(z) + // }); + // builder.assert_eq(computed_limb, a_limb); + // } + // } + // } + + let bits_to_value_be: expr[] -> expr = |bits_be| array::fold(bits_be, 0, |acc, e| (acc * 2 + e)); + + array::new(100, |i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + let get_bit: int -> expr = |z| xor3(a_prime[y * 320 + x * 64 + z], c[x * 64 + z], c_prime[x * 64 + z]); + + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_bit(limb * 16 + z))); + a[i] = bits_to_value_be(limb_bits_be) + }); + + // // xor_{i=0}^4 A'[x, i, z] = C'[x, z], so for each x, z, + // // diff * (diff - 2) * (diff - 4) = 0, where + // // diff = sum_{i=0}^4 A'[x, i, z] - C'[x, z] + // for x in 0..5 { + // for z in 0..64 { + // let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].into()).sum(); + // let diff = sum - local.c_prime[x][z]; + // let four = AB::Expr::from_canonical_u8(4); + // builder + // .assert_zero(diff.clone() * (diff.clone() - AB::Expr::two()) * (diff - four)); + // } + // } + + array::new(320, |i| { + let x = i / 64; + let z = i % 64; + let sum = utils::sum(5, |y| a_prime[y * 320 + i]); + let diff = sum - c_prime[i]; + diff * (diff - 2) * (diff - 4) = 0 + }); + + // // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + // for y in 0..5 { + // for x in 0..5 { + // let get_bit = |z| { + // let andn = andn_gen::( + // local.b((x + 1) % 5, y, z).into(), + // local.b((x + 2) % 5, y, z).into(), + // ); + // xor_gen::(local.b(x, y, z).into(), andn) + // }; + + // for limb in 0..U64_LIMBS { + // let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(AB::Expr::zero(), |acc, z| acc.double() + get_bit(z)); + // builder.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]); + // } + // } + // } + + array::new(100, |i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + + let get_bit: int -> expr = |z| { + xor(b(x, y, z), andn(b((x + 1) % 5, y, z), b((x + 2) % 5, y, z))) + }; + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_bit(limb * 16 + z))); + a_prime_prime[i] = bits_to_value_be(limb_bits_be) + }); + + // pub fn b(&self, x: usize, y: usize, z: usize) -> T { + // debug_assert!(x < 5); + // debug_assert!(y < 5); + // debug_assert!(z < 64); + + // // B is just a rotation of A', so these are aliases for A' registers. + // // From the spec, + // // B[y, (2x + 3y) % 5] = ROT(A'[x, y], r[x, y]) + // // So, + // // B[x, y] = f((x + 3y) % 5, x) + // // where f(a, b) = ROT(A'[a, b], r[a, b]) + // let a = (x + 3 * y) % 5; + // let b = x; + // let rot = R[a][b] as usize; + // self.a_prime[b][a][(z + 64 - rot) % 64] + // } + + let b: int, int, int -> expr = |x, y, z| { + let a: int = (x + 3 * y) % 5; + let rot: int = R[a * 5 + x]; // b = x + a_prime[x * 320 + a * 64 + (z + 64 - rot) % 64] + }; + + // // A'''[0, 0] = A''[0, 0] XOR RC + // for limb in 0..U64_LIMBS { + // let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB + // ..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(AB::Expr::zero(), |acc, z| { + // builder.assert_bool(local.a_prime_prime_0_0_bits[z]); + // acc.double() + local.a_prime_prime_0_0_bits[z] + // }); + // let a_prime_prime_0_0_limb = local.a_prime_prime[0][0][limb]; + // builder.assert_eq(computed_a_prime_prime_0_0_limb, a_prime_prime_0_0_limb); + // } + + array::new(4, |limb| { + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| a_prime_prime_0_0_bits[limb * 16 + z])); + a_prime_prime[limb] = bits_to_value_be(limb_bits_be) + }); + + // let get_xored_bit = |i| { + // let mut rc_bit_i = AB::Expr::zero(); + // for r in 0..NUM_ROUNDS { + // let this_round = local.step_flags[r]; + // let this_round_constant = AB::Expr::from_canonical_u8(rc_value_bit(r, i)); + // rc_bit_i += this_round * this_round_constant; + // } + + // xor_gen::(local.a_prime_prime_0_0_bits[i].into(), rc_bit_i) + // }; + + let get_xored_bit: int -> expr = |i| xor(a_prime_prime_0_0_bits[i], utils::sum(NUM_ROUNDS, |r| expr(RC_BITS[r * 64 + i]) * step_flags[r] )); + + // for limb in 0..U64_LIMBS { + // let a_prime_prime_prime_0_0_limb = local.a_prime_prime_prime_0_0_limbs[limb]; + // let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB + // ..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(AB::Expr::zero(), |acc, z| acc.double() + get_xored_bit(z)); + // builder.assert_eq( + // computed_a_prime_prime_prime_0_0_limb, + // a_prime_prime_prime_0_0_limb, + // ); + // } + + array::new(4, |limb| { + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_xored_bit(limb * 16 + z))); + a_prime_prime_prime_0_0_limbs[limb] = bits_to_value_be(limb_bits_be) + }); + + // // Enforce that this round's output equals the next round's input. + // for x in 0..5 { + // for y in 0..5 { + // for limb in 0..U64_LIMBS { + // let output = local.a_prime_prime_prime(y, x, limb); + // let input = next.a[y][x][limb]; + // builder + // .when_transition() + // .when(not_final_step.clone()) + // .assert_eq(output, input); + // } + // } + // } + + // final_step and is_last should never be 1 at the same time, because final_step is 1 at multiples of 24 and can never be 1 at power of 2. + // (1 - final_step - is_last) is used to deactivate constraints that reference the next row, whenever we are at the latch row or the last row of the trace (so that we don't incorrectly cycle to the first row). + array::new(100, |i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + (1 - final_step - is_last) * (a_prime_prime_prime(y, x, limb) - a[i]') = 0 + }); + + // pub fn a_prime_prime_prime(&self, y: usize, x: usize, limb: usize) -> T { + // debug_assert!(y < 5); + // debug_assert!(x < 5); + // debug_assert!(limb < U64_LIMBS); + + // if y == 0 && x == 0 { + // self.a_prime_prime_prime_0_0_limbs[limb] + // } else { + // self.a_prime_prime[y][x][limb] + // } + // } + + let a_prime_prime_prime: int, int, int -> expr = |y, x, limb| if y == 0 && x == 0 { a_prime_prime_prime_0_0_limbs[limb] } else { a_prime_prime[y * 20 + x * 4 + limb] }; + + let R: int[] = [ + 0, 36, 3, 41, 18, + 1, 44, 10, 45, 2, + 62, 6, 43, 15, 61, + 28, 55, 25, 21, 56, + 27, 20, 39, 8, 14 + ]; + + let RC: int[] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008 + ]; + + let RC_BITS: int[] = array::new(24 * 64, |i| { + let rc_idx = i / 64; + let bit = i % 64; + RC[rc_idx] >> bit & 0x1 + }); + + // Prover function section (for witness generation). + + // // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]). + // for x in 0..5 { + // for z in 0..64 { + // let limb = z / BITS_PER_LIMB; + // let bit_in_limb = z % BITS_PER_LIMB; + // let a = (0..5).map(|y| { + // let a_limb = row.a[y][x][limb].as_canonical_u64() as u16; + // ((a_limb >> bit_in_limb) & 1) != 0 + // }); + // row.c[x][z] = F::from_bool(a.fold(false, |acc, x| acc ^ x)); + // } + // } + + let query_c: int, int, int -> int = query |x, limb, bit_in_limb| + utils::fold( + 5, + |y| (int(eval(a[y * 20 + x * 4 + limb])) >> bit_in_limb) & 0x1, + 0, + |acc, e| acc ^ e + ); + + query |row| { + let _ = array::map_enumerated(c, |i, c_i| { + let x = i / 64; + let z = i % 64; + let limb = z / 16; + let bit_in_limb = z % 16; + + provide_value(c_i, row, fe(query_c(x, limb, bit_in_limb))); + }); + }; + + // // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). + // for x in 0..5 { + // for z in 0..64 { + // row.c_prime[x][z] = xor([ + // row.c[x][z], + // row.c[(x + 4) % 5][z], + // row.c[(x + 1) % 5][(z + 63) % 64], + // ]); + // } + // } + + let query_c_prime: int, int -> int = query |x, z| + int(eval(c[x * 64 + z])) ^ + int(eval(c[((x + 4) % 5) * 64 + z])) ^ + int(eval(c[((x + 1) % 5) * 64 + (z + 63) % 64])); + + query |row| { + let _ = array::map_enumerated(c_prime, |i, c_i| { + let x = i / 64; + let z = i % 64; + + provide_value(c_i, row, fe(query_c_prime(x, z))); + }); + }; + + // // Populate A'. To avoid shifting indices, we rewrite + // // A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1]) + // // as + // // A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]). + // for x in 0..5 { + // for y in 0..5 { + // for z in 0..64 { + // let limb = z / BITS_PER_LIMB; + // let bit_in_limb = z % BITS_PER_LIMB; + // let a_limb = row.a[y][x][limb].as_canonical_u64() as u16; + // let a_bit = F::from_bool(((a_limb >> bit_in_limb) & 1) != 0); + // row.a_prime[y][x][z] = xor([a_bit, row.c[x][z], row.c_prime[x][z]]); + // } + // } + // } + + let query_a_prime: int, int, int, int, int -> int = query |x, y, z, limb, bit_in_limb| + ((int(eval(a[y * 20 + x * 4 + limb])) >> bit_in_limb) & 0x1) ^ + int(eval(c[x * 64 + z])) ^ + int(eval(c_prime[x * 64 + z])); + + query |row| { + let _ = array::map_enumerated(a_prime, |i, a_i| { + let y = i / 320; + let x = (i / 64) % 5; + let z = i % 64; + let limb = z / 16; + let bit_in_limb = z % 16; + + provide_value(a_i, row, fe(query_a_prime(x, y, z, limb, bit_in_limb))); + }); + }; + + // // Populate A''.P + // // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + // for y in 0..5 { + // for x in 0..5 { + // for limb in 0..U64_LIMBS { + // row.a_prime_prime[y][x][limb] = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(F::zero(), |acc, z| { + // let bit = xor([ + // row.b(x, y, z), + // andn(row.b((x + 1) % 5, y, z), row.b((x + 2) % 5, y, z)), + // ]); + // acc.double() + bit + // }); + // } + // } + // } + + let query_a_prime_prime: int, int, int -> int = query |x, y, limb| + utils::fold( + 16, + |z| + int(eval(b(x, y, (limb + 1) * 16 - 1 - z))) ^ + int(eval(andn(b((x + 1) % 5, y, (limb + 1) * 16 - 1 - z), + b((x + 2) % 5, y, (limb + 1) * 16 - 1 - z)))), + 0, + |acc, e| acc * 2 + e + ); + + query |row| { + let _ = array::map_enumerated(a_prime_prime, |i, a_i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + + provide_value(a_i, row, fe(query_a_prime_prime(x, y, limb))); + }); + }; + + // // For the XOR, we split A''[0, 0] to bits. + // let mut val = 0; // smaller address correspond to less significant limb + // for limb in 0..U64_LIMBS { + // let val_limb = row.a_prime_prime[0][0][limb].as_canonical_u64(); + // val |= val_limb << (limb * BITS_PER_LIMB); + // } + // let val_bits: Vec = (0..64) // smaller address correspond to less significant bit + // .scan(val, |acc, _| { + // let bit = (*acc & 1) != 0; + // *acc >>= 1; + // Some(bit) + // }) + // .collect(); + // for (i, bit) in row.a_prime_prime_0_0_bits.iter_mut().enumerate() { + // *bit = F::from_bool(val_bits[i]); + // } + + query |row| { + let _ = array::map_enumerated(a_prime_prime_0_0_bits, |i, a_i| { + let limb = i / 16; + let bit_in_limb = i % 16; + + provide_value( + a_i, + row, + fe((int(eval(a_prime_prime[limb])) >> bit_in_limb) & 0x1) + ); + }); + }; + + // // A''[0, 0] is additionally xor'd with RC. + // for limb in 0..U64_LIMBS { + // let rc_lo = rc_value_limb(round, limb); + // row.a_prime_prime_prime_0_0_limbs[limb] = + // F::from_canonical_u16(row.a_prime_prime[0][0][limb].as_canonical_u64() as u16 ^ rc_lo); + // } + + let query_a_prime_prime_prime_0_0_limbs: int, int -> int = query |round, limb| + int(eval(a_prime_prime[limb])) ^ + ((RC[round] >> (limb * 16)) & 0xffff); + + query |row| { + let _ = array::new(4, |limb| { + provide_value( + a_prime_prime_prime_0_0_limbs[limb], + row, + fe(query_a_prime_prime_prime_0_0_limbs(row % NUM_ROUNDS, limb) + )); + }); + }; +} diff --git a/std/machines/hash/mod.asm b/std/machines/hash/mod.asm index f173cc68ab..8c03f8d4b2 100644 --- a/std/machines/hash/mod.asm +++ b/std/machines/hash/mod.asm @@ -1,3 +1,4 @@ mod poseidon_bn254; mod poseidon_gl; mod poseidon_gl_memory; +mod keccakf16; diff --git a/test_data/std/keccakf16_test.asm b/test_data/std/keccakf16_test.asm new file mode 100644 index 0000000000..b449dbdb48 --- /dev/null +++ b/test_data/std/keccakf16_test.asm @@ -0,0 +1,330 @@ +use std::machines::hash::keccakf16::Keccakf16; + +machine Main with degree: 64 { + reg pc[@pc]; + + reg X0[<=]; + reg X1[<=]; + reg X2[<=]; + reg X3[<=]; + reg X4[<=]; + reg X5[<=]; + reg X6[<=]; + reg X7[<=]; + reg X8[<=]; + reg X9[<=]; + reg X10[<=]; + reg X11[<=]; + reg X12[<=]; + reg X13[<=]; + reg X14[<=]; + reg X15[<=]; + reg X16[<=]; + reg X17[<=]; + reg X18[<=]; + reg X19[<=]; + reg X20[<=]; + reg X21[<=]; + reg X22[<=]; + reg X23[<=]; + reg X24[<=]; + reg X25[<=]; + reg X26[<=]; + reg X27[<=]; + reg X28[<=]; + reg X29[<=]; + reg X30[<=]; + reg X31[<=]; + reg X32[<=]; + reg X33[<=]; + reg X34[<=]; + reg X35[<=]; + reg X36[<=]; + reg X37[<=]; + reg X38[<=]; + reg X39[<=]; + reg X40[<=]; + reg X41[<=]; + reg X42[<=]; + reg X43[<=]; + reg X44[<=]; + reg X45[<=]; + reg X46[<=]; + reg X47[<=]; + reg X48[<=]; + reg X49[<=]; + reg X50[<=]; + reg X51[<=]; + reg X52[<=]; + reg X53[<=]; + reg X54[<=]; + reg X55[<=]; + reg X56[<=]; + reg X57[<=]; + reg X58[<=]; + reg X59[<=]; + reg X60[<=]; + reg X61[<=]; + reg X62[<=]; + reg X63[<=]; + reg X64[<=]; + reg X65[<=]; + reg X66[<=]; + reg X67[<=]; + reg X68[<=]; + reg X69[<=]; + reg X70[<=]; + reg X71[<=]; + reg X72[<=]; + reg X73[<=]; + reg X74[<=]; + reg X75[<=]; + reg X76[<=]; + reg X77[<=]; + reg X78[<=]; + reg X79[<=]; + reg X80[<=]; + reg X81[<=]; + reg X82[<=]; + reg X83[<=]; + reg X84[<=]; + reg X85[<=]; + reg X86[<=]; + reg X87[<=]; + reg X88[<=]; + reg X89[<=]; + reg X90[<=]; + reg X91[<=]; + reg X92[<=]; + reg X93[<=]; + reg X94[<=]; + reg X95[<=]; + reg X96[<=]; + reg X97[<=]; + reg X98[<=]; + reg X99[<=]; + + reg Y0[<=]; + reg Y1[<=]; + reg Y2[<=]; + reg Y3[<=]; + reg Y4[<=]; + reg Y5[<=]; + reg Y6[<=]; + reg Y7[<=]; + reg Y8[<=]; + reg Y9[<=]; + reg Y10[<=]; + reg Y11[<=]; + reg Y12[<=]; + reg Y13[<=]; + reg Y14[<=]; + reg Y15[<=]; + reg Y16[<=]; + reg Y17[<=]; + reg Y18[<=]; + reg Y19[<=]; + reg Y20[<=]; + reg Y21[<=]; + reg Y22[<=]; + reg Y23[<=]; + reg Y24[<=]; + reg Y25[<=]; + reg Y26[<=]; + reg Y27[<=]; + reg Y28[<=]; + reg Y29[<=]; + reg Y30[<=]; + reg Y31[<=]; + reg Y32[<=]; + reg Y33[<=]; + reg Y34[<=]; + reg Y35[<=]; + reg Y36[<=]; + reg Y37[<=]; + reg Y38[<=]; + reg Y39[<=]; + reg Y40[<=]; + reg Y41[<=]; + reg Y42[<=]; + reg Y43[<=]; + reg Y44[<=]; + reg Y45[<=]; + reg Y46[<=]; + reg Y47[<=]; + reg Y48[<=]; + reg Y49[<=]; + reg Y50[<=]; + reg Y51[<=]; + reg Y52[<=]; + reg Y53[<=]; + reg Y54[<=]; + reg Y55[<=]; + reg Y56[<=]; + reg Y57[<=]; + reg Y58[<=]; + reg Y59[<=]; + reg Y60[<=]; + reg Y61[<=]; + reg Y62[<=]; + reg Y63[<=]; + reg Y64[<=]; + reg Y65[<=]; + reg Y66[<=]; + reg Y67[<=]; + reg Y68[<=]; + reg Y69[<=]; + reg Y70[<=]; + reg Y71[<=]; + reg Y72[<=]; + reg Y73[<=]; + reg Y74[<=]; + reg Y75[<=]; + reg Y76[<=]; + reg Y77[<=]; + reg Y78[<=]; + reg Y79[<=]; + reg Y80[<=]; + reg Y81[<=]; + reg Y82[<=]; + reg Y83[<=]; + reg Y84[<=]; + reg Y85[<=]; + reg Y86[<=]; + reg Y87[<=]; + reg Y88[<=]; + reg Y89[<=]; + reg Y90[<=]; + reg Y91[<=]; + reg Y92[<=]; + reg Y93[<=]; + reg Y94[<=]; + reg Y95[<=]; + reg Y96[<=]; + reg Y97[<=]; + reg Y98[<=]; + reg Y99[<=]; + + reg A0; + reg A1; + reg A2; + reg A3; + reg A4; + reg A5; + reg A6; + reg A7; + reg A8; + reg A9; + reg A10; + reg A11; + reg A12; + reg A13; + reg A14; + reg A15; + reg A16; + reg A17; + reg A18; + reg A19; + reg A20; + reg A21; + reg A22; + reg A23; + reg A24; + reg A25; + reg A26; + reg A27; + reg A28; + reg A29; + reg A30; + reg A31; + reg A32; + reg A33; + reg A34; + reg A35; + reg A36; + reg A37; + reg A38; + reg A39; + reg A40; + reg A41; + reg A42; + reg A43; + reg A44; + reg A45; + reg A46; + reg A47; + reg A48; + reg A49; + reg A50; + reg A51; + reg A52; + reg A53; + reg A54; + reg A55; + reg A56; + reg A57; + reg A58; + reg A59; + reg A60; + reg A61; + reg A62; + reg A63; + reg A64; + reg A65; + reg A66; + reg A67; + reg A68; + reg A69; + reg A70; + reg A71; + reg A72; + reg A73; + reg A74; + reg A75; + reg A76; + reg A77; + reg A78; + reg A79; + reg A80; + reg A81; + reg A82; + reg A83; + reg A84; + reg A85; + reg A86; + reg A87; + reg A88; + reg A89; + reg A90; + reg A91; + reg A92; + reg A93; + reg A94; + reg A95; + reg A96; + reg A97; + reg A98; + reg A99; + + Keccakf16 keccakf16; + + instr keccakf16 X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99 -> Y0, Y1, Y2, Y3, Y4, Y5, Y6, Y7, Y8, Y9, Y10, Y11, Y12, Y13, Y14, Y15, Y16, Y17, Y18, Y19, Y20, Y21, Y22, Y23, Y24, Y25, Y26, Y27, Y28, Y29, Y30, Y31, Y32, Y33, Y34, Y35, Y36, Y37, Y38, Y39, Y40, Y41, Y42, Y43, Y44, Y45, Y46, Y47, Y48, Y49, Y50, Y51, Y52, Y53, Y54, Y55, Y56, Y57, Y58, Y59, Y60, Y61, Y62, Y63, Y64, Y65, Y66, Y67, Y68, Y69, Y70, Y71, Y72, Y73, Y74, Y75, Y76, Y77, Y78, Y79, Y80, Y81, Y82, Y83, Y84, Y85, Y86, Y87, Y88, Y89, Y90, Y91, Y92, Y93, Y94, Y95, Y96, Y97, Y98, Y99 link => (Y0, Y1, Y2, Y3, Y4, Y5, Y6, Y7, Y8, Y9, Y10, Y11, Y12, Y13, Y14, Y15, Y16, Y17, Y18, Y19, Y20, Y21, Y22, Y23, Y24, Y25, Y26, Y27, Y28, Y29, Y30, Y31, Y32, Y33, Y34, Y35, Y36, Y37, Y38, Y39, Y40, Y41, Y42, Y43, Y44, Y45, Y46, Y47, Y48, Y49, Y50, Y51, Y52, Y53, Y54, Y55, Y56, Y57, Y58, Y59, Y60, Y61, Y62, Y63, Y64, Y65, Y66, Y67, Y68, Y69, Y70, Y71, Y72, Y73, Y74, Y75, Y76, Y77, Y78, Y79, Y80, Y81, Y82, Y83, Y84, Y85, Y86, Y87, Y88, Y89, Y90, Y91, Y92, Y93, Y94, Y95, Y96, Y97, Y98, Y99) = keccakf16.keccakf16(X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99); + + instr assert_eq X0, X1 { + X0 = X1 + } + + function main { + // 0 for all 25 64-bit inputs except setting the second 64-bit input to 1. All 64-bit inputs in chunks of 4 16-bit little endian limbs. + A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, A30, A31, A32, A33, A34, A35, A36, A37, A38, A39, A40, A41, A42, A43, A44, A45, A46, A47, A48, A49, A50, A51, A52, A53, A54, A55, A56, A57, A58, A59, A60, A61, A62, A63, A64, A65, A66, A67, A68, A69, A70, A71, A72, A73, A74, A75, A76, A77, A78, A79, A80, A81, A82, A83, A84, A85, A86, A87, A88, A89, A90, A91, A92, A93, A94, A95, A96, A97, A98, A99 <== keccakf16(0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + // Selectively checking a few registers only. + assert_eq A0, 0x405f; + assert_eq A3, 0xfdbb; + assert_eq A92, 0x8f6e; + assert_eq A95, 0x3e10; + assert_eq A96, 0xeb35; + assert_eq A99, 0xeac9; + + return; + } +} From 64e7bf140269c7f1694166519517a70ece8068fb Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Tue, 24 Sep 2024 14:54:20 +0200 Subject: [PATCH 16/16] Refactor asm ast (#1736) --- airgen/src/lib.rs | 64 +++-- analysis/src/machine_check.rs | 39 +-- analysis/src/vm/batcher.rs | 13 +- analysis/src/vm/inference.rs | 42 +-- asm-to-pil/src/lib.rs | 83 +++--- asm-to-pil/src/romgen.rs | 16 +- asm-to-pil/src/vm_to_constrained.rs | 56 +--- ast/src/asm_analysis/display.rs | 162 ++++++----- ast/src/asm_analysis/mod.rs | 127 ++++++--- ast/src/object/display.rs | 26 +- ast/src/object/mod.rs | 3 +- ast/src/parsed/asm.rs | 34 +-- ast/src/parsed/display.rs | 84 +++--- ast/src/parsed/folder.rs | 34 +-- ast/src/parsed/types.rs | 56 ++++ ast/src/parsed/visitor.rs | 2 + backend/src/estark/starky_wrapper.rs | 2 +- importer/src/path_canonicalizer.rs | 252 +++++++++++------- importer/src/powdr_std.rs | 13 +- importer/test_data/instruction.expected.asm | 8 +- importer/test_data/trait_implementation.asm | 26 ++ .../trait_implementation.expected.asm | 19 ++ linker/src/lib.rs | 67 ++--- parser/src/powdr.lalrpop | 26 +- parser/src/test_utils.rs | 13 +- pil-analyzer/src/expression_processor.rs | 1 + pil-analyzer/src/pil_analyzer.rs | 16 +- riscv-executor/src/lib.rs | 17 +- riscv/src/continuations.rs | 4 +- riscv/src/continuations/bootloader.rs | 2 + 30 files changed, 679 insertions(+), 628 deletions(-) create mode 100644 importer/test_data/trait_implementation.asm create mode 100644 importer/test_data/trait_implementation.expected.asm diff --git a/airgen/src/lib.rs b/airgen/src/lib.rs index a8ffdc8457..3cab99d4ce 100644 --- a/airgen/src/lib.rs +++ b/airgen/src/lib.rs @@ -5,10 +5,8 @@ use std::collections::BTreeMap; use powdr_ast::{ - asm_analysis::{self, combine_flags, AnalysisASMFile, Item, LinkDefinition}, - object::{ - Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph, TypeOrExpression, - }, + asm_analysis::{self, combine_flags, AnalysisASMFile, LinkDefinition}, + object::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph}, parsed::{ asm::{parse_absolute_path, AbsoluteSymbolPath, CallableRef, MachineParams}, Expression, PilStatement, @@ -47,7 +45,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph { main, entry_points: Default::default(), objects: [(main_location, Default::default())].into(), - definitions: utility_functions(input), + statements: utility_functions(input), }; } // if there is a single machine, treat it as main @@ -55,7 +53,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph { // otherwise, use the machine called `MAIN` _ => { let p = parse_absolute_path(MAIN_MACHINE); - assert!(input.items.contains_key(&p)); + assert!(input.get_machine(&p).is_some()); p } }; @@ -67,7 +65,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph { let mut instances = BTreeMap::default(); while let Some((location, ty, args)) = queue.pop() { - let machine = input.items.get(&ty).unwrap().try_to_machine().unwrap(); + let machine = &input.get_machine(&ty).unwrap(); queue.extend(machine.submachines.iter().map(|def| { ( @@ -127,9 +125,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph { } } - let Item::Machine(main_ty) = input.items.get(&main_ty).unwrap() else { - panic!() - }; + let main_ty = &input.get_machine(&main_ty).unwrap(); let main = powdr_ast::object::Machine { location: main_location, @@ -150,7 +146,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph { main, entry_points, objects, - definitions: utility_functions(input), + statements: utility_functions(input), } } @@ -178,14 +174,32 @@ fn resolve_submachine_arg( } } -fn utility_functions(asm_file: AnalysisASMFile) -> BTreeMap { +fn utility_functions(asm_file: AnalysisASMFile) -> BTreeMap> { asm_file - .items + .modules .into_iter() - .filter_map(|(n, v)| match v { - Item::Expression(e) => Some((n, TypeOrExpression::Expression(e))), - Item::TypeDeclaration(type_decl) => Some((n, TypeOrExpression::Type(type_decl))), - _ => None, + .map(|(module_path, module)| { + ( + module_path, + module + .into_inner() + .1 + .into_iter() + .filter(|s| match s { + PilStatement::EnumDeclaration(..) | PilStatement::LetStatement(..) => true, + PilStatement::Include(..) => false, + PilStatement::Namespace(..) => false, + PilStatement::PolynomialDefinition(..) => false, + PilStatement::PublicDeclaration(..) => false, + PilStatement::PolynomialConstantDeclaration(..) => false, + PilStatement::PolynomialConstantDefinition(..) => false, + PilStatement::PolynomialCommitDeclaration(..) => false, + PilStatement::TraitImplementation(..) => false, + PilStatement::TraitDeclaration(..) => false, + PilStatement::Expression(..) => false, + }) + .collect(), + ) }) .collect() } @@ -205,7 +219,7 @@ struct ASMPILConverter<'a> { /// Current machine instance location: &'a Location, /// Input definitions and machines. - items: &'a BTreeMap, + input: &'a AnalysisASMFile, /// Pil statements generated for the machine pil: Vec, /// Submachine instances accessible to the machine (includes those passed as a parameter) @@ -224,7 +238,7 @@ impl<'a> ASMPILConverter<'a> { Self { instances, location, - items: &input.items, + input, pil: Default::default(), submachines: Default::default(), incoming_permutations, @@ -247,9 +261,7 @@ impl<'a> ASMPILConverter<'a> { fn convert_machine_inner(mut self) -> Object { let (ty, args) = self.instances.get(self.location).as_ref().unwrap(); // TODO: This clone doubles the current memory usage - let Item::Machine(input) = self.items.get(ty).unwrap().clone() else { - panic!(); - }; + let input = self.input.get_machine(ty).unwrap().clone(); let degree = input.degree; @@ -331,9 +343,7 @@ impl<'a> ASMPILConverter<'a> { panic!("could not find submachine named `{instance}` in machine `{ty}`"); }); // get the machine type from the machine map - let Item::Machine(instance_ty) = self.items.get(&instance.ty).unwrap() else { - panic!(); - }; + let instance_ty = &self.input.get_machine(&instance.ty).unwrap(); // check that the operation exists and that it has the same number of inputs/outputs as the link let operation = instance_ty @@ -513,8 +523,8 @@ impl<'a> ASMPILConverter<'a> { for (param, value) in params.iter().zip(values) { let ty = AbsoluteSymbolPath::default().join(param.ty.clone().unwrap()); - match self.items.get(&ty) { - Some(Item::Machine(_)) => self.submachines.push(SubmachineRef { + match self.input.get_machine(&ty) { + Some(_) => self.submachines.push(SubmachineRef { location: value.clone(), name: param.name.clone(), ty, diff --git a/analysis/src/machine_check.rs b/analysis/src/machine_check.rs index e8a48f95f8..32f67a4600 100644 --- a/analysis/src/machine_check.rs +++ b/analysis/src/machine_check.rs @@ -6,7 +6,7 @@ use powdr_ast::{ asm_analysis::{ AnalysisASMFile, AssignmentStatement, CallableSymbolDefinitions, DebugDirective, FunctionBody, FunctionStatements, FunctionSymbol, InstructionDefinitionStatement, - InstructionStatement, Item, LabelStatement, LinkDefinition, Machine, MachineDegree, + InstructionStatement, LabelStatement, LinkDefinition, Machine, MachineDegree, Module, OperationSymbol, RegisterDeclarationStatement, RegisterTy, Return, SubmachineDeclaration, }, parsed::{ @@ -23,10 +23,8 @@ use powdr_ast::{ /// Also transfers generic PIL definitions but does not verify anything about them. pub fn check(file: ASMProgram) -> Result> { let ctx = AbsoluteSymbolPath::default(); - let machines = TypeChecker::default().check_module(file.main, &ctx)?; - Ok(AnalysisASMFile { - items: machines.into_iter().collect(), - }) + let modules = TypeChecker::default().check_module(file.main, &ctx)?; + Ok(AnalysisASMFile { modules }) } #[derive(Default)] @@ -312,10 +310,11 @@ impl TypeChecker { &mut self, module: ASMModule, ctx: &AbsoluteSymbolPath, - ) -> Result, Vec> { + ) -> Result, Vec> { let mut errors = vec![]; - let mut res: BTreeMap = BTreeMap::default(); + let mut checked_module = Module::default(); + let mut res = BTreeMap::default(); for m in module.statements { match m { @@ -327,7 +326,7 @@ impl TypeChecker { errors.extend(e); } Ok(machine) => { - res.insert(ctx.with_part(&name), Item::Machine(machine)); + checked_module.push_machine(name, machine); } }; } @@ -343,6 +342,8 @@ impl TypeChecker { asm::Module::Local(m) => m, }; + checked_module.push_module(name); + match self.check_module(m, &ctx) { Err(err) => { errors.extend(err); @@ -352,29 +353,17 @@ impl TypeChecker { } }; } - asm::SymbolValue::Expression(e) => { - res.insert(ctx.clone().with_part(&name), Item::Expression(e)); - } - asm::SymbolValue::TypeDeclaration(enum_decl) => { - res.insert( - ctx.clone().with_part(&name), - Item::TypeDeclaration(enum_decl), - ); - } - asm::SymbolValue::TraitDeclaration(trait_decl) => { - res.insert( - ctx.clone().with_part(&name), - Item::TraitDeclaration(trait_decl), - ); - } } } - ModuleStatement::TraitImplementation(trait_impl) => { - res.insert(ctx.clone(), Item::TraitImplementation(trait_impl)); + ModuleStatement::PilStatement(s) => { + checked_module.push_pil_statement(s); } } } + // add this module to the map of modules found inside it + res.insert(ctx.clone(), checked_module); + if !errors.is_empty() { Err(errors) } else { diff --git a/analysis/src/vm/batcher.rs b/analysis/src/vm/batcher.rs index c3d131a8cf..35a7f917b3 100644 --- a/analysis/src/vm/batcher.rs +++ b/analysis/src/vm/batcher.rs @@ -3,8 +3,7 @@ use itertools::Itertools; use powdr_ast::{ asm_analysis::{ - AnalysisASMFile, BatchMetadata, FunctionStatement, Incompatible, IncompatibleSet, Item, - Machine, + AnalysisASMFile, BatchMetadata, FunctionStatement, Incompatible, IncompatibleSet, Machine, }, parsed::asm::AbsoluteSymbolPath, }; @@ -132,14 +131,8 @@ impl RomBatcher { } pub fn batch(&mut self, mut asm_file: AnalysisASMFile) -> AnalysisASMFile { - for (name, machine) in asm_file.items.iter_mut().filter_map(|(n, m)| match m { - Item::Machine(m) => Some((n, m)), - Item::Expression(_) - | Item::TypeDeclaration(_) - | Item::TraitDeclaration(_) - | Item::TraitImplementation(_) => None, - }) { - self.extract_batches(name, machine); + for (name, machine) in asm_file.machines_mut() { + self.extract_batches(&name, machine); } asm_file diff --git a/analysis/src/vm/inference.rs b/analysis/src/vm/inference.rs index b78d5b0936..a1825734a4 100644 --- a/analysis/src/vm/inference.rs +++ b/analysis/src/vm/inference.rs @@ -1,41 +1,29 @@ //! Infer assignment registers in asm statements use powdr_ast::{ - asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Item, Machine}, + asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Machine}, parsed::asm::AssignmentRegister, }; -pub fn infer(file: AnalysisASMFile) -> Result> { +pub fn infer(mut file: AnalysisASMFile) -> Result> { let mut errors = vec![]; - let items = file - .items - .into_iter() - .filter_map(|(name, m)| match m { - Item::Machine(m) => match infer_machine(m) { - Ok(m) => Some((name, Item::Machine(m))), - Err(e) => { - errors.extend(e); - None - } - }, - Item::Expression(e) => Some((name, Item::Expression(e))), - Item::TypeDeclaration(enum_decl) => Some((name, Item::TypeDeclaration(enum_decl))), - Item::TraitImplementation(trait_impl) => { - Some((name, Item::TraitImplementation(trait_impl))) + file.machines_mut() + .for_each(|(_, m)| match infer_machine(m) { + Ok(()) => {} + Err(e) => { + errors.extend(e); } - Item::TraitDeclaration(trait_decl) => Some((name, Item::TraitDeclaration(trait_decl))), - }) - .collect(); + }); if !errors.is_empty() { Err(errors) } else { - Ok(AnalysisASMFile { items }) + Ok(file) } } -fn infer_machine(mut machine: Machine) -> Result> { +fn infer_machine(machine: &mut Machine) -> Result<(), Vec> { let mut errors = vec![]; for f in machine.callable.functions_mut() { @@ -96,7 +84,7 @@ fn infer_machine(mut machine: Machine) -> Result> { if !errors.is_empty() { Err(errors) } else { - Ok(machine) + Ok(()) } } @@ -127,9 +115,7 @@ mod tests { let file = infer_str(file).unwrap(); - let machine = &file.items[&parse_absolute_path("::Machine")] - .try_to_machine() - .unwrap(); + let machine = &file.get_machine(&parse_absolute_path("::Machine")).unwrap(); if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = machine .functions() .next() @@ -168,9 +154,7 @@ mod tests { let file = infer_str(file).unwrap(); - let machine = &file.items[&parse_absolute_path("::Machine")] - .try_to_machine() - .unwrap(); + let machine = &file.get_machine(&parse_absolute_path("::Machine")).unwrap(); if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = &machine .functions() .next() diff --git a/asm-to-pil/src/lib.rs b/asm-to-pil/src/lib.rs index 12dcc32579..e87ee30c56 100644 --- a/asm-to-pil/src/lib.rs +++ b/asm-to-pil/src/lib.rs @@ -1,6 +1,8 @@ #![deny(clippy::print_stdout)] -use powdr_ast::asm_analysis::{AnalysisASMFile, Item, SubmachineDeclaration}; +use std::collections::BTreeMap; + +use powdr_ast::asm_analysis::{AnalysisASMFile, Module, StatementReference, SubmachineDeclaration}; use powdr_number::FieldElement; use romgen::generate_machine_rom; use vm_to_constrained::ROM_SUBMACHINE_NAME; @@ -10,45 +12,58 @@ mod vm_to_constrained; pub const ROM_SUFFIX: &str = "ROM"; -/// Remove all ASM from the machine tree. Takes a tree of virtual or constrained machines and returns a tree of constrained machines -pub fn compile(file: AnalysisASMFile) -> AnalysisASMFile { - AnalysisASMFile { - items: file - .items +/// Remove all ASM from the machine tree, leaving only constrained machines +pub fn compile(mut file: AnalysisASMFile) -> AnalysisASMFile { + for (path, module) in &mut file.modules { + let mut new_machines = BTreeMap::default(); + let (mut machines, statements, ordering) = std::mem::take(module).into_inner(); + let ordering = ordering .into_iter() - .flat_map(|(name, m)| match m { - Item::Machine(m) => { - let (m, rom) = generate_machine_rom::(m); - let (mut m, rom_machine) = vm_to_constrained::convert_machine::(m, rom); - - match rom_machine { - // in the absence of ROM, simply return the machine - None => vec![(name, Item::Machine(m))], - Some(rom_machine) => { - // introduce a new name for the ROM machine, based on the original name - let mut rom_name = name.clone(); - let machine_name = rom_name.pop().unwrap(); - rom_name.push(format!("{machine_name}{ROM_SUFFIX}")); - - // add the ROM as a submachine - m.submachines.push(SubmachineDeclaration { - name: ROM_SUBMACHINE_NAME.into(), - ty: rom_name.clone(), - args: vec![], - }); - - // return both the machine and the rom - vec![ - (name, Item::Machine(m)), - (rom_name, Item::Machine(rom_machine)), - ] + .flat_map(|r| { + match r { + StatementReference::MachineDeclaration(name) => { + let m = machines.remove(&name).unwrap(); + let (m, rom) = generate_machine_rom::(m); + let (mut m, rom_machine) = vm_to_constrained::convert_machine::(m, rom); + + match rom_machine { + // in the absence of ROM, simply return the machine + None => { + new_machines.insert(name.clone(), m); + vec![name] + } + Some(rom_machine) => { + // introduce a new name for the ROM machine, based on the original name + let rom_name = format!("{name}{ROM_SUFFIX}"); + let mut ty = path.clone(); + ty.push(rom_name.clone()); + + // add the ROM as a submachine + m.submachines.push(SubmachineDeclaration { + name: ROM_SUBMACHINE_NAME.into(), + ty, + args: vec![], + }); + + new_machines.insert(name.clone(), m); + new_machines.insert(rom_name.clone(), rom_machine); + + // return both the machine and the rom + vec![name, rom_name] + } } + .into_iter() + .map(StatementReference::MachineDeclaration) + .collect() } + r => vec![r], } - item => vec![(name, item)], }) - .collect(), + .collect(); + machines.extend(new_machines); + *module = Module::new(machines, statements, ordering); } + file } pub mod utils { diff --git a/asm-to-pil/src/romgen.rs b/asm-to-pil/src/romgen.rs index 29979eefba..3bc840dc2c 100644 --- a/asm-to-pil/src/romgen.rs +++ b/asm-to-pil/src/romgen.rs @@ -250,10 +250,7 @@ pub fn generate_machine_rom(mut machine: Machine) -> (Machine, mod tests { use std::collections::BTreeMap; - use powdr_ast::{ - asm_analysis::Item, - parsed::asm::{parse_absolute_path, AbsoluteSymbolPath}, - }; + use powdr_ast::parsed::asm::{parse_absolute_path, AbsoluteSymbolPath}; use powdr_number::Bn254Field; use pretty_assertions::assert_eq; @@ -264,15 +261,8 @@ mod tests { let parsed = powdr_parser::parse_asm(None, src).unwrap(); let checked = powdr_analysis::machine_check::check(parsed).unwrap(); checked - .items - .into_iter() - .filter_map(|(name, m)| match m { - Item::Machine(m) => Some((name, generate_machine_rom::(m))), - Item::Expression(_) - | Item::TypeDeclaration(_) - | Item::TraitDeclaration(_) - | Item::TraitImplementation(_) => None, - }) + .into_machines() + .map(|(name, m)| (name, generate_machine_rom::(m))) .collect() } diff --git a/asm-to-pil/src/vm_to_constrained.rs b/asm-to-pil/src/vm_to_constrained.rs index 070c3ff685..75aa33426c 100644 --- a/asm-to-pil/src/vm_to_constrained.rs +++ b/asm-to-pil/src/vm_to_constrained.rs @@ -1323,14 +1323,7 @@ fn try_extract_update(expr: &Expression) -> Option<(String, Expression)> { #[cfg(test)] mod test { - use powdr_ast::{ - asm_analysis::{AnalysisASMFile, Item}, - parsed::{ - asm::{parse_absolute_path, Part, SymbolPath}, - types::{FunctionType, Type}, - TraitDeclaration, - }, - }; + use powdr_ast::asm_analysis::AnalysisASMFile; use powdr_importer::load_dependencies_and_resolve_str; use powdr_number::{FieldElement, GoldilocksField}; @@ -1363,51 +1356,4 @@ machine Main { "; parse_analyze_and_compile::(asm); } - - #[test] - fn trait_parsing() { - let asm = r" - mod types { - enum DoubleOpt { - None, - Some(T, T) - } - - trait ArraySum { - array_sum: T[4 + 1] -> DoubleOpt, - } - } - - machine Empty { - col witness w; - w = w * w; - } - "; - - let analyzed = parse_analyze_and_compile::(asm); - let arraysum = parse_absolute_path("::types::ArraySum"); - let trait_decl = analyzed.items.get(&arraysum).unwrap(); - if let Item::TraitDeclaration(TraitDeclaration { functions, .. }) = trait_decl { - assert_eq!(functions.len(), 1); - let func_ty = &functions.iter().next().unwrap().ty; - match func_ty { - Type::Function(FunctionType { value, .. }) => { - assert_eq!( - value.as_ref(), - &Type::NamedType( - SymbolPath::from_parts( - ["types", "DoubleOpt"] - .iter() - .map(|arg| Part::Named(arg.to_string())) - ), - Some(vec![Type::TypeVar("T".to_string())]) - ) - ); - } - _ => panic!("Expected function type"), - } - } else { - panic!("Expected trait declaration"); - } - } } diff --git a/ast/src/asm_analysis/display.rs b/ast/src/asm_analysis/display.rs index d36cd49c17..ecf4e133f9 100644 --- a/ast/src/asm_analysis/display.rs +++ b/ast/src/asm_analysis/display.rs @@ -6,74 +6,59 @@ use std::{ use itertools::Itertools; use crate::{ - asm_analysis::combine_flags, + asm_analysis::{combine_flags, Module, StatementReference}, indent, - parsed::{ - asm::{AbsoluteSymbolPath, Part}, - display::format_type_scheme_around_name, - TypedExpression, - }, - write_indented_by, write_items_indented, + parsed::asm::{AbsoluteSymbolPath, SymbolPath}, + write_indented_by, write_items_indented, writeln_indented_by, }; use super::{ AnalysisASMFile, AssignmentStatement, CallableSymbol, CallableSymbolDefinitionRef, DebugDirective, FunctionBody, FunctionStatement, FunctionStatements, Incompatible, - IncompatibleSet, InstructionDefinitionStatement, InstructionStatement, Item, LabelStatement, + IncompatibleSet, InstructionDefinitionStatement, InstructionStatement, LabelStatement, LinkDefinition, Machine, MachineDegree, RegisterDeclarationStatement, RegisterTy, Return, Rom, SubmachineDeclaration, }; impl Display for AnalysisASMFile { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - let mut current_path = AbsoluteSymbolPath::default(); - - for (path, item) in &self.items { - let relative_path = path.relative_to(¤t_path); - let name = relative_path.name(); - // Skip the name (last) part - for part in relative_path.parts().rev().skip(1).rev() { - match part { - Part::Super => { - current_path.pop(); - write_indented_by(f, "}\n", current_path.len())?; - } - Part::Named(m) => { - write_indented_by(f, format!("mod {m} {{\n"), current_path.len())?; - current_path.push(m.clone()); - } - } - } + write_module(f, self, &AbsoluteSymbolPath::default(), 0) + } +} - match item { - Item::Machine(machine) => { - write_indented_by(f, format!("machine {name}{machine}"), current_path.len())?; - } - Item::Expression(TypedExpression { e, type_scheme }) => write_indented_by( - f, - format!( - "let{} = {e};\n", - format_type_scheme_around_name(name, type_scheme) - ), - current_path.len(), - )?, - Item::TypeDeclaration(enum_decl) => { - write_indented_by(f, enum_decl, current_path.len())? - } - Item::TraitImplementation(trait_impl) => { - write_indented_by(f, trait_impl, current_path.len())? - } - Item::TraitDeclaration(trait_decl) => { - write_indented_by(f, trait_decl, current_path.len())? - } - } - } - for i in (0..current_path.len()).rev() { - write_indented_by(f, "}\n", i)?; - } +fn write_module( + f: &mut Formatter<'_>, + file: &AnalysisASMFile, + module_path: &AbsoluteSymbolPath, + indentation: usize, +) -> Result { + let module: &Module = &file.modules[module_path]; + let mut pil = module.statements.iter(); - Ok(()) + for r in &module.ordering { + match r { + StatementReference::MachineDeclaration(name) => write_indented_by( + f, + format!("machine {name}{}", module.machines[name]), + indentation, + ), + StatementReference::Pil => { + writeln_indented_by(f, format!("{}", pil.next().unwrap()), indentation) + } + StatementReference::Module(name) => { + let path = module_path + .clone() + .join(SymbolPath::from_identifier(name.to_string())); + writeln_indented_by(f, format!("mod {name} {{"), indentation)?; + write_module(f, file, &path, indentation + 1)?; + writeln_indented_by(f, "}", indentation) + } + }?; } + + assert!(pil.next().is_none()); + + Ok(()) } impl Display for MachineDegree { @@ -321,23 +306,74 @@ impl Display for IncompatibleSet { #[cfg(test)] mod test { use super::*; - use crate::parsed::asm::parse_absolute_path; + use crate::{ + asm_analysis::{Module, StatementReference}, + parsed::asm::parse_absolute_path, + }; use pretty_assertions::assert_eq; #[test] fn display_asm_analysis_file() { let file = AnalysisASMFile { - items: [ - "::x::Y", - "::x::r::T", - "::x::f::Y", - "::M", - "::t::x::y::R", - "::t::F", - "::X", + modules: [ + ( + "::", + vec![ + StatementReference::MachineDeclaration("M".into()), + StatementReference::MachineDeclaration("X".into()), + StatementReference::Module("t".into()), + StatementReference::Module("x".into()), + ], + ), + ( + "::t", + vec![ + StatementReference::MachineDeclaration("F".into()), + StatementReference::Module("x".into()), + ], + ), + ( + "::x", + vec![ + StatementReference::MachineDeclaration("Y".into()), + StatementReference::Module("f".into()), + StatementReference::Module("r".into()), + ], + ), + ("::t::x", vec![StatementReference::Module("y".into())]), + ( + "::t::x::y", + vec![StatementReference::MachineDeclaration("R".into())], + ), + ( + "::x::f", + vec![StatementReference::MachineDeclaration("Y".into())], + ), + ( + "::x::r", + vec![StatementReference::MachineDeclaration("T".into())], + ), ] .into_iter() - .map(|s| (parse_absolute_path(s), Item::Machine(Machine::default()))) + .map(|(path, ordering)| { + ( + parse_absolute_path(path), + Module { + machines: ordering + .iter() + .filter_map(|r| match r { + StatementReference::MachineDeclaration(name) => { + Some((name.clone(), Machine::default())) + } + StatementReference::Pil => unimplemented!(), + StatementReference::Module(_) => None, + }) + .collect(), + ordering, + ..Default::default() + }, + ) + }) .collect(), }; assert_eq!( diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index d485931a6e..bc8ca67253 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -19,8 +19,7 @@ use crate::parsed::{ MachineParams, OperationId, OperationParams, }, visitor::{ExpressionVisitable, VisitOrder}, - EnumDeclaration, NamespacedPolynomialReference, PilStatement, TraitDeclaration, - TraitImplementation, TypedExpression, + NamespacedPolynomialReference, PilStatement, }; pub use crate::parsed::Expression; @@ -674,29 +673,6 @@ pub struct SubmachineDeclaration { pub args: Vec, } -/// An item that is part of the module tree after all modules, -/// imports and references have been resolved. -#[derive(Clone, Debug)] -pub enum Item { - Machine(Machine), - Expression(TypedExpression), - TypeDeclaration(EnumDeclaration), - TraitImplementation(TraitImplementation), - TraitDeclaration(TraitDeclaration), -} - -impl Item { - pub fn try_to_machine(&self) -> Option<&Machine> { - match self { - Item::Machine(m) => Some(m), - Item::Expression(_) - | Item::TypeDeclaration(_) - | Item::TraitImplementation(_) - | Item::TraitDeclaration(_) => None, - } - } -} - #[derive(Default, Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub struct MachineDegree { pub min: Option, @@ -830,28 +806,97 @@ pub struct Rom { #[derive(Default, Clone, Debug)] pub struct AnalysisASMFile { - pub items: BTreeMap, + pub modules: BTreeMap, } - impl AnalysisASMFile { - pub fn machines(&self) -> impl Iterator { - self.items.iter().filter_map(|(n, m)| match m { - Item::Machine(m) => Some((n, m)), - Item::Expression(_) - | Item::TypeDeclaration(_) - | Item::TraitDeclaration(_) - | Item::TraitImplementation(_) => None, + pub fn machines_mut(&mut self) -> impl Iterator { + self.modules.iter_mut().flat_map(|(module_path, module)| { + module.machines.iter_mut().map(move |(name, machine)| { + let mut machine_path = module_path.clone(); + machine_path.push(name.to_string()); + (machine_path, machine) + }) }) } - pub fn machines_mut(&mut self) -> impl Iterator { - self.items.iter_mut().filter_map(|(n, m)| match m { - Item::Machine(m) => Some((n, m)), - Item::Expression(_) - | Item::TypeDeclaration(_) - | Item::TraitDeclaration(_) - | Item::TraitImplementation(_) => None, + + pub fn machines(&self) -> impl Iterator { + self.modules.iter().flat_map(|(module_path, module)| { + module.machines.iter().map(move |(name, machine)| { + let mut machine_path = module_path.clone(); + machine_path.push(name.to_string()); + (machine_path, machine) + }) + }) + } + + pub fn into_machines(self) -> impl Iterator { + self.modules.into_iter().flat_map(|(module_path, module)| { + module.machines.into_iter().map(move |(name, machine)| { + let mut machine_path = module_path.clone(); + machine_path.push(name.to_string()); + (machine_path, machine) + }) }) } + + pub fn get_machine(&self, ty: &AbsoluteSymbolPath) -> Option<&Machine> { + let mut path = ty.clone(); + let name = path.pop().unwrap(); + self.modules[&path].machines.get(&name) + } +} + +#[derive(Clone, Debug)] +pub enum StatementReference { + MachineDeclaration(String), + Pil, + Module(String), +} + +#[derive(Default, Clone, Debug)] +pub struct Module { + machines: BTreeMap, + statements: Vec, + ordering: Vec, +} + +impl Module { + pub fn new( + machines: BTreeMap, + statements: Vec, + ordering: Vec, + ) -> Self { + Self { + machines, + statements, + ordering, + } + } + + pub fn push_machine(&mut self, name: String, machine: Machine) { + self.machines.insert(name.clone(), machine); + self.ordering + .push(StatementReference::MachineDeclaration(name)); + } + + pub fn push_pil_statement(&mut self, s: PilStatement) { + self.statements.push(s); + self.ordering.push(StatementReference::Pil); + } + + pub fn push_module(&mut self, name: String) { + self.ordering.push(StatementReference::Module(name)); + } + + pub fn into_inner( + self, + ) -> ( + BTreeMap, + Vec, + Vec, + ) { + (self.machines, self.statements, self.ordering) + } } #[derive(Default, Debug, Clone)] diff --git a/ast/src/object/display.rs b/ast/src/object/display.rs index f728b1b305..8470f907e5 100644 --- a/ast/src/object/display.rs +++ b/ast/src/object/display.rs @@ -1,13 +1,8 @@ use std::fmt::{Display, Formatter, Result}; -use crate::{ - asm_analysis::combine_flags, - parsed::{display::format_type_scheme_around_name, TypedExpression}, -}; +use crate::{asm_analysis::combine_flags, write_items_indented}; -use super::{ - Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph, TypeOrExpression, -}; +use super::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph}; impl Display for Location { fn fmt(&self, f: &mut Formatter<'_>) -> Result { @@ -18,19 +13,10 @@ impl Display for Location { impl Display for PILGraph { fn fmt(&self, f: &mut Formatter<'_>) -> Result { writeln!(f, "// Utilities")?; - for (name, utility) in &self.definitions { - match utility { - TypeOrExpression::Expression(TypedExpression { e, type_scheme }) => { - writeln!( - f, - "let{} = {e};", - format_type_scheme_around_name(&name.to_string(), type_scheme) - )?; - } - TypeOrExpression::Type(enum_decl) => { - writeln!(f, "{enum_decl}",)?; - } - } + for (module_path, statements) in &self.statements { + writeln!(f, "mod {module_path} {{")?; + write_items_indented(f, statements)?; + writeln!(f, "}}")?; } for (location, object) in &self.objects { writeln!(f, "// Object {location}")?; diff --git a/ast/src/object/mod.rs b/ast/src/object/mod.rs index f8f78f52c9..e5fdd70ad9 100644 --- a/ast/src/object/mod.rs +++ b/ast/src/object/mod.rs @@ -44,7 +44,8 @@ pub struct PILGraph { pub main: Machine, pub entry_points: Vec, pub objects: BTreeMap, - pub definitions: BTreeMap, + /// PIL utilities by module path + pub statements: BTreeMap>, } #[derive(Clone)] diff --git a/ast/src/parsed/asm.rs b/ast/src/parsed/asm.rs index 7d7e4d51b3..8274ad5213 100644 --- a/ast/src/parsed/asm.rs +++ b/ast/src/parsed/asm.rs @@ -15,8 +15,8 @@ use serde::{Deserialize, Serialize}; use crate::parsed::{BinaryOperation, BinaryOperator}; use super::{ - visitor::Children, EnumDeclaration, EnumVariant, Expression, PilStatement, SourceReference, - TraitDeclaration, TraitImplementation, TypedExpression, + types::TypeScheme, visitor::Children, EnumDeclaration, EnumVariant, Expression, PilStatement, + SourceReference, TraitDeclaration, }; #[derive(Default, Clone, Debug, PartialEq, Eq)] @@ -29,26 +29,19 @@ pub struct ASMModule { pub statements: Vec, } -impl ASMModule { - pub fn symbol_definitions(&self) -> impl Iterator { - self.statements.iter().filter_map(|s| match s { - ModuleStatement::SymbolDefinition(d) => Some(d), - ModuleStatement::TraitImplementation(_) => None, - }) - } -} - #[derive(Debug, Clone, PartialEq, Eq, From)] pub enum ModuleStatement { SymbolDefinition(SymbolDefinition), - TraitImplementation(TraitImplementation), + PilStatement(PilStatement), } impl ModuleStatement { - pub fn defined_names(&self) -> Option<&String> { + pub fn defined_names(&self) -> Box + '_> { match self { - ModuleStatement::SymbolDefinition(d) => Some(&d.name), - ModuleStatement::TraitImplementation(_) => None, + ModuleStatement::SymbolDefinition(d) => Box::new(once(&d.name)), + ModuleStatement::PilStatement(s) => { + Box::new(s.symbol_definition_names().map(|(name, _)| name)) + } } } } @@ -67,12 +60,6 @@ pub enum SymbolValue { Import(Import), /// A module definition Module(Module), - /// A generic symbol / function. - Expression(TypedExpression), - /// A type declaration (currently only enums) - TypeDeclaration(EnumDeclaration), - /// A trait declaration - TraitDeclaration(TraitDeclaration), } impl SymbolValue { @@ -81,9 +68,6 @@ impl SymbolValue { SymbolValue::Machine(machine) => SymbolValueRef::Machine(machine), SymbolValue::Import(i) => SymbolValueRef::Import(i), SymbolValue::Module(m) => SymbolValueRef::Module(m.as_ref()), - SymbolValue::Expression(e) => SymbolValueRef::Expression(e), - SymbolValue::TypeDeclaration(t) => SymbolValueRef::TypeDeclaration(t), - SymbolValue::TraitDeclaration(t) => SymbolValueRef::TraitDeclaration(t), } } } @@ -97,7 +81,7 @@ pub enum SymbolValueRef<'a> { /// A module definition Module(ModuleRef<'a>), /// A generic symbol / function. - Expression(&'a TypedExpression), + Expression(&'a Option, &'a Option>), /// A type declaration (currently only enums) TypeDeclaration(&'a EnumDeclaration), /// A type constructor of an enum. diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index b13908122c..5cf0ce4c34 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use crate::{ indent, parsed::{BinaryOperator, UnaryOperator}, - write_indented_by, write_items, write_items_indented, + write_indented_by, write_items, write_items_indented, writeln_indented, }; use self::types::{ArrayType, FunctionType, TupleType, TypeBounds}; @@ -14,7 +14,13 @@ use super::{asm::*, *}; impl Display for PILFile { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - write_items(f, &self.0) + for statement in &self.0 { + match statement { + PilStatement::Namespace(..) => writeln!(f, "{statement}")?, + _ => writeln_indented(f, statement.to_string())?, + } + } + Ok(()) } } @@ -34,7 +40,7 @@ impl Display for ModuleStatement { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { ModuleStatement::SymbolDefinition(symbol_def) => write!(f, "{symbol_def}"), - ModuleStatement::TraitImplementation(trait_impl) => write!(f, "{trait_impl}"), + ModuleStatement::PilStatement(s) => write!(f, "{s}"), } } } @@ -55,15 +61,6 @@ impl Display for SymbolDefinition { SymbolValue::Module(m @ Module::Local(_)) => { write!(f, "mod {name} {m}") } - SymbolValue::Expression(TypedExpression { e, type_scheme }) => { - write!( - f, - "let{} = {e};", - format_type_scheme_around_name(name, type_scheme) - ) - } - SymbolValue::TypeDeclaration(ty) => write!(f, "{ty}"), - SymbolValue::TraitDeclaration(trait_decl) => write!(f, "{trait_decl}"), } } } @@ -478,56 +475,47 @@ impl Display for PilStatement { } } } - PilStatement::LetStatement(_, pattern, type_scheme, value) => write_indented_by( + PilStatement::LetStatement(_, pattern, type_scheme, value) => write!( f, - format!( - "let{}{};", - format_type_scheme_around_name(pattern, type_scheme), - value - .as_ref() - .map(|value| format!(" = {value}")) - .unwrap_or_default() - ), - 1, + "let{}{};", + format_type_scheme_around_name(pattern, type_scheme), + value + .as_ref() + .map(|value| format!(" = {value}")) + .unwrap_or_default() ), PilStatement::PolynomialDefinition(_, name, value) => { - write_indented_by(f, format!("pol {name} = {value};"), 1) + write!(f, "pol {name} = {value};") } PilStatement::PublicDeclaration(_, name, poly, array_index, index) => { - write_indented_by( + write!( f, - format!( - "public {name} = {poly}{}({index});", - array_index - .as_ref() - .map(|i| format!("[{i}]")) - .unwrap_or_default() - ), - 1, + "public {name} = {poly}{}({index});", + array_index + .as_ref() + .map(|i| format!("[{i}]")) + .unwrap_or_default() ) } PilStatement::PolynomialConstantDeclaration(_, names) => { - write_indented_by(f, format!("pol constant {};", names.iter().format(", ")), 1) + write!(f, "pol constant {};", names.iter().format(", ")) } PilStatement::PolynomialConstantDefinition(_, name, definition) => { - write_indented_by(f, format!("pol constant {name}{definition};"), 1) + write!(f, "pol constant {name}{definition};") } - PilStatement::PolynomialCommitDeclaration(_, stage, names, value) => write_indented_by( + PilStatement::PolynomialCommitDeclaration(_, stage, names, value) => write!( f, - format!( - "pol commit {}{}{};", - stage - .and_then(|s| (s > 0).then(|| format!("stage({s}) "))) - .unwrap_or_default(), - names.iter().format(", "), - value.as_ref().map(|v| format!("{v}")).unwrap_or_default() - ), - 1, + "pol commit {}{}{};", + stage + .and_then(|s| (s > 0).then(|| format!("stage({s}) "))) + .unwrap_or_default(), + names.iter().format(", "), + value.as_ref().map(|v| format!("{v}")).unwrap_or_default() ), - PilStatement::Expression(_, e) => write_indented_by(f, format!("{e};"), 1), - PilStatement::EnumDeclaration(_, enum_decl) => write_indented_by(f, enum_decl, 1), - PilStatement::TraitImplementation(_, trait_impl) => write_indented_by(f, trait_impl, 1), - PilStatement::TraitDeclaration(_, trait_decl) => write_indented_by(f, trait_decl, 1), + PilStatement::Expression(_, e) => write!(f, "{e};"), + PilStatement::EnumDeclaration(_, enum_decl) => write!(f, "{enum_decl}"), + PilStatement::TraitImplementation(_, trait_impl) => write!(f, "{trait_impl}"), + PilStatement::TraitDeclaration(_, trait_decl) => write!(f, "{trait_decl}"), } } } diff --git a/ast/src/parsed/folder.rs b/ast/src/parsed/folder.rs index 0015c76989..c465c0f256 100644 --- a/ast/src/parsed/folder.rs +++ b/ast/src/parsed/folder.rs @@ -3,7 +3,7 @@ use super::{ ASMModule, ASMProgram, Import, Machine, Module, ModuleStatement, SymbolDefinition, SymbolValue, }, - EnumDeclaration, Expression, TraitDeclaration, TraitImplementation, + PilStatement, }; pub trait Folder { @@ -24,18 +24,9 @@ pub trait Folder { SymbolValue::Machine(machine) => self.fold_machine(machine).map(From::from), SymbolValue::Import(import) => self.fold_import(import).map(From::from), SymbolValue::Module(module) => self.fold_module(module).map(From::from), - SymbolValue::Expression(e) => Ok(SymbolValue::Expression(e)), - SymbolValue::TypeDeclaration(ty) => { - self.fold_type_declaration(ty).map(From::from) - } - SymbolValue::TraitDeclaration(trait_decl) => { - self.fold_trait_declaration(trait_decl).map(From::from) - } } .map(|value| ModuleStatement::SymbolDefinition(SymbolDefinition { value, ..d })), - ModuleStatement::TraitImplementation(trait_impl) => { - self.fold_trait_implementation(trait_impl).map(From::from) - } + ModuleStatement::PilStatement(s) => self.fold_pil_statement(s).map(From::from), }) .collect::, _>>()?; @@ -57,24 +48,7 @@ pub trait Folder { Ok(import) } - fn fold_type_declaration( - &mut self, - ty: EnumDeclaration, - ) -> Result, Self::Error> { - Ok(ty) - } - - fn fold_trait_implementation( - &mut self, - trait_impl: TraitImplementation, - ) -> Result, Self::Error> { - Ok(trait_impl) - } - - fn fold_trait_declaration( - &mut self, - trait_decl: TraitDeclaration, - ) -> Result, Self::Error> { - Ok(trait_decl) + fn fold_pil_statement(&mut self, statement: PilStatement) -> Result { + Ok(statement) } } diff --git a/ast/src/parsed/types.rs b/ast/src/parsed/types.rs index 6b086beb07..484a5ac001 100644 --- a/ast/src/parsed/types.rs +++ b/ast/src/parsed/types.rs @@ -136,6 +136,62 @@ impl Type { Type::Tuple(TupleType { items: vec![] }) } } + +impl Type { + pub fn contained_expressions_mut(&mut self) -> Box + '_> { + match self { + Type::Array(ArrayType { base, length }) => Box::new( + length + .as_mut() + .and_then(|l| l.try_to_expression_mut()) + .into_iter() + .chain(base.contained_expressions_mut()), + ), + t => Box::new(t.children_mut().flat_map(|t| t.contained_expressions_mut())), + } + } + + pub fn contained_expressions(&self) -> Box + '_> { + match self { + Type::Array(ArrayType { base, length }) => Box::new( + length + .as_ref() + .and_then(|l| l.try_to_expression()) + .into_iter() + .chain(base.contained_expressions()), + ), + t => Box::new(t.children().flat_map(|t| t.contained_expressions())), + } + } +} + +/// A trait to operate over the possible types for the array type lengths +pub trait ExpressionInArrayLength: std::fmt::Display + std::fmt::Debug { + fn try_to_expression_mut(&mut self) -> Option<&mut Expression>; + + fn try_to_expression(&self) -> Option<&Expression>; +} + +impl ExpressionInArrayLength for Expression { + fn try_to_expression_mut(&mut self) -> Option<&mut Expression> { + Some(self) + } + + fn try_to_expression(&self) -> Option<&Expression> { + Some(self) + } +} + +impl ExpressionInArrayLength for u64 { + fn try_to_expression_mut(&mut self) -> Option<&mut Expression> { + None + } + + fn try_to_expression(&self) -> Option<&Expression> { + None + } +} + impl Type { /// Substitutes all occurrences of the given type variables with the given types. /// Does not apply the substitutions inside the replacements. diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index 10a04f07bc..cec85e55cc 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -3,6 +3,7 @@ use std::{iter, ops::ControlFlow}; use super::Expression; /// Generic trait that allows to iterate over sub-structures. +/// /// It is only meant to iterate non-recursively over the direct children. /// Self and O do not have to be the same type and we can also have /// Children and Children implemented for the same type, @@ -27,6 +28,7 @@ pub enum VisitOrder { } /// A trait to be implemented by an AST node. +/// /// The idea is that it calls a callback function on each of the sub-nodes /// that are expressions. /// The difference to the Children trait is that ExpressionVisitable diff --git a/backend/src/estark/starky_wrapper.rs b/backend/src/estark/starky_wrapper.rs index 908b48e583..fca0329a88 100644 --- a/backend/src/estark/starky_wrapper.rs +++ b/backend/src/estark/starky_wrapper.rs @@ -181,7 +181,7 @@ impl EStark { impl Backend for EStark { fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { let proof: StarkProof = - serde_json::from_str(&String::from_utf8(proof.to_vec()).unwrap()).unwrap(); + serde_json::from_str(core::str::from_utf8(proof).unwrap()).unwrap(); self.verify_stark_gl_with_publics(&proof, instances) } diff --git a/importer/src/path_canonicalizer.rs b/importer/src/path_canonicalizer.rs index 5f7fdaec66..d049d5e23c 100644 --- a/importer/src/path_canonicalizer.rs +++ b/importer/src/path_canonicalizer.rs @@ -12,12 +12,12 @@ use powdr_ast::parsed::{ SymbolValue, SymbolValueRef, }, folder::Folder, - types::{Type, TypeScheme}, + types::{ExpressionInArrayLength, Type, TypeScheme}, visitor::{Children, ExpressionVisitable}, ArrayLiteral, BinaryOperation, BlockExpression, EnumDeclaration, EnumVariant, Expression, FunctionCall, IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, NamedType, Pattern, PilStatement, StatementInsideBlock, TraitDeclaration, - TypedExpression, UnaryOperation, + UnaryOperation, }; use powdr_parser_util::{Error, SourceRef}; @@ -79,47 +79,16 @@ impl<'a> Folder for Canonicalizer<'a> { .map(Some) .transpose(), }, - SymbolValue::Expression(mut exp) => { - if let Some(type_scheme) = &mut exp.type_scheme { - canonicalize_inside_type_scheme( - type_scheme, - &self.path, - self.paths, - ); - } - canonicalize_inside_expression(&mut exp.e, &self.path, self.paths); - Some(Ok(SymbolValue::Expression(exp))) - } - SymbolValue::TypeDeclaration(mut enum_decl) => { - let type_vars = enum_decl.type_vars.vars().collect(); - for variant in &mut enum_decl.variants { - if let Some(fields) = &mut variant.fields { - for field in fields { - canonicalize_inside_type( - field, &type_vars, &self.path, self.paths, - ); - } - } - } - Some(Ok(SymbolValue::TypeDeclaration(enum_decl))) - } - SymbolValue::TraitDeclaration(mut trait_decl) => { - let type_vars = trait_decl.type_vars.iter().collect(); - for f in &mut trait_decl.functions { - canonicalize_inside_type( - &mut f.ty, &type_vars, &self.path, self.paths, - ); - } - Some(Ok(SymbolValue::TraitDeclaration(trait_decl))) - } } .map(|value| value.map(|value| SymbolDefinition { name, value }.into())) } - ModuleStatement::TraitImplementation(mut trait_impl) => { - for f in trait_impl.children_mut() { - canonicalize_inside_expression(f, &self.path, self.paths) - } - Some(Ok(ModuleStatement::TraitImplementation(trait_impl))) + ModuleStatement::PilStatement(mut pil_statement) => { + canonicalize_inside_pil_statement( + &mut pil_statement, + &self.path, + self.paths, + ); + Some(Ok(ModuleStatement::PilStatement(pil_statement))) } }) .collect::>()?, @@ -263,6 +232,46 @@ fn free_inputs_in_expression_mut<'a>( } } +fn canonicalize_inside_pil_statement( + statement: &mut PilStatement, + path: &AbsoluteSymbolPath, + paths: &'_ PathMap, +) { + match statement { + PilStatement::LetStatement(_, _, type_scheme, e) => { + if let Some(type_scheme) = type_scheme { + canonicalize_inside_type_scheme(type_scheme, path, paths); + } + if let Some(e) = e { + canonicalize_inside_expression(e, path, paths); + } + } + PilStatement::EnumDeclaration(_, enum_decl) => { + let type_vars = enum_decl.type_vars.vars().collect(); + for variant in &mut enum_decl.variants { + if let Some(fields) = &mut variant.fields { + for field in fields { + canonicalize_inside_type(field, &type_vars, path, paths); + } + } + } + } + PilStatement::TraitImplementation(_, trait_impl) => { + canonicalize_inside_type_scheme(&mut trait_impl.type_scheme, path, paths); + for f in trait_impl.children_mut() { + canonicalize_inside_expression(f, path, paths) + } + } + PilStatement::TraitDeclaration(_, trait_decl) => { + let type_vars = trait_decl.type_vars.iter().collect(); + for f in &mut trait_decl.functions { + canonicalize_inside_type(&mut f.ty, &type_vars, path, paths); + } + } + _ => unreachable!("unexpected at module level, make this enum more strict"), + } +} + fn canonicalize_inside_expression( e: &mut Expression, path: &AbsoluteSymbolPath, @@ -330,8 +339,8 @@ fn canonicalize_inside_pattern( } } -fn canonicalize_inside_type_scheme( - type_scheme: &mut TypeScheme, +fn canonicalize_inside_type_scheme( + type_scheme: &mut TypeScheme, path: &AbsoluteSymbolPath, paths: &'_ PathMap, ) { @@ -343,19 +352,23 @@ fn canonicalize_inside_type_scheme( ); } -fn canonicalize_inside_type( - ty: &mut Type, +fn canonicalize_inside_type( + ty: &mut Type, type_vars: &HashSet<&String>, path: &AbsoluteSymbolPath, paths: &'_ PathMap, ) { + // replace type vars recursively ty.map_to_type_vars(type_vars); + + // canonicalize names recursively for p in ty.contained_named_types_mut() { let abs = paths.get(&path.clone().join(p.clone())).unwrap(); *p = abs.relative_to(&Default::default()).clone(); } - for tne in ty.children_mut() { + // canonicalize contained expressions recursively + for tne in ty.contained_expressions_mut() { canonicalize_inside_expression(tne, path, paths); } } @@ -465,27 +478,42 @@ fn check_path_internal<'a>( match value { // machines, expressions and enum variants do not expose symbols SymbolValueRef::Machine(_) - | SymbolValueRef::Expression(_) + | SymbolValueRef::Expression(_, _) | SymbolValueRef::TypeConstructor(_) | SymbolValueRef::TraitDeclaration(_) => { Err(format!("symbol not found in `{location}`: `{member}`")) } // modules expose symbols SymbolValueRef::Module(ModuleRef::Local(module)) => module - .symbol_definitions() - .find_map(|SymbolDefinition { name, value }| { - (name == member).then_some(value) + .statements + .iter() + .find_map(|s| match s { + ModuleStatement::SymbolDefinition(SymbolDefinition { name, value }) => { + (name == member).then_some(value.as_ref()) + } + // some pil statements introduce names + ModuleStatement::PilStatement(s) => match s { + PilStatement::EnumDeclaration(_, d) => { + (d.name == member).then_some(SymbolValueRef::TypeDeclaration(d)) + } + PilStatement::LetStatement(_, name, type_scheme, e) => (name + == member) + .then_some(SymbolValueRef::Expression(e, type_scheme)), + PilStatement::TraitDeclaration(_, d) => (d.name == member) + .then_some(SymbolValueRef::TraitDeclaration(d)), + _s => None, + }, }) .ok_or_else(|| format!("symbol not found in `{location}`: `{member}`")) .and_then(|symbol| { match symbol { - SymbolValue::Import(p) => { + SymbolValueRef::Import(p) => { // if we found an import, check it and continue from there check_path_internal(location.join(p.path.clone()), state, chain) } symbol => { // if we found any other symbol, continue from there - Ok((location.with_part(member), symbol.as_ref(), chain)) + Ok((location.with_part(member), symbol, chain)) } } }), @@ -565,53 +593,83 @@ fn check_module( state: &mut State<'_>, ) -> Result<(), Error> { module - .symbol_definitions() - .try_fold( - BTreeSet::default(), - |mut acc, SymbolDefinition { name, .. }| { - // TODO we should store source refs in symbol definitions. - acc.insert(name.clone()) - .then_some(acc) - .ok_or(format!("Duplicate name `{name}` in module `{location}`")) - }, - ) + .statements + .iter() + .flat_map(|s| s.defined_names()) + .try_fold(BTreeSet::default(), |mut acc, name| { + // TODO we should store source refs in symbol definitions. + acc.insert(name.clone()) + .then_some(acc) + .ok_or(format!("Duplicate name `{name}` in module `{location}`")) + }) .map_err(|e| SourceRef::default().with_error(e))?; - for SymbolDefinition { name, value } in module.symbol_definitions() { + for statement in &module.statements { // start with the initial state // update the state - match value { - SymbolValue::Machine(machine) => { - check_machine(location.with_part(name), machine, state)?; - } - SymbolValue::Module(module) => { - let m = match module { - Module::External(_) => unreachable!(), - Module::Local(m) => m, - }; - check_module(location.with_part(name), m, state)?; + match statement { + ModuleStatement::PilStatement(p) => { + check_pil_statement_inside_module(location.clone(), p, state)?; } - SymbolValue::Import(s) => check_import(location.clone(), s.clone(), state) - .map_err(|e| SourceRef::default().with_error(e))?, - SymbolValue::Expression(TypedExpression { e, type_scheme }) => { - if let Some(type_scheme) = type_scheme { - check_type_scheme(&location, type_scheme, state, &Default::default())?; + ModuleStatement::SymbolDefinition(SymbolDefinition { name, value }) => match value { + SymbolValue::Machine(machine) => { + check_machine(location.with_part(name), machine, state)?; } - let type_vars = type_scheme - .as_ref() - .map(|ts| ts.vars.vars().collect()) - .unwrap_or_default(); - check_expression(&location, e, state, &type_vars, &HashSet::default())? + SymbolValue::Module(module) => { + let m = match module { + Module::External(_) => unreachable!(), + Module::Local(m) => m, + }; + check_module(location.with_part(name), m, state)?; + } + SymbolValue::Import(s) => check_import(location.clone(), s.clone(), state) + .map_err(|e| SourceRef::default().with_error(e))?, + }, + } + } + Ok(()) +} + +fn check_pil_statement_inside_module( + location: AbsoluteSymbolPath, + s: &PilStatement, + state: &mut State<'_>, +) -> Result<(), Error> { + match s { + PilStatement::LetStatement(_, _, type_scheme, e) => { + if let Some(type_scheme) = type_scheme { + check_type_scheme(&location, type_scheme, state, &Default::default())?; } - SymbolValue::TypeDeclaration(enum_decl) => { - check_type_declaration(&location, enum_decl, state)? + let type_vars = type_scheme + .as_ref() + .map(|ts| ts.vars.vars().collect()) + .unwrap_or_default(); + if let Some(e) = e { + check_expression(&location, e, state, &type_vars, &HashSet::default())?; } - SymbolValue::TraitDeclaration(trait_decl) => { - check_trait_declaration(&location, trait_decl, state)? + Ok(()) + } + PilStatement::EnumDeclaration(_, enum_decl) => { + check_type_declaration(&location, enum_decl, state) + } + PilStatement::TraitImplementation(_, trait_impl) => { + check_type_scheme( + &location, + &trait_impl.type_scheme, + state, + &Default::default(), + )?; + let type_vars = trait_impl.type_scheme.vars.vars().collect::>(); + for f in &trait_impl.functions { + check_expression(&location, &f.body, state, &type_vars, &Default::default())?; } + Ok(()) + } + PilStatement::TraitDeclaration(_, trait_decl) => { + check_trait_declaration(&location, trait_decl, state) } + s => unreachable!("the parser should not produce statement {s} inside a module"), } - Ok(()) } /// Checks a machine, checking the paths it contains, in particular paths to the types of submachines @@ -925,9 +983,9 @@ fn check_type_declaration( .try_for_each(|ty| check_type(location, ty, state, &type_vars, &Default::default())) } -fn check_type_scheme( +fn check_type_scheme( location: &AbsoluteSymbolPath, - type_scheme: &TypeScheme, + type_scheme: &TypeScheme, state: &mut State<'_>, local_variables: &HashSet, ) -> Result<(), Error> { @@ -941,16 +999,13 @@ fn check_type_scheme( ) } -fn check_type( +fn check_type( location: &AbsoluteSymbolPath, - ty: &Type, + ty: &Type, state: &mut State<'_>, type_vars: &HashSet<&String>, local_variables: &HashSet, -) -> Result<(), Error> -where - Type: Children, -{ +) -> Result<(), Error> { for p in ty.contained_named_types() { if let Some(id) = p.try_to_identifier() { if type_vars.contains(id) { @@ -960,7 +1015,7 @@ where check_path_try_prelude(location.clone(), p.clone(), state) .map_err(|e| SourceRef::unknown().with_error(e))?; } - ty.children() + ty.contained_expressions() .try_for_each(|e| check_expression(location, e, state, type_vars, local_variables)) } @@ -1149,4 +1204,9 @@ mod tests { fn degree_not_found() { expect("degree_not_found", Err("symbol not found in `::`: `N`")) } + + #[test] + fn trait_implementation() { + expect("trait_implementation", Ok(())) + } } diff --git a/importer/src/powdr_std.rs b/importer/src/powdr_std.rs index d9972cd36b..d1ecc4910a 100644 --- a/importer/src/powdr_std.rs +++ b/importer/src/powdr_std.rs @@ -91,18 +91,9 @@ impl Folder for StdAdder { StdAdder::fold_import(self, import).map(From::from) } SymbolValue::Module(module) => self.fold_module(module).map(From::from), - SymbolValue::Expression(e) => Ok(SymbolValue::Expression(e)), - SymbolValue::TypeDeclaration(ty) => { - self.fold_type_declaration(ty).map(From::from) - } - SymbolValue::TraitDeclaration(trait_decl) => { - self.fold_trait_declaration(trait_decl).map(From::from) - } } .map(|value| ModuleStatement::SymbolDefinition(SymbolDefinition { value, ..d })), - ModuleStatement::TraitImplementation(trait_impl) => { - self.fold_trait_implementation(trait_impl).map(From::from) - } + ModuleStatement::PilStatement(pil) => self.fold_pil_statement(pil).map(From::from), }) .collect::, _>>()?; @@ -110,7 +101,7 @@ impl Folder for StdAdder { // (E.g. the main module) let has_std = statements .iter() - .filter_map(|m| m.defined_names()) + .flat_map(|m| m.defined_names()) .any(|n| n == "std"); if !has_std { diff --git a/importer/test_data/instruction.expected.asm b/importer/test_data/instruction.expected.asm index d9e90acff6..a235c2af2b 100644 --- a/importer/test_data/instruction.expected.asm +++ b/importer/test_data/instruction.expected.asm @@ -1,15 +1,15 @@ let identity: expr -> expr = |expr| expr; machine Id { operation id<0> x, y; - pol commit x; - pol commit y; - x = y; + pol commit x; + pol commit y; + x = y; } machine Main { ::Id id; reg pc[@pc]; reg X[<=]; reg Y[<=]; - instr id X, l: label -> Y link => X = id.id(identity(l)) link => Y = id.id(identity(Y)){ Y = identity(X) } + instr id X, l: label -> Y link => X = id.id(identity(l)) link => Y = id.id(identity(Y)){ Y = identity(X) } link => X = id.id(identity(X)); } diff --git a/importer/test_data/trait_implementation.asm b/importer/test_data/trait_implementation.asm new file mode 100644 index 0000000000..14e5916f9b --- /dev/null +++ b/importer/test_data/trait_implementation.asm @@ -0,0 +1,26 @@ +mod other { + enum E0 { + + } + enum E1 { + A, + } + enum E2 { + A + } +} + +use other::E0; +use other::E1; +use other::E2; + +trait Foo { + foo: E1 -> E2, +} + +impl Foo { + foo: |e| { + e = E1::A; + other::E2::A + } +} \ No newline at end of file diff --git a/importer/test_data/trait_implementation.expected.asm b/importer/test_data/trait_implementation.expected.asm new file mode 100644 index 0000000000..d24bd4e4fe --- /dev/null +++ b/importer/test_data/trait_implementation.expected.asm @@ -0,0 +1,19 @@ +mod other { + enum E0 { + } + enum E1 { + A, + } + enum E2 { + A, + } +} +trait Foo { + foo: other::E1 -> other::E2, +} +impl Foo { + foo: |e| { + e = other::E1::A; + other::E2::A + }, +} diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 43f2272f4b..73ab93ca6a 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -4,17 +4,15 @@ use lazy_static::lazy_static; use powdr_analysis::utils::parse_pil_statement; use powdr_ast::{ asm_analysis::{combine_flags, MachineDegree}, - object::{Link, Location, PILGraph, TypeOrExpression}, + object::{Link, Location, PILGraph}, parsed::{ asm::{AbsoluteSymbolPath, SymbolPath}, build::{index_access, lookup, namespaced_reference, permutation, selected}, - ArrayLiteral, Expression, NamespaceDegree, PILFile, PilStatement, TypedExpression, + ArrayLiteral, Expression, NamespaceDegree, PILFile, PilStatement, }, }; use powdr_parser_util::SourceRef; -use std::collections::BTreeMap; - -use itertools::Itertools; +use std::{collections::BTreeMap, iter::once}; const MAIN_OPERATION_NAME: &str = "main"; /// The log of the default minimum degree @@ -59,7 +57,7 @@ pub fn link(graph: PILGraph) -> Result> { .degree .clone(); - let mut pil = process_definitions(graph.definitions); + let mut pil = process_definitions(graph.statements); for (location, object) in graph.objects.into_iter() { // create a namespace for this object @@ -110,49 +108,26 @@ pub fn link(graph: PILGraph) -> Result> { // Extract the utilities and sort them into namespaces where possible. fn process_definitions( - definitions: BTreeMap, + mut definitions: BTreeMap>, ) -> Vec { - let mut current_namespace = Default::default(); - definitions - .into_iter() - .sorted_by_cached_key(|(namespace, _)| { - let mut namespace = namespace.clone(); - let name = namespace.pop(); - // Group by namespace and then sort by name. - (namespace, name) - }) - .flat_map(|(mut namespace, type_or_expr)| { - let name = namespace.pop().unwrap(); - let statement = match type_or_expr { - TypeOrExpression::Expression(TypedExpression { e, type_scheme }) => { - PilStatement::LetStatement( - SourceRef::unknown(), - name.to_string(), - type_scheme, - Some(e), - ) - } - TypeOrExpression::Type(enum_decl) => { - PilStatement::EnumDeclaration(SourceRef::unknown(), enum_decl) - } - }; - - // If there is a namespace change, insert a namespace statement. - if current_namespace != namespace { - current_namespace = namespace.clone(); - vec![ - PilStatement::Namespace( + // definitions at the root do not require a namespace statement, so we put them first + let root = definitions.remove(&Default::default()); + + root.into_iter() + .flatten() + .chain( + definitions + .into_iter() + .flat_map(|(module_path, statements)| { + once(PilStatement::Namespace( SourceRef::unknown(), - namespace.relative_to(&AbsoluteSymbolPath::default()), + module_path.relative_to(&Default::default()), None, - ), - statement, - ] - } else { - vec![statement] - } - }) - .collect::>() + )) + .chain(statements) + }), + ) + .collect() } fn process_link(link: Link) -> PilStatement { diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index d59b2ae08e..a9d3976365 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -29,16 +29,7 @@ pub ASMModule: ASMModule = { ModuleStatement: ModuleStatement = { => ModuleStatement::SymbolDefinition(<>), - => ModuleStatement::SymbolDefinition(<>), - => ModuleStatement::SymbolDefinition(SymbolDefinition { - name: <>.name.clone(), - value: SymbolValue::TypeDeclaration(<>), - }), - => ModuleStatement::TraitImplementation(<>), - => ModuleStatement::SymbolDefinition(SymbolDefinition { - name: <>.name.clone(), - value: SymbolValue::TraitDeclaration(<>), - }), + => ModuleStatement::PilStatement(<>), => ModuleStatement::SymbolDefinition(<>), => ModuleStatement::SymbolDefinition(<>), } @@ -100,16 +91,15 @@ TypeSymbolPathPart: Part = { => Part::Named(name), } -LetStatementAtModuleLevel: SymbolDefinition = { - "let" "=" ";" => - SymbolDefinition { - name: name.0, - value: SymbolValue::Expression(TypedExpression{ e: value, type_scheme: name.1 }) - } -} - // ---------------------------- PIL part ----------------------------- +pub PilStatementAtModuleLevel = { + LetStatement, + => PilStatement::EnumDeclaration(ctx.source_ref(start, end), decl), + => PilStatement::TraitImplementation(ctx.source_ref(start, end), impl_), + => PilStatement::TraitDeclaration(ctx.source_ref(start, end), decl), +} + pub PilStatement = { Include, Namespace, diff --git a/parser/src/test_utils.rs b/parser/src/test_utils.rs index 1156c014d1..3e9269d142 100644 --- a/parser/src/test_utils.rs +++ b/parser/src/test_utils.rs @@ -71,17 +71,10 @@ impl ClearSourceRefs for ModuleStatement { .for_each(ClearSourceRefs::clear_source_refs); } SymbolValue::Module(Module::External(_)) | SymbolValue::Import(_) => {} - SymbolValue::Expression(e) => e.e.clear_source_refs(), - SymbolValue::TypeDeclaration(decl) => decl - .children_mut() - .for_each(ClearSourceRefs::clear_source_refs), - SymbolValue::TraitDeclaration(trait_decl) => trait_decl - .children_mut() - .for_each(ClearSourceRefs::clear_source_refs), }, - ModuleStatement::TraitImplementation(trait_impl) => trait_impl - .children_mut() - .for_each(ClearSourceRefs::clear_source_refs), + ModuleStatement::PilStatement(s) => { + s.clear_source_refs(); + } } } } diff --git a/pil-analyzer/src/expression_processor.rs b/pil-analyzer/src/expression_processor.rs index d4cc46c47e..4648ad5670 100644 --- a/pil-analyzer/src/expression_processor.rs +++ b/pil-analyzer/src/expression_processor.rs @@ -18,6 +18,7 @@ use std::{ use crate::{type_processor::TypeProcessor, AnalysisDriver}; /// The ExpressionProcessor turns parsed expressions into analyzed expressions. +/// /// Its main job is to resolve references: /// It turns simple references into fully namespaced references and resolves local function variables. pub struct ExpressionProcessor<'a, D: AnalysisDriver> { diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index 3f7b229e82..43aefab8e5 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -187,12 +187,18 @@ impl PILAnalyzer { let missing_symbols = module .statements .into_iter() - .filter_map(|s| match s { - ModuleStatement::SymbolDefinition(s) => missing_symbols - .contains(&s.name.as_str()) - .then_some(format!("{s}")), - ModuleStatement::TraitImplementation(_) => None, + .filter_map(|s| { + match &s { + ModuleStatement::SymbolDefinition(s) => { + missing_symbols.contains(&s.name.as_str()) + } + ModuleStatement::PilStatement(s) => s + .symbol_definition_names() + .any(|(name, _)| missing_symbols.contains(&name.as_str())), + } + .then_some(vec![format!("{s}")]) }) + .flatten() .join("\n"); parse(None, &format!("namespace std::prelude;\n{missing_symbols}")).unwrap() }) diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index 8521ee3359..f0d8f4f106 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -19,11 +19,10 @@ use builder::TraceBuilder; use itertools::Itertools; use powdr_ast::{ - asm_analysis::{ - AnalysisASMFile, CallableSymbol, FunctionStatement, Item, LabelStatement, Machine, - }, + asm_analysis::{AnalysisASMFile, CallableSymbol, FunctionStatement, LabelStatement, Machine}, parsed::{ - asm::DebugDirective, BinaryOperation, Expression, FunctionCall, Number, UnaryOperation, + asm::{parse_absolute_path, DebugDirective}, + BinaryOperation, Expression, FunctionCall, Number, UnaryOperation, }, }; use powdr_number::{FieldElement, LargeInt}; @@ -551,15 +550,7 @@ mod builder { } pub fn get_main_machine(program: &AnalysisASMFile) -> &Machine { - for (name, m) in program.items.iter() { - if name.len() == 1 && name.parts().next() == Some("Main") { - let Item::Machine(m) = m else { - panic!(); - }; - return m; - } - } - panic!(); + program.get_machine(&parse_absolute_path("::Main")).unwrap() } struct PreprocessedMain<'a, T: FieldElement> { diff --git a/riscv/src/continuations.rs b/riscv/src/continuations.rs index 826d442e30..d160e13099 100644 --- a/riscv/src/continuations.rs +++ b/riscv/src/continuations.rs @@ -221,9 +221,7 @@ pub fn rust_continuations_dry_run( let mut register_values = default_register_values(); let program = pipeline.compute_analyzed_asm().unwrap().clone(); - let main_machine = program.items[&parse_absolute_path("::Main")] - .try_to_machine() - .unwrap(); + let main_machine = program.get_machine(&parse_absolute_path("::Main")).unwrap(); sanity_check(main_machine); log::info!("Initializing memory merkle tree..."); diff --git a/riscv/src/continuations/bootloader.rs b/riscv/src/continuations/bootloader.rs index 2c7f98bd65..18f6dcbd7c 100644 --- a/riscv/src/continuations/bootloader.rs +++ b/riscv/src/continuations/bootloader.rs @@ -144,6 +144,7 @@ pub fn bootloader_preamble() -> String { } /// The bootloader: An assembly program that can be executed at the beginning of RISC-V execution. +/// /// It lets the prover provide arbitrary memory pages and writes them to memory, as well as values for /// the registers (including the PC, which is set last). /// This can be used to implement continuations. Note that this is completely non-sound. Progress to @@ -740,6 +741,7 @@ pub const REGISTER_NAMES: [&str; 37] = [ pub const PC_INDEX: usize = REGISTER_NAMES.len() - 1; /// The default PC that can be used in first chunk, will just continue with whatever comes after the bootloader. +/// /// The value is 3, because we added a jump instruction at the beginning of the code. /// Specifically, the first instructions are: /// 0: reset