From ca4157d05dadd173ef0ddacb642d77ae1b5dcc45 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 15 Jul 2024 13:03:45 +0200 Subject: [PATCH 01/24] Small refactoring of challenge-based protocols (#1567) Some small refactorings of `lookup.asm` and `permutation.asm`. --- pipeline/tests/powdr_std.rs | 2 +- std/math/fp2.asm | 16 ++++++++++++++-- std/protocols/lookup.asm | 29 ++++++----------------------- std/protocols/permutation.asm | 34 ++++++---------------------------- 4 files changed, 27 insertions(+), 54 deletions(-) diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index f241f5053..c457e906d 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -104,7 +104,7 @@ fn permutation_via_challenges_bn() { } #[test] -#[should_panic = "Error reducing expression to constraint:\nExpression: std::protocols::permutation::permutation(main.is_first, [main.z], main.alpha, main.beta, main.permutation_constraint)\nError: FailedAssertion(\"The Goldilocks field is too small and needs to move to the extension field. Pass two accumulators instead!\""] +#[should_panic = "Error reducing expression to constraint:\nExpression: std::protocols::permutation::permutation(main.is_first, [main.z], main.alpha, main.beta, main.permutation_constraint)\nError: FailedAssertion(\"The field is too small and needs to move to the extension field. Pass two elements instead!\")"] fn permutation_via_challenges_gl() { let f = "std/permutation_via_challenges.asm"; Pipeline::::default() diff --git a/std/math/fp2.asm b/std/math/fp2.asm index c12d37b7f..241a2705e 100644 --- a/std/math/fp2.asm +++ b/std/math/fp2.asm @@ -91,11 +91,16 @@ let next_ext: Fp2 -> Fp2 = |a| match a { Fp2::Fp2(a0, a1) => Fp2::Fp2(a0', a1') }; -/// Returns the two components of the extension field element +/// Returns the two components of the extension field element as a tuple let unpack_ext: Fp2 -> (T, T) = |a| match a { Fp2::Fp2(a0, a1) => (a0, a1) }; +/// Returns the two components of the extension field element as an array +let unpack_ext_array: Fp2 -> T[] = |a| match a { + Fp2::Fp2(a0, a1) => [a0, a1] +}; + /// Whether we need to operate on the F_{p^2} extension field (because the current field is too small). let needs_extension: -> bool = || match known_field() { Option::Some(KnownField::Goldilocks) => true, @@ -111,7 +116,14 @@ let is_extension = |arr| match len(arr) { }; /// Constructs an extension field element `a0 + a1 * X` from either `[a0, a1]` or `[a0]` (setting `a1`to zero in that case) -let fp2_from_array = |arr| if is_extension(arr) { Fp2::Fp2(arr[0], arr[1]) } else { from_base(arr[0]) }; +let fp2_from_array = |arr| { + if is_extension(arr) { + Fp2::Fp2(arr[0], arr[1]) + } else { + let _ = assert(!needs_extension(), || "The field is too small and needs to move to the extension field. Pass two elements instead!"); + from_base(arr[0]) + } +}; mod test { use super::Fp2; diff --git a/std/protocols/lookup.asm b/std/protocols/lookup.asm index 354cca2a6..021b4064c 100644 --- a/std/protocols/lookup.asm +++ b/std/protocols/lookup.asm @@ -8,13 +8,12 @@ use std::math::fp2::add_ext; use std::math::fp2::sub_ext; use std::math::fp2::mul_ext; use std::math::fp2::unpack_ext; +use std::math::fp2::unpack_ext_array; use std::math::fp2::next_ext; use std::math::fp2::inv_ext; use std::math::fp2::eval_ext; use std::math::fp2::from_base; -use std::math::fp2::is_extension; use std::math::fp2::fp2_from_array; -use std::math::fp2::needs_extension; use std::math::fp2::constrain_eq_ext; use std::protocols::fingerprint::fingerprint; use std::utils::unwrap_or_else; @@ -49,9 +48,7 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr, expr -> fe[] = quer eval_ext(from_base(rhs_selector)) ) )); - match res { - Fp2::Fp2(a0_fe, a1_fe) => [a0_fe, a1_fe] - } + unpack_ext_array(res) }; // Adds constraints that enforce that rhs is the lookup for lhs @@ -67,26 +64,11 @@ let lookup: expr, expr[], Fp2, Fp2, Constr, expr -> Constr[] = |is_f let (lhs_selector, lhs, rhs_selector, rhs) = unpack_lookup_constraint(lookup_constraint); - let _ = assert(len(lhs) == len(rhs), || "LHS and RHS should have equal length"); - let _ = if !is_extension(acc) { - assert(!needs_extension(), || "The Goldilocks field is too small and needs to move to the extension field. Pass two accumulators instead!") - } else { }; - - // On the extension field, we'll need two field elements to represent the challenge. - // If we don't need an extension field, we can simply set the second component to 0, - // in which case the operations below effectively only operate on the first component. - let acc_ext = fp2_from_array(acc); - let lhs_denom = sub_ext(beta, fingerprint(lhs, alpha)); let rhs_denom = sub_ext(beta, fingerprint(rhs, alpha)); let m_ext = from_base(multiplicities); - - let next_acc = if is_extension(acc) { - next_ext(acc_ext) - } else { - // The second component is 0, but the next operator is not defined on it... - from_base(acc[0]') - }; + let acc_ext = fp2_from_array(acc); + let next_acc = next_ext(acc_ext); // Update rule: // acc' * (beta - A) * (beta - B) + m * rhs_selector * (beta - A) = acc * (beta - A) * (beta - B) + lhs_selector * (beta - B) @@ -114,8 +96,9 @@ let lookup: expr, expr[], Fp2, Fp2, Constr, expr -> Constr[] = |is_f let (acc_1, acc_2) = unpack_ext(acc_ext); [ + // First and last acc needs to be 0 + // (because of wrapping, the acc[0] and acc[N] are the same) is_first * acc_1 = 0, - is_first * acc_2 = 0 ] + constrain_eq_ext(update_expr, from_base(0)) }; \ No newline at end of file diff --git a/std/protocols/permutation.asm b/std/protocols/permutation.asm index 274469b58..cba1373d7 100644 --- a/std/protocols/permutation.asm +++ b/std/protocols/permutation.asm @@ -7,12 +7,11 @@ use std::math::fp2::add_ext; use std::math::fp2::sub_ext; use std::math::fp2::mul_ext; use std::math::fp2::unpack_ext; +use std::math::fp2::unpack_ext_array; use std::math::fp2::next_ext; use std::math::fp2::inv_ext; use std::math::fp2::eval_ext; use std::math::fp2::from_base; -use std::math::fp2::needs_extension; -use std::math::fp2::is_extension; use std::math::fp2::fp2_from_array; use std::math::fp2::constrain_eq_ext; use std::protocols::fingerprint::fingerprint; @@ -50,9 +49,7 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr -> fe[] = query |acc inv_ext(eval_ext(rhs_folded)) ); - match res { - Fp2::Fp2(a0_fe, a1_fe) => [a0_fe, a1_fe] - } + unpack_ext_array(res) }; /// Returns constraints that enforce that lhs is a permutation of rhs @@ -84,29 +81,13 @@ let permutation: expr, expr[], Fp2, Fp2, Constr -> Constr[] = |is_fi let (lhs_selector, lhs, rhs_selector, rhs) = unpack_permutation_constraint(permutation_constraint); - let _ = assert(len(lhs) == len(rhs), || "LHS and RHS should have equal length"); - let _ = if !is_extension(acc) { - assert(!needs_extension(), || "The Goldilocks field is too small and needs to move to the extension field. Pass two accumulators instead!") - } else { }; - - // On the extension field, we'll need two field elements to represent the challenge. - // If we don't need an extension field, we can simply set the second component to 0, - // in which case the operations below effectively only operate on the first component. - let fp2_from_array = |arr| if is_extension(acc) { Fp2::Fp2(arr[0], arr[1]) } else { from_base(arr[0]) }; - let acc_ext = fp2_from_array(acc); - // If the selector is 1, contribute a factor of `beta - fingerprint(lhs)` to accumulator. // If the selector is 0, contribute a factor of 1 to the accumulator. // Implemented as: folded = selector * (beta - fingerprint(values) - 1) + 1; let lhs_folded = selected_or_one(lhs_selector, sub_ext(beta, fingerprint(lhs, alpha))); let rhs_folded = selected_or_one(rhs_selector, sub_ext(beta, fingerprint(rhs, alpha))); - - let next_acc = if is_extension(acc) { - next_ext(acc_ext) - } else { - // The second component is 0, but the next operator is not defined on it... - from_base(acc[0]') - }; + let acc_ext = fp2_from_array(acc); + let next_acc = next_ext(acc_ext); // Update rule: // acc' = acc * lhs_folded / rhs_folded @@ -119,12 +100,9 @@ let permutation: expr, expr[], Fp2, Fp2, Constr -> Constr[] = |is_fi let (acc_1, acc_2) = unpack_ext(acc_ext); [ - // First and last z needs to be 1 - // (because of wrapping, the z[0] and z[N] are the same) + // First and last acc needs to be 1 + // (because of wrapping, the acc[0] and acc[N] are the same) is_first * (acc_1 - 1) = 0, - - // Note that if with_extension is false, this generates 0 = 0 and is removed - // by the optimizer. is_first * acc_2 = 0 ] + constrain_eq_ext(update_expr, from_base(0)) }; \ No newline at end of file From 97aea6da9e79580a447837ceb583967fc74eae77 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Mon, 15 Jul 2024 13:43:52 +0100 Subject: [PATCH 02/24] Using mainline raki (#1575) The fix has been merged. --- riscv/Cargo.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/riscv/Cargo.toml b/riscv/Cargo.toml index c53055965..d316d7b0f 100644 --- a/riscv/Cargo.toml +++ b/riscv/Cargo.toml @@ -31,9 +31,7 @@ lalrpop-util = { version = "^0.19", features = ["lexer"] } log = "0.4.17" mktemp = "0.5.0" num-traits = "0.2.15" -# Use the patched version of raki until the fix is merged. -# Fixes the name of "mulhsu" instruction. -raki = { git = "https://github.com/powdr-labs/raki.git", branch = "patch-1" } +raki = "0.1.4" serde_json = "1.0" # This is only here to work around https://github.com/lalrpop/lalrpop/issues/750 # It should be removed once that workaround is no longer needed. From ff8f81e8ce59524407c7b36db0686ef4f55c34f1 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 15 Jul 2024 16:44:06 +0200 Subject: [PATCH 03/24] Block machine witgen: Always run default sequence after the cached sequence (#1562) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #1559 (alternative to #1560) With this PR, we always run the "default" sequence iterator, even if we have a cached sequence (which is still run before). This way, if the cached sequence was not sufficient to solve the entire block, the default solving sequence will have another attempt. The reason this doesn't lead to a dramatic performance degradation is because since #1528, we skip identities that have been completed. In the typical case, the cached sequence will have completed most identities. The reason this is better than #1560 is because the cached sequence typically makes progress with every identity, whereas the default iterator does not. ## Benchmark I ran the RISC-V Keccak example 3 times. It looks like the time spent in block machines increases by roughly 20%. Given that this fixes a bug and the overall time spent in block machines is small (even in the bitwise-heavy Keccak example), I think it's worth it! ### Main ``` == Witgen profile (1766802 events) 44.5% ( 4.4s): FixedLookup 36.3% ( 3.6s): Main Machine 9.7% ( 964.7ms): Secondary machine 0: main_binary (BlockMachine) 6.8% ( 669.3ms): witgen (outer code) 2.4% ( 236.0ms): Secondary machine 2: main_shift (BlockMachine) 0.3% ( 32.6ms): Secondary machine 1: main_memory (DoubleSortedWitnesses) 0.0% ( 183.7µs): Secondary machine 3: main_split_gl (BlockMachine) --------------------------- ==> Total: 9.907174375s == Witgen profile (1766802 events) 41.0% ( 3.8s): FixedLookup 39.0% ( 3.6s): Main Machine 10.2% ( 951.7ms): Secondary machine 0: main_binary (BlockMachine) 6.9% ( 644.2ms): witgen (outer code) 2.5% ( 231.5ms): Secondary machine 2: main_shift (BlockMachine) 0.3% ( 32.1ms): Secondary machine 1: main_memory (DoubleSortedWitnesses) 0.0% ( 183.4µs): Secondary machine 3: main_split_gl (BlockMachine) --------------------------- ==> Total: 9.295457333s == Witgen profile (1766802 events) 43.7% ( 4.2s): FixedLookup 37.0% ( 3.6s): Main Machine 10.0% ( 963.6ms): Secondary machine 0: main_binary (BlockMachine) 6.6% ( 636.3ms): witgen (outer code) 2.4% ( 234.7ms): Secondary machine 2: main_shift (BlockMachine) 0.3% ( 29.0ms): Secondary machine 1: main_memory (DoubleSortedWitnesses) 0.0% ( 190.8µs): Secondary machine 3: main_split_gl (BlockMachine) --------------------------- ==> Total: 9.677017958s ``` ### This branch ``` == Witgen profile (1986686 events) 43.3% ( 4.3s): FixedLookup 36.2% ( 3.6s): Main Machine 11.5% ( 1.1s): Secondary machine 0: main_binary (BlockMachine) 6.0% ( 600.2ms): witgen (outer code) 2.7% ( 273.3ms): Secondary machine 2: main_shift (BlockMachine) 0.3% ( 28.6ms): Secondary machine 1: main_memory (DoubleSortedWitnesses) 0.0% ( 203.4µs): Secondary machine 3: main_split_gl (BlockMachine) --------------------------- ==> Total: 9.975125084s == Witgen profile (1986686 events) 40.4% ( 3.9s): FixedLookup 37.2% ( 3.6s): Main Machine 12.1% ( 1.2s): Secondary machine 0: main_binary (BlockMachine) 7.1% ( 687.7ms): witgen (outer code) 2.9% ( 276.7ms): Secondary machine 2: main_shift (BlockMachine) 0.3% ( 30.3ms): Secondary machine 1: main_memory (DoubleSortedWitnesses) 0.0% ( 197.3µs): Secondary machine 3: main_split_gl (BlockMachine) --------------------------- ==> Total: 9.619824375s == Witgen profile (1986686 events) 42.4% ( 4.2s): FixedLookup 36.1% ( 3.6s): Main Machine 11.9% ( 1.2s): Secondary machine 0: main_binary (BlockMachine) 6.6% ( 654.9ms): witgen (outer code) 2.8% ( 276.6ms): Secondary machine 2: main_shift (BlockMachine) 0.3% ( 29.1ms): Secondary machine 1: main_memory (DoubleSortedWitnesses) 0.0% ( 202.3µs): Secondary machine 3: main_split_gl (BlockMachine) --------------------------- ==> Total: 9.957052209s ``` --- executor/src/witgen/machines/block_machine.rs | 19 ---------- executor/src/witgen/sequence_iterator.rs | 35 ++++++++++++------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index b78a8f2af..33e444e91 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -34,13 +34,6 @@ impl<'a, T: FieldElement> ProcessResult<'a, T> { false => ProcessResult::Incomplete(updates), } } - - fn is_success(&self) -> bool { - match self { - ProcessResult::Success(_, _) => true, - ProcessResult::Incomplete(_) => false, - } - } } fn collect_fixed_cols( @@ -517,18 +510,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { let process_result = self.process(mutable_state, &mut sequence_iterator, outer_query.clone())?; - let process_result = if sequence_iterator.is_cached() && !process_result.is_success() { - log::debug!("The cached sequence did not complete the block machine. \ - This can happen if the machine's execution steps depend on the input or constant values. \ - We'll try again with the default sequence."); - let mut sequence_iterator = self - .processing_sequence_cache - .get_default_sequence_iterator(); - self.process(mutable_state, &mut sequence_iterator, outer_query.clone())? - } else { - process_result - }; - match process_result { ProcessResult::Success(new_block, updates) => { log::trace!( diff --git a/executor/src/witgen/sequence_iterator.rs b/executor/src/witgen/sequence_iterator.rs index b70786845..fed7b78d7 100644 --- a/executor/src/witgen/sequence_iterator.rs +++ b/executor/src/witgen/sequence_iterator.rs @@ -170,7 +170,10 @@ pub enum ProcessingSequenceIterator { /// The default strategy Default(DefaultSequenceIterator), /// The machine has been run successfully before and the sequence is cached. - Cached( as IntoIterator>::IntoIter), + Cached( + as IntoIterator>::IntoIter, + DefaultSequenceIterator, + ), /// The machine has been run before, but did not succeed. There is no point in trying again. Incomplete, } @@ -179,24 +182,17 @@ impl ProcessingSequenceIterator { pub fn report_progress(&mut self, progress_in_last_step: bool) { match self { Self::Default(it) => it.report_progress(progress_in_last_step), - Self::Cached(_) => {} // Progress is ignored + Self::Cached(_, _) => {} // Progress is ignored Self::Incomplete => unreachable!(), } } pub fn has_steps(&self) -> bool { match self { - Self::Default(_) | Self::Cached(_) => true, + Self::Default(_) | Self::Cached(_, _) => true, Self::Incomplete => false, } } - - pub fn is_cached(&self) -> bool { - match self { - Self::Default(_) => false, - Self::Cached(_) | Self::Incomplete => true, - } - } } impl Iterator for ProcessingSequenceIterator { @@ -205,7 +201,13 @@ impl Iterator for ProcessingSequenceIterator { fn next(&mut self) -> Option { match self { Self::Default(it) => it.next(), - Self::Cached(it) => it.next(), + // After the cached iterator is exhausted, run the default iterator again. + // This is because the order in which the identities should be processed *might* + // depend on the concrete input values. + // In the typical scenario, most identities will be completed at this point and + // the block processor will skip them. But if an identity was not completed before, + // it will try again. + Self::Cached(it, default_iterator) => it.next().or_else(|| default_iterator.next()), Self::Incomplete => unreachable!(), } } @@ -246,7 +248,14 @@ impl ProcessingSequenceCache { match self.cache.get(&left.into()) { Some(CacheEntry::Complete(cached_sequence)) => { log::trace!("Using cached sequence"); - ProcessingSequenceIterator::Cached(cached_sequence.clone().into_iter()) + ProcessingSequenceIterator::Cached( + cached_sequence.clone().into_iter(), + DefaultSequenceIterator::new( + self.block_size, + self.identities_count, + Some(self.outer_query_row as i64), + ), + ) } Some(CacheEntry::Incomplete) => ProcessingSequenceIterator::Incomplete, None => { @@ -291,7 +300,7 @@ impl ProcessingSequenceCache { .is_none()); } ProcessingSequenceIterator::Incomplete => unreachable!(), - ProcessingSequenceIterator::Cached(_) => {} // Already cached, do nothing + ProcessingSequenceIterator::Cached(_, _) => {} // Already cached, do nothing } } } From 5acb5a91a0ba66c948765d98d448643217748aa2 Mon Sep 17 00:00:00 2001 From: Leo Date: Mon, 15 Jul 2024 17:42:13 +0200 Subject: [PATCH 04/24] fix riscv tests verification when using a composite backend (#1576) Needed for registers in memory. This will make some RISCV tests slower, but we won't use the normal EStarkDump backend anymore for such tests but rather composites backend only which should make them faster again. --- backend/src/lib.rs | 28 ++++++++++++++++++++++++++++ pipeline/src/test_util.rs | 16 +++++++++++++++- riscv/tests/riscv.rs | 3 ++- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/backend/src/lib.rs b/backend/src/lib.rs index e490fa9bb..6e4b4af93 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -92,6 +92,34 @@ impl BackendType { } } } + + pub fn is_composite(&self) -> bool { + match self { + #[cfg(feature = "halo2")] + BackendType::Halo2 => false, + #[cfg(feature = "halo2")] + BackendType::Halo2Composite => true, + #[cfg(feature = "halo2")] + BackendType::Halo2Mock => false, + #[cfg(feature = "halo2")] + BackendType::Halo2MockComposite => true, + #[cfg(feature = "estark-polygon")] + BackendType::EStarkPolygon => false, + #[cfg(feature = "estark-polygon")] + BackendType::EStarkPolygonComposite => true, + BackendType::EStarkStarky => false, + BackendType::EStarkStarkyComposite => true, + BackendType::EStarkDump => false, + BackendType::EStarkDumpComposite => true, + #[cfg(feature = "plonky3")] + BackendType::Plonky3 => false, + #[cfg(feature = "plonky3")] + BackendType::Plonky3Composite => true, + // We explicitly do not use a wildcard here + // so that a new composite backend needs to be + // added here too. + } + } } #[derive(thiserror::Error, Debug)] diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index bbbff59cc..d9e68615c 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -2,6 +2,7 @@ use powdr_ast::analyzed::Analyzed; use powdr_backend::BackendType; use powdr_number::{buffered_write_file, BigInt, Bn254Field, FieldElement, GoldilocksField}; use powdr_pil_analyzer::evaluator::{self, SymbolLookup}; +use std::fs; use std::path::PathBuf; use std::sync::Arc; @@ -73,7 +74,20 @@ pub fn verify_pipeline( pipeline.compute_proof().unwrap(); - verify(pipeline.output_dir().as_ref().unwrap()) + let out_dir = pipeline.output_dir().as_ref().unwrap(); + if backend.is_composite() { + // traverse all subdirs of the given output dir and verify each subproof + for entry in fs::read_dir(out_dir).unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.is_dir() { + verify(&path)?; + } + } + Ok(()) + } else { + verify(out_dir) + } } /// Makes a new pipeline for the given file and inputs. All steps until witness generation are diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 196a45283..64e5ba94c 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -368,7 +368,8 @@ fn many_chunks_memory() { } fn verify_riscv_crate(case: &str, inputs: Vec, runtime: &Runtime) { - verify_riscv_crate_with_backend(case, inputs, runtime, BackendType::EStarkDump) + verify_riscv_crate_with_backend(case, inputs.clone(), runtime, BackendType::EStarkDump); + verify_riscv_crate_with_backend(case, inputs, runtime, BackendType::EStarkDumpComposite); } fn verify_riscv_crate_with_backend( From bd4ac9c059c48ba773090a554f9b19d97e8ab3f7 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Tue, 16 Jul 2024 11:52:19 +0200 Subject: [PATCH 05/24] Fix vm_to_block_different_length test (#1580) Fixes one of the tests that failed in the [last nightly run](https://github.com/powdr-labs/powdr/actions/runs/9949789019/job/27486609397). To reproduce: ``` IS_NIGHTLY_TEST=true PILCOM=$(pwd)/pilcom/ cargo nextest run --all --features halo2 -E 'test(=vm_to_block_different_length)' ``` --- pipeline/src/test_util.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index d9e68615c..fe647ff8b 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -212,10 +212,11 @@ pub fn gen_halo2_proof(pipeline: Pipeline, backend: BackendVariant) // Setup let output_dir = pipeline.output_dir().clone().unwrap(); let setup_file_path = output_dir.join("params.bin"); + let max_degree = pil.degrees().into_iter().max().unwrap(); buffered_write_file(&setup_file_path, |writer| { powdr_backend::BackendType::Halo2 .factory::() - .generate_setup(pil.degree(), writer) + .generate_setup(max_degree, writer) .unwrap() }) .unwrap(); From eb3bdfbe60bcd8285292e4b612b9d22a15e87571 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Tue, 16 Jul 2024 13:02:32 +0100 Subject: [PATCH 06/24] Fixing atomic instructions in ELF translate. (#1582) Fixes EVM test. --- riscv/src/elf.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/riscv/src/elf.rs b/riscv/src/elf.rs index 95451de61..075bc31df 100644 --- a/riscv/src/elf.rs +++ b/riscv/src/elf.rs @@ -898,11 +898,10 @@ impl TwoOrOneMapper for InstructionLifter<'_> { }, }; - // For some reason, atomic instructions come with the immediate set to - // zero instead of None (maybe to mimic assembly syntax? Who knows). We - // must fix this: + // The acquire and release bits of an atomic instructions are decoded as + // the immediate value, but we don't need the bits and an immediate is + // not expected, so we must remove it. if let Extensions::A = insn.extension { - assert!(matches!(imm, HighLevelImmediate::Value(0))); imm = HighLevelImmediate::None; } From e6bc15d491967de3715551bf6fa48d65f69af6db Mon Sep 17 00:00:00 2001 From: Leo Date: Wed, 17 Jul 2024 11:35:41 +0200 Subject: [PATCH 07/24] change riscv tests to use composite backends (#1583) Also in preparation for registers in memory. Removing the `public`s for continuations as discussed with @georgwiese , since those are not sound anyway and we might change it completely soon. --- pipeline/src/test_util.rs | 1 - riscv/src/continuations/bootloader.rs | 20 ++++++++++---------- riscv/tests/common/mod.rs | 2 +- riscv/tests/instructions.rs | 2 +- riscv/tests/riscv.rs | 18 ++++++++++-------- 5 files changed, 22 insertions(+), 21 deletions(-) diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index fe647ff8b..ef81f3452 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -65,7 +65,6 @@ pub fn verify_pipeline( pipeline: Pipeline, backend: BackendType, ) -> Result<(), String> { - // TODO: Also test Composite variants let mut pipeline = pipeline.with_backend(backend, None); if pipeline.output_dir().is_none() { diff --git a/riscv/src/continuations/bootloader.rs b/riscv/src/continuations/bootloader.rs index cc9c7cc99..d2c5af48c 100644 --- a/riscv/src/continuations/bootloader.rs +++ b/riscv/src/continuations/bootloader.rs @@ -73,22 +73,22 @@ pub fn bootloader_preamble() -> String { for (i, reg) in REGISTER_NAMES.iter().enumerate() { let reg = reg.strip_prefix("main.").unwrap(); preamble.push_str(&format!( - " public initial_{reg} = main_bootloader_inputs.value({i});\n" + " //public initial_{reg} = main_bootloader_inputs.value({i});\n" )); } for (i, reg) in REGISTER_NAMES.iter().enumerate() { let reg = reg.strip_prefix("main.").unwrap(); preamble.push_str(&format!( - " public final_{reg} = main_bootloader_inputs.value({});\n", + " //public final_{reg} = main_bootloader_inputs.value({});\n", i + REGISTER_NAMES.len() )); } preamble.push_str(&format!( r#" - public initial_memory_hash_1 = main_bootloader_inputs.value({}); - public initial_memory_hash_2 = main_bootloader_inputs.value({}); - public initial_memory_hash_3 = main_bootloader_inputs.value({}); - public initial_memory_hash_4 = main_bootloader_inputs.value({}); + //public initial_memory_hash_1 = main_bootloader_inputs.value({}); + //public initial_memory_hash_2 = main_bootloader_inputs.value({}); + //public initial_memory_hash_3 = main_bootloader_inputs.value({}); + //public initial_memory_hash_4 = main_bootloader_inputs.value({}); "#, MEMORY_HASH_START_INDEX, MEMORY_HASH_START_INDEX + 1, @@ -97,10 +97,10 @@ pub fn bootloader_preamble() -> String { )); preamble.push_str(&format!( r#" - public final_memory_hash_1 = main_bootloader_inputs.value({}); - public final_memory_hash_2 = main_bootloader_inputs.value({}); - public final_memory_hash_3 = main_bootloader_inputs.value({}); - public final_memory_hash_4 = main_bootloader_inputs.value({}); + //public final_memory_hash_1 = main_bootloader_inputs.value({}); + //public final_memory_hash_2 = main_bootloader_inputs.value({}); + //public final_memory_hash_3 = main_bootloader_inputs.value({}); + //public final_memory_hash_4 = main_bootloader_inputs.value({}); "#, MEMORY_HASH_START_INDEX + 4, MEMORY_HASH_START_INDEX + 5, diff --git a/riscv/tests/common/mod.rs b/riscv/tests/common/mod.rs index d356cbac2..d34ab86cb 100644 --- a/riscv/tests/common/mod.rs +++ b/riscv/tests/common/mod.rs @@ -93,6 +93,6 @@ pub fn verify_riscv_asm_file(asm_file: &Path, runtime: &Runtime, use_pie: bool) &powdr_asm, &[], None, - BackendType::EStarkDump, + BackendType::EStarkDumpComposite, ); } diff --git a/riscv/tests/instructions.rs b/riscv/tests/instructions.rs index a0c8ee2b8..f077efe81 100644 --- a/riscv/tests/instructions.rs +++ b/riscv/tests/instructions.rs @@ -33,7 +33,7 @@ mod instruction_tests { &powdr_asm, Default::default(), None, - BackendType::EStarkDump, + BackendType::EStarkDumpComposite, ); } diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 64e5ba94c..1f1e0623b 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -4,7 +4,7 @@ use common::{verify_riscv_asm_file, verify_riscv_asm_string}; use mktemp::Temp; use powdr_backend::BackendType; use powdr_number::GoldilocksField; -use powdr_pipeline::{verify::verify, Pipeline}; +use powdr_pipeline::{test_util::verify_pipeline, Pipeline}; use std::path::{Path, PathBuf}; use test_log::test; @@ -47,11 +47,7 @@ fn run_continuations_test(case: &str, powdr_asm: String) { .with_prover_inputs(Default::default()) .with_output(tmp_dir.to_path_buf(), false); let pipeline_callback = |pipeline: Pipeline| -> Result<(), ()> { - // Can't use `verify_pipeline`, because the pipeline was renamed in the middle of after - // computing the constants file. - let mut pipeline = pipeline.with_backend(BackendType::EStarkDump, None); - pipeline.compute_proof().unwrap(); - verify(pipeline.output_dir().as_ref().unwrap()).unwrap(); + verify_pipeline(pipeline, BackendType::EStarkDumpComposite).unwrap(); Ok(()) }; @@ -154,7 +150,7 @@ fn vec_median_estark_polygon() { .map(|x| x.into()) .collect(), &Runtime::base(), - BackendType::EStarkPolygon, + BackendType::EStarkPolygonComposite, ); } @@ -387,7 +383,13 @@ fn verify_riscv_crate_with_data( runtime: &Runtime, data: Vec<(u32, S)>, ) { - verify_riscv_crate_from_both_paths(case, inputs, runtime, Some(data), BackendType::EStarkDump) + verify_riscv_crate_from_both_paths( + case, + inputs, + runtime, + Some(data), + BackendType::EStarkDumpComposite, + ) } fn verify_riscv_crate_from_both_paths( From 718459ba69633fb88aab92aa668c262b10f91dce Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Wed, 17 Jul 2024 11:38:18 +0200 Subject: [PATCH 08/24] `CompositeBackend`: Improve logging (#1581) This should help us to collect more data and get a better understanding of the factors that drive proof times and sizes. Example: ``` $ cargo run pil test_data/asm/block_to_block.asm -o output -f --field bn254 --prove-with halo2-composite ... Instantiating a composite backend with 2 machines: * main: * Number of witness columns: 3 * Number of fixed columns: 4 * Number of identities: * Polynomial: 3 * main_arith: * Number of witness columns: 4 * Number of fixed columns: 0 * Number of identities: * Polynomial: 2 == Proving machine: main (size 8) Starting proof generation... Generating PK for snark... Generating proof... Time taken: 151.890834ms Proof generation done. ==> Machine proof of 1753 bytes computed in 292.9045ms == Proving machine: main_arith (size 8) Starting proof generation... Generating PK for snark... Generating proof... Time taken: 154.678333ms Proof generation done. ==> Machine proof of 1625 bytes computed in 276.353375ms Proof generation took 0.5694918s Proof size: 3402 bytes Writing output/block_to_block_proof.bin. ``` --- backend/src/composite/mod.rs | 42 +++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index 0ccbe20d9..e162e2466 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -77,6 +77,30 @@ impl> BackendFactory for CompositeBacke }) .verification_keys; + log::info!( + "Instantiating a composite backend with {} machines:", + pils.len() + ); + for (machine_name, pil) in pils.iter() { + let num_witness_columns = pil.committed_polys_in_source_order().len(); + let num_fixed_columns = pil.constant_polys_in_source_order().len(); + let num_identities_by_kind = pil + .identities + .iter() + .map(|i| i.kind) + .counts() + .into_iter() + .collect::>(); + + log::info!("* {}:", machine_name); + log::info!(" * Number of witness columns: {}", num_witness_columns); + log::info!(" * Number of fixed columns: {}", num_fixed_columns); + log::info!(" * Number of identities:"); + for (kind, count) in num_identities_by_kind { + log::info!(" * {:?}: {}", kind, count); + } + } + let machine_data = pils .into_iter() .zip_eq(verification_keys.into_iter()) @@ -157,9 +181,25 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> { log::info!("== Proving machine: {} (size {})", machine, pil.degree()); log::debug!("PIL:\n{}", pil); + let start = std::time::Instant::now(); + let witness = machine_witness_columns(witness, pil, machine); - backend.prove(&witness, None, witgen_callback) + let proof = backend.prove(&witness, None, witgen_callback); + + match &proof { + Ok(proof) => { + log::info!( + "==> Machine proof of {} bytes computed in {:?}", + proof.len(), + start.elapsed() + ); + } + Err(e) => { + log::error!("==> Machine proof failed: {:?}", e); + } + }; + proof }) .collect::>()?, }; From 9483a26f413c911fcea2039b484596fba8b7c261 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Wed, 17 Jul 2024 11:51:10 +0200 Subject: [PATCH 09/24] Block machine: report side effect (fixes TODO) (#1563) Since #1385 is resolved, we can fix this TODO. --- executor/src/witgen/machines/block_machine.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index 33e444e91..8e8f2575d 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -518,10 +518,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { ); self.append_block(new_block)?; - // TODO: This would be the right thing to do, but currently leads to failing tests - // due to #1385 ("Witgen: Block machines "forget" that they already completed a block"): - // https://github.com/powdr-labs/powdr/issues/1385 - // let updates = updates.report_side_effect(); + let updates = updates.report_side_effect(); // We solved the query, so report it to the cache. self.processing_sequence_cache From 4f873e6205d8c95ae3b33328223fcc933fe5f6b1 Mon Sep 17 00:00:00 2001 From: onurinanc Date: Wed, 17 Jul 2024 13:50:57 +0300 Subject: [PATCH 10/24] Implement Basic Bus (#1566) Related to the issue of implementing basic bus (#1497), I have implemented basic bus together with an example (`permutation_via_bus.asm`) as specified inside the issue. Currently, `test_data/std/bus_permutation_via_challenges.asm` works as intended (To make it sound, stage(1) witness columns need to be exposed publicly and verifier needs to check such as `out_z1 + out_z2 = 0`) We can now check using RUST_LOG=trace and adding the final z1 and z2 is equal to 0. However, `test_data/std/bus_permutation_via_challenges_ext.asm` is not working correctly as intended. This will be fixed with the following commits. --- pipeline/tests/powdr_std.rs | 12 ++ std/protocols/bus.asm | 104 ++++++++++++++++++ std/protocols/fingerprint.asm | 5 +- std/protocols/lookup.asm | 20 ++-- std/protocols/mod.asm | 4 +- std/protocols/permutation.asm | 3 + std/protocols/permutation_via_bus.asm | 26 +++++ .../std/bus_permutation_via_challenges.asm | 34 ++++++ .../bus_permutation_via_challenges_ext.asm | 61 ++++++++++ 9 files changed, 257 insertions(+), 12 deletions(-) create mode 100644 std/protocols/bus.asm create mode 100644 std/protocols/permutation_via_bus.asm create mode 100644 test_data/std/bus_permutation_via_challenges.asm create mode 100644 test_data/std/bus_permutation_via_challenges_ext.asm diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index c457e906d..a6a5c48e6 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -156,6 +156,18 @@ fn lookup_via_challenges_ext_simple() { .unwrap(); } +#[test] +fn bus_permutation_via_challenges_bn() { + let f = "std/bus_permutation_via_challenges.asm"; + test_halo2(f, Default::default()); +} + +#[test] +fn bus_permutation_via_challenges_ext_bn() { + let f = "std/bus_permutation_via_challenges_ext.asm"; + test_halo2(f, Default::default()); +} + #[test] fn write_once_memory_test() { let f = "std/write_once_memory_test.asm"; diff --git a/std/protocols/bus.asm b/std/protocols/bus.asm new file mode 100644 index 000000000..34e941602 --- /dev/null +++ b/std/protocols/bus.asm @@ -0,0 +1,104 @@ +use std::check::assert; +use std::check::panic; +use std::math::fp2::Fp2; +use std::math::fp2::add_ext; +use std::math::fp2::sub_ext; +use std::math::fp2::mul_ext; +use std::math::fp2::inv_ext; +use std::math::fp2::eval_ext; +use std::math::fp2::unpack_ext; +use std::math::fp2::unpack_ext_array; +use std::math::fp2::next_ext; +use std::math::fp2::from_base; +use std::math::fp2::needs_extension; +use std::math::fp2::is_extension; +use std::math::fp2::fp2_from_array; +use std::math::fp2::constrain_eq_ext; +use std::protocols::fingerprint::fingerprint_with_id; +use std::prover::eval; + +/// Sends the tuple (id, tuple...) to the bus by adding +/// `multiplicity / (beta - fingerprint(id, tuple...))` to `acc` +/// It is the callers responsibility to properly constrain the multiplicity (e.g. constrain +/// it to be boolean) if needed. +/// +/// # Arguments: +/// +/// - is_first: A column that is 1 for the first row and 0 for the rest +/// - id: Interaction Id +/// - tuple: An array of columns to be sent to the bus +/// - multiplicity: The multiplicity which shows how many times a column will be sent +/// - acc: A phase-2 witness column to be used as the accumulator. If 2 are provided, computations +/// are done on the F_{p^2} extension field. +/// - alpha: A challenge used to compress id and tuple +/// - beta: A challenge used to update the accumulator +/// +/// # Returns: +/// +/// - Constraints to be added to enforce the bus +let bus_interaction: expr, expr, expr[], expr, expr[], Fp2, Fp2 -> Constr[] = |is_first, id, tuple, multiplicity, acc, alpha, beta| { + + // Implemented as: folded = (beta - fingerprint(id, tuple...)); + let folded = sub_ext(beta, fingerprint_with_id(id, tuple, alpha)); + let folded_next = next_ext(folded); + + let m_ext = from_base(multiplicity); + let m_ext_next = next_ext(m_ext); + + let acc_ext = fp2_from_array(acc); + let next_acc = next_ext(acc_ext); + + let is_first_next = from_base(is_first'); + + // Update rule: + // acc' = acc * (1 - is_first') + multiplicity' / folded' + // or equivalently: + // folded' * (acc' - acc * (1 - is_first')) - multiplicity' = 0 + let update_expr = sub_ext( + mul_ext(folded_next, sub_ext(next_acc, mul_ext(acc_ext, sub_ext(from_base(1), is_first_next)))), m_ext_next + ); + + constrain_eq_ext(update_expr, from_base(0)) +}; + +/// Compute acc' = acc * (1 - is_first') + multiplicity' / fingerprint_with_id(id, (a1', a2')), +/// using extension field arithmetic. +/// This is intended to be used as a hint in the extension field case; for the base case +/// automatic witgen is smart enough to figure out the value of the accumulator. +let compute_next_z_send: expr, expr, expr[], expr, Fp2, Fp2, Fp2 -> fe[] = query |is_first, id, tuple, multiplicity, acc, alpha, beta| { + // Implemented as: folded = (beta - fingerprint(id, tuple...)); + // `multiplicity / (beta - fingerprint(id, tuple...))` to `acc` + let folded = sub_ext(beta, fingerprint_with_id(id, tuple, alpha)); + let folded_next = next_ext(folded); + + let m_ext = from_base(multiplicity); + let m_ext_next = next_ext(m_ext); + + let is_first_next = eval(is_first'); + let current_acc = if is_first_next == 1 {from_base(0)} else {eval_ext(acc)}; + + // acc' = acc * (1 - is_first') + multiplicity / fingerprint_with_id(id, (a1', a2')) + let res = add_ext( + current_acc, + mul_ext(eval_ext(m_ext_next), inv_ext(eval_ext(folded_next))) + ); + + unpack_ext_array(res) +}; + +/// Compute acc' = acc * (1 - is_first') - multiplicity' / fingerprint_with_id(id, (a1', a2')), +/// using extension field arithmetic. +/// This is intended to be used as a hint in the extension field case; for the base case +/// automatic witgen is smart enough to figure out the value of the accumulator. +let compute_next_z_receive: expr, expr, expr[], expr, Fp2, Fp2, Fp2 -> fe[] = query |is_first, id, tuple, multiplicity, acc, alpha, beta| + compute_next_z_send(is_first, id, tuple, -multiplicity, acc, alpha, beta); + +/// Convenience function for bus interaction to send columns +let bus_send: expr, expr, expr[], expr, expr[], Fp2, Fp2 -> Constr[] = |is_first, id, tuple, multiplicity, acc, alpha, beta| { + bus_interaction(is_first, id, tuple, multiplicity, acc, alpha, beta) +}; + +/// Convenience function for bus interaction to receive columns +let bus_receive: expr, expr, expr[], expr, expr[], Fp2, Fp2 -> Constr[] = |is_first, id, tuple, multiplicity, acc, alpha, beta| { + bus_interaction(is_first, id, tuple, -1 * multiplicity, acc, alpha, beta) +}; \ No newline at end of file diff --git a/std/protocols/fingerprint.asm b/std/protocols/fingerprint.asm index cff54175e..1ac5577ff 100644 --- a/std/protocols/fingerprint.asm +++ b/std/protocols/fingerprint.asm @@ -9,4 +9,7 @@ let fingerprint: T[], Fp2 -> Fp2 = |expr_array expr_array, from_base(0), |sum_acc, el| add_ext(mul_ext(alpha, sum_acc), from_base(el)) -); \ No newline at end of file +); + +/// Maps [id, x_1, x_2, ..., x_n] to its Read-Solomon fingerprint, using a challenge alpha: $\sum_{i=1}^n alpha**{(n - i)} * x_i$ +let fingerprint_with_id: T, T[], Fp2 -> Fp2 = |id, expr_array, alpha| fingerprint([id] + expr_array, alpha); \ No newline at end of file diff --git a/std/protocols/lookup.asm b/std/protocols/lookup.asm index 021b4064c..deb62b3f2 100644 --- a/std/protocols/lookup.asm +++ b/std/protocols/lookup.asm @@ -28,7 +28,7 @@ let unpack_lookup_constraint: Constr -> (expr, expr[], expr, expr[]) = |lookup_c _ => panic("Expected lookup constraint") }; -// Compute z' = z + 1/(beta-a_i) * lhs_selector - m_i/(beta-b_i) * rhs_selector, using extension field arithmetic +/// Compute z' = z + 1/(beta-a_i) * lhs_selector - m_i/(beta-b_i) * rhs_selector, using extension field arithmetic let compute_next_z: Fp2, Fp2, Fp2, Constr, expr -> fe[] = query |acc, alpha, beta, lookup_constraint, multiplicities| { let (lhs_selector, lhs, rhs_selector, rhs) = unpack_lookup_constraint(lookup_constraint); @@ -51,15 +51,15 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr, expr -> fe[] = quer unpack_ext_array(res) }; -// Adds constraints that enforce that rhs is the lookup for lhs -// Arguments: -// - is_first: A column that is 1 for the first row and 0 for the rest -// - alpha: A challenge used to compress the LHS and RHS values -// - beta: A challenge used to update the accumulator -// - acc: A phase-2 witness column to be used as the accumulator. If 2 are provided, computations -// are done on the F_{p^2} extension field. -// - lookup_constraint: The lookup constraint -// - multiplicities: The multiplicities which shows how many times each RHS value appears in the LHS +/// Adds constraints that enforce that rhs is the lookup for lhs +/// Arguments: +/// - is_first: A column that is 1 for the first row and 0 for the rest +/// - acc: A phase-2 witness column to be used as the accumulator. If 2 are provided, computations +/// are done on the F_{p^2} extension field. +/// - alpha: A challenge used to compress the LHS and RHS values +/// - beta: A challenge used to update the accumulator +/// - lookup_constraint: The lookup constraint +/// - multiplicities: The multiplicities which shows how many times each RHS value appears in the LHS let lookup: expr, expr[], Fp2, Fp2, Constr, expr -> Constr[] = |is_first, acc, alpha, beta, lookup_constraint, multiplicities| { let (lhs_selector, lhs, rhs_selector, rhs) = unpack_lookup_constraint(lookup_constraint); diff --git a/std/protocols/mod.asm b/std/protocols/mod.asm index 90d583ae0..64256fa08 100644 --- a/std/protocols/mod.asm +++ b/std/protocols/mod.asm @@ -1,3 +1,5 @@ +mod bus; mod fingerprint; mod lookup; -mod permutation; \ No newline at end of file +mod permutation; +mod permutation_via_bus; \ No newline at end of file diff --git a/std/protocols/permutation.asm b/std/protocols/permutation.asm index cba1373d7..06fb0a5c5 100644 --- a/std/protocols/permutation.asm +++ b/std/protocols/permutation.asm @@ -55,8 +55,11 @@ let compute_next_z: Fp2, Fp2, Fp2, Constr -> fe[] = query |acc /// Returns constraints that enforce that lhs is a permutation of rhs /// /// # Arguments: +/// - is_first: A column that is 1 for the first row and 0 for the rest /// - acc: A phase-2 witness column to be used as the accumulator. If 2 are provided, computations /// are done on the F_{p^2} extension field. +/// - alpha: A challenge used to compress the LHS and RHS values +/// - beta: A challenge used to update the accumulator /// - permutation_constraint: The permutation constraint /// /// # Returns: diff --git a/std/protocols/permutation_via_bus.asm b/std/protocols/permutation_via_bus.asm new file mode 100644 index 000000000..8f6aaa548 --- /dev/null +++ b/std/protocols/permutation_via_bus.asm @@ -0,0 +1,26 @@ +use std::array::len; +use std::check::assert; +use std::protocols::bus::bus_send; +use std::protocols::bus::bus_receive; +use std::protocols::bus::compute_next_z_send; +use std::protocols::bus::compute_next_z_receive; +use std::protocols::permutation::unpack_permutation_constraint; +use std::math::fp2::Fp2; + +// Example usage of the bus: Implement a permutation constraint +// To make this sound, the last values of `acc_lhs` and `acc_rhs` need to be +// exposed as publics, and the verifier needs to assert that they sum to 0. +let permutation: expr, expr, expr[], expr[], Fp2, Fp2, Constr -> Constr[] = |is_first, id, acc_lhs, acc_rhs, alpha, beta, permutation_constraint| { + let (lhs_selector, lhs, rhs_selector, rhs) = unpack_permutation_constraint(permutation_constraint); + bus_send(is_first, id, lhs, lhs_selector, acc_lhs, alpha, beta) + bus_receive(is_first, id, rhs, rhs_selector, acc_rhs, alpha, beta) +}; + +let compute_next_z_send_permutation: expr, expr, Fp2, Fp2, Fp2, Constr -> fe[] = query |is_first, id, acc, alpha, beta, permutation_constraint| { + let (lhs_selector, lhs, rhs_selector, rhs) = unpack_permutation_constraint(permutation_constraint); + compute_next_z_send(is_first, id, lhs, lhs_selector, acc, alpha, beta) +}; + +let compute_next_z_receive_permutation: expr, expr, Fp2, Fp2, Fp2, Constr -> fe[] = query |is_first, id, acc, alpha, beta, permutation_constraint| { + let (lhs_selector, lhs, rhs_selector, rhs) = unpack_permutation_constraint(permutation_constraint); + compute_next_z_receive(is_first, id, rhs, rhs_selector, acc, alpha, beta) +}; \ No newline at end of file diff --git a/test_data/std/bus_permutation_via_challenges.asm b/test_data/std/bus_permutation_via_challenges.asm new file mode 100644 index 000000000..d325f0f31 --- /dev/null +++ b/test_data/std/bus_permutation_via_challenges.asm @@ -0,0 +1,34 @@ +use std::prover::Query; +use std::convert::fe; +use std::protocols::permutation_via_bus::permutation; +use std::math::fp2::from_base; +use std::prover::challenge; + +machine Main with degree: 8 { + let alpha = from_base(challenge(0, 1)); + let beta = from_base(challenge(0, 2)); + + col fixed first_four = [1, 1, 1, 1, 0, 0, 0, 0]; + + // Two pairs of witness columns, claimed to be permutations of one another + // (when selected by first_four and (1 - first_four), respectively) + col witness a1(i) query Query::Hint(fe(i)); + col witness a2(i) query Query::Hint(fe(i + 42)); + col witness b1(i) query Query::Hint(fe(7 - i)); + col witness b2(i) query Query::Hint(fe(7 - i + 42)); + + let permutation_constraint = Constr::Permutation( + (Option::Some(first_four), Option::Some(1 - first_four)), + [(a1, b1), (a2, b2)] + ); + + // TODO: Functions currently cannot add witness columns at later stages, + // so we have to manually create it here and pass it to permutation(). + col witness stage(1) z; + col witness stage(1) u; + + let is_first: col = std::well_known::is_first; + permutation(is_first, 1, [z], [u], alpha, beta, permutation_constraint); + + is_first' * (z + u) = 0; +} diff --git a/test_data/std/bus_permutation_via_challenges_ext.asm b/test_data/std/bus_permutation_via_challenges_ext.asm new file mode 100644 index 000000000..d69962ff8 --- /dev/null +++ b/test_data/std/bus_permutation_via_challenges_ext.asm @@ -0,0 +1,61 @@ +use std::prover::Query; +use std::convert::fe; +use std::protocols::permutation_via_bus::permutation; +use std::protocols::permutation_via_bus::compute_next_z_send_permutation; +use std::protocols::permutation_via_bus::compute_next_z_receive_permutation; +use std::math::fp2::Fp2; +use std::prover::challenge; + +machine Main with degree: 8 { + + let alpha1: expr = challenge(0, 1); + let alpha2: expr = challenge(0, 2); + let beta1: expr = challenge(0, 3); + let beta2: expr = challenge(0, 4); + let alpha = Fp2::Fp2(alpha1, alpha2); + let beta = Fp2::Fp2(beta1, beta2); + + col fixed first_four = [1, 1, 1, 1, 0, 0, 0, 0]; + + // Two pairs of witness columns, claimed to be permutations of one another + // (when selected by first_four and (1 - first_four), respectively) + col witness a1(i) query Query::Hint(fe(i)); + col witness a2(i) query Query::Hint(fe(i + 42)); + col witness b1(i) query Query::Hint(fe(7 - i)); + col witness b2(i) query Query::Hint(fe(7 - i + 42)); + + let permutation_constraint = Constr::Permutation( + (Option::Some(first_four), Option::Some(1 - first_four)), + [(a1, b1), (a2, b2)] + ); + + // TODO: Functions currently cannot add witness columns at later stages, + // so we have to manually create it here and pass it to permutation(). + col witness stage(1) z1; + col witness stage(1) z2; + let z = Fp2::Fp2(z1, z2); + + col witness stage(1) u1; + col witness stage(1) u2; + let u = Fp2::Fp2(u1, u2); + + let is_first: col = std::well_known::is_first; + permutation(is_first, 1, [z1,z2], [u1, u2], alpha, beta, permutation_constraint); + + let hint_send = query |i| Query::Hint(compute_next_z_send_permutation(is_first, 1, z, alpha, beta, permutation_constraint)[i]); + col witness stage(1) z1_next(i) query hint_send(0); + col witness stage(1) z2_next(i) query hint_send(1); + + z1' = z1_next; + z2' = z2_next; + + let hint_receive = query |i| Query::Hint(compute_next_z_receive_permutation(is_first, 1, u, alpha, beta, permutation_constraint)[i]); + col witness stage(1) u1_next(i) query hint_receive(0); + col witness stage(1) u2_next(i) query hint_receive(1); + + u1' = u1_next; + u2' = u2_next; + + is_first' * (z1 + u1) = 0; + is_first' * (z2 + u2) = 0; +} From e63c59be090b64025ac95e2a899cdffe11d0f126 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Wed, 17 Jul 2024 15:40:08 +0200 Subject: [PATCH 11/24] Always test PILCOM + Composite (#1587) This PR contains the following changes: - I renamed a bunch of functions `verify_*` to `run_pilcom_*`, because I think it better describes what they do - They all call `run_pilcom_with_backend_variant`, which works analogous to `gen_estark_proof_with_backend_variant` and `test_halo2_with_backend_variant` - In functions like `run_pilcom_test_file` (previously `verify_test_file`), which are used for many tests, we now test both the monolithic and composite backend variant (but share the generated witness & constants). This is anlogous to `gen_estark_proof` and `test_halo2` - In the RISC-V tests, we only test the composite variant, because with registers in memory (#1443) we don't expect the monolithic backend variant to work anymore. --- backend/src/lib.rs | 28 ------------ pipeline/src/test_util.rs | 85 +++++++++++++++++++++---------------- pipeline/tests/asm.rs | 20 ++++++--- pipeline/tests/pil.rs | 27 ++++++++---- pipeline/tests/powdr_std.rs | 31 ++++++++------ riscv/tests/common/mod.rs | 19 +++------ riscv/tests/instructions.rs | 9 +--- riscv/tests/riscv.rs | 34 ++++----------- 8 files changed, 114 insertions(+), 139 deletions(-) diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 6e4b4af93..e490fa9bb 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -92,34 +92,6 @@ impl BackendType { } } } - - pub fn is_composite(&self) -> bool { - match self { - #[cfg(feature = "halo2")] - BackendType::Halo2 => false, - #[cfg(feature = "halo2")] - BackendType::Halo2Composite => true, - #[cfg(feature = "halo2")] - BackendType::Halo2Mock => false, - #[cfg(feature = "halo2")] - BackendType::Halo2MockComposite => true, - #[cfg(feature = "estark-polygon")] - BackendType::EStarkPolygon => false, - #[cfg(feature = "estark-polygon")] - BackendType::EStarkPolygonComposite => true, - BackendType::EStarkStarky => false, - BackendType::EStarkStarkyComposite => true, - BackendType::EStarkDump => false, - BackendType::EStarkDumpComposite => true, - #[cfg(feature = "plonky3")] - BackendType::Plonky3 => false, - #[cfg(feature = "plonky3")] - BackendType::Plonky3Composite => true, - // We explicitly do not use a wildcard here - // so that a new composite backend needs to be - // added here too. - } - } } #[derive(thiserror::Error, Debug)] diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index ef81f3452..e864e1876 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -30,19 +30,35 @@ pub fn execute_test_file( .map(|_| ()) } -pub fn verify_test_file( +/// Makes a new pipeline for the given file and inputs. All steps until witness generation are +/// already computed, so that the test can branch off from there, without having to re-compute +/// these steps. +pub fn make_prepared_pipeline( file_name: &str, - inputs: Vec, - external_witness_values: Vec<(String, Vec)>, -) -> Result<(), String> { - let pipeline = Pipeline::default() + inputs: Vec, + external_witness_values: Vec<(String, Vec)>, +) -> Pipeline { + let mut pipeline = Pipeline::default() + .with_tmp_output() .from_file(resolve_test_file(file_name)) .with_prover_inputs(inputs) .add_external_witness_values(external_witness_values); - verify_pipeline(pipeline, BackendType::EStarkDump) + pipeline.compute_witness().unwrap(); + pipeline +} + +pub fn run_pilcom_test_file( + file_name: &str, + inputs: Vec, + external_witness_values: Vec<(String, Vec)>, +) -> Result<(), String> { + let pipeline = make_prepared_pipeline(file_name, inputs, external_witness_values); + run_pilcom_with_backend_variant(pipeline.clone(), BackendVariant::Monolithic)?; + run_pilcom_with_backend_variant(pipeline, BackendVariant::Composite)?; + Ok(()) } -pub fn verify_asm_string( +pub fn run_pilcom_asm_string( file_name: &str, contents: &str, inputs: Vec, @@ -57,14 +73,20 @@ pub fn verify_asm_string( if let Some(data) = data { pipeline = pipeline.add_data_vec(&data); } + pipeline.compute_witness().unwrap(); - verify_pipeline(pipeline, BackendType::EStarkDump).unwrap(); + run_pilcom_with_backend_variant(pipeline.clone(), BackendVariant::Monolithic).unwrap(); + run_pilcom_with_backend_variant(pipeline, BackendVariant::Composite).unwrap(); } -pub fn verify_pipeline( +pub fn run_pilcom_with_backend_variant( pipeline: Pipeline, - backend: BackendType, + backend_variant: BackendVariant, ) -> Result<(), String> { + let backend = match backend_variant { + BackendVariant::Monolithic => BackendType::EStarkDump, + BackendVariant::Composite => BackendType::EStarkDumpComposite, + }; let mut pipeline = pipeline.with_backend(backend, None); if pipeline.output_dir().is_none() { @@ -74,35 +96,24 @@ pub fn verify_pipeline( pipeline.compute_proof().unwrap(); let out_dir = pipeline.output_dir().as_ref().unwrap(); - if backend.is_composite() { - // traverse all subdirs of the given output dir and verify each subproof - for entry in fs::read_dir(out_dir).unwrap() { - let entry = entry.unwrap(); - let path = entry.path(); - if path.is_dir() { - verify(&path)?; + match backend_variant { + BackendVariant::Composite => { + // traverse all subdirs of the given output dir and verify each subproof + for entry in fs::read_dir(out_dir).unwrap() { + let entry = entry.unwrap(); + let path = entry.path(); + if path.is_dir() { + verify(&path)?; + } } + Ok(()) } - Ok(()) - } else { - verify(out_dir) + BackendVariant::Monolithic => verify(out_dir), } } -/// Makes a new pipeline for the given file and inputs. All steps until witness generation are -/// already computed, so that the test can branch off from there, without having to re-compute -/// these steps. -pub fn make_prepared_pipeline(file_name: &str, inputs: Vec) -> Pipeline { - let mut pipeline = Pipeline::default() - .with_tmp_output() - .from_file(resolve_test_file(file_name)) - .with_prover_inputs(inputs); - pipeline.compute_witness().unwrap(); - pipeline -} - pub fn gen_estark_proof(file_name: &str, inputs: Vec) { - let pipeline = make_prepared_pipeline(file_name, inputs); + let pipeline = make_prepared_pipeline(file_name, inputs, Vec::new()); gen_estark_proof_with_backend_variant(pipeline.clone(), BackendVariant::Monolithic); gen_estark_proof_with_backend_variant(pipeline, BackendVariant::Composite); } @@ -146,7 +157,7 @@ pub fn gen_estark_proof_with_backend_variant( } pub fn test_halo2(file_name: &str, inputs: Vec) { - let pipeline = make_prepared_pipeline(file_name, inputs); + let pipeline = make_prepared_pipeline(file_name, inputs, Vec::new()); test_halo2_with_backend_variant(pipeline.clone(), BackendVariant::Monolithic); test_halo2_with_backend_variant(pipeline, BackendVariant::Composite); } @@ -338,12 +349,14 @@ pub fn assert_proofs_fail_for_invalid_witnesses_pilcom( file_name: &str, witness: &[(String, Vec)], ) { - let pipeline = Pipeline::::default() + let mut pipeline = Pipeline::::default() .with_tmp_output() .from_file(resolve_test_file(file_name)) .set_witness(convert_witness(witness)); + pipeline.compute_witness().unwrap(); - assert!(verify_pipeline(pipeline.clone(), BackendType::EStarkDump).is_err()); + assert!(run_pilcom_with_backend_variant(pipeline.clone(), BackendVariant::Monolithic).is_err()); + assert!(run_pilcom_with_backend_variant(pipeline, BackendVariant::Composite).is_err()); } pub fn assert_proofs_fail_for_invalid_witnesses_estark( diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index 66d69592c..a022eb368 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -3,8 +3,8 @@ use powdr_number::{Bn254Field, FieldElement, GoldilocksField}; use powdr_pipeline::{ test_util::{ gen_estark_proof, gen_estark_proof_with_backend_variant, make_prepared_pipeline, - resolve_test_file, test_halo2, test_halo2_with_backend_variant, verify_test_file, - BackendVariant, + resolve_test_file, run_pilcom_test_file, run_pilcom_with_backend_variant, test_halo2, + test_halo2_with_backend_variant, BackendVariant, }, util::{read_poly_set, FixedPolySet, WitnessPolySet}, Pipeline, @@ -12,7 +12,7 @@ use powdr_pipeline::{ use test_log::test; fn verify_asm(file_name: &str, inputs: Vec) { - verify_test_file(file_name, inputs, vec![]).unwrap(); + run_pilcom_test_file(file_name, inputs, vec![]).unwrap(); } fn slice_to_vec(arr: &[i32]) -> Vec { @@ -82,7 +82,7 @@ fn mem_write_once_external_write() { mem[17] = GoldilocksField::from(42); mem[62] = GoldilocksField::from(123); mem[255] = GoldilocksField::from(-1); - verify_test_file( + run_pilcom_test_file( f, Default::default(), vec![("main_memory.value".to_string(), mem)], @@ -228,9 +228,17 @@ fn vm_to_block_different_length() { let f = "asm/vm_to_block_different_length.asm"; // Because machines have different lengths, this can only be proven // with a composite proof. - test_halo2_with_backend_variant(make_prepared_pipeline(f, vec![]), BackendVariant::Composite); + run_pilcom_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ) + .unwrap(); + test_halo2_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ); gen_estark_proof_with_backend_variant( - make_prepared_pipeline(f, vec![]), + make_prepared_pipeline(f, vec![], vec![]), BackendVariant::Composite, ); } diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 073739705..b13470072 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -5,14 +5,15 @@ use powdr_pipeline::test_util::{ assert_proofs_fail_for_invalid_witnesses, assert_proofs_fail_for_invalid_witnesses_estark, assert_proofs_fail_for_invalid_witnesses_halo2, assert_proofs_fail_for_invalid_witnesses_pilcom, gen_estark_proof, - gen_estark_proof_with_backend_variant, make_prepared_pipeline, test_halo2, - test_halo2_with_backend_variant, test_plonky3, verify_test_file, BackendVariant, + gen_estark_proof_with_backend_variant, make_prepared_pipeline, run_pilcom_test_file, + run_pilcom_with_backend_variant, test_halo2, test_halo2_with_backend_variant, test_plonky3, + BackendVariant, }; use test_log::test; pub fn verify_pil(file_name: &str, inputs: Vec) { - verify_test_file(file_name, inputs, vec![]).unwrap(); + run_pilcom_test_file(file_name, inputs, vec![]).unwrap(); } #[test] @@ -145,14 +146,14 @@ fn external_witgen_fails_if_none_provided() { fn external_witgen_a_provided() { let f = "pil/external_witgen.pil"; let external_witness = vec![("main.a".to_string(), vec![GoldilocksField::from(3); 16])]; - verify_test_file(f, Default::default(), external_witness).unwrap(); + run_pilcom_test_file(f, Default::default(), external_witness).unwrap(); } #[test] fn external_witgen_b_provided() { let f = "pil/external_witgen.pil"; let external_witness = vec![("main.b".to_string(), vec![GoldilocksField::from(4); 16])]; - verify_test_file(f, Default::default(), external_witness).unwrap(); + run_pilcom_test_file(f, Default::default(), external_witness).unwrap(); } #[test] @@ -162,7 +163,7 @@ fn external_witgen_both_provided() { ("main.a".to_string(), vec![GoldilocksField::from(3); 16]), ("main.b".to_string(), vec![GoldilocksField::from(4); 16]), ]; - verify_test_file(f, Default::default(), external_witness).unwrap(); + run_pilcom_test_file(f, Default::default(), external_witness).unwrap(); } #[test] @@ -174,7 +175,7 @@ fn external_witgen_fails_on_conflicting_external_witness() { // Does not satisfy b = a + 1 ("main.b".to_string(), vec![GoldilocksField::from(3); 16]), ]; - verify_test_file(f, Default::default(), external_witness).unwrap(); + run_pilcom_test_file(f, Default::default(), external_witness).unwrap(); } #[test] @@ -313,9 +314,17 @@ fn different_degrees() { let f = "pil/different_degrees.pil"; // Because machines have different lengths, this can only be proven // with a composite proof. - test_halo2_with_backend_variant(make_prepared_pipeline(f, vec![]), BackendVariant::Composite); + run_pilcom_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ) + .unwrap(); + test_halo2_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ); gen_estark_proof_with_backend_variant( - make_prepared_pipeline(f, vec![]), + make_prepared_pipeline(f, vec![], vec![]), BackendVariant::Composite, ); } diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index a6a5c48e6..93047b0f5 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -6,8 +6,8 @@ use powdr_pil_analyzer::evaluator::Value; use powdr_pipeline::{ test_util::{ evaluate_function, evaluate_integer_function, execute_test_file, gen_estark_proof, - gen_halo2_proof, make_prepared_pipeline, resolve_test_file, std_analyzed, test_halo2, - verify_test_file, BackendVariant, + gen_halo2_proof, make_prepared_pipeline, resolve_test_file, run_pilcom_test_file, + std_analyzed, test_halo2, BackendVariant, }, Pipeline, }; @@ -22,23 +22,26 @@ fn poseidon_bn254_test() { // This makes sure we test the whole proof generation for one example // file even in the PR tests. gen_halo2_proof( - make_prepared_pipeline(f, vec![]), + make_prepared_pipeline(f, vec![], vec![]), BackendVariant::Monolithic, ); - gen_halo2_proof(make_prepared_pipeline(f, vec![]), BackendVariant::Composite); + gen_halo2_proof( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ); } #[test] fn poseidon_gl_test() { let f = "std/poseidon_gl_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); } #[test] fn poseidon_gl_memory_test() { let f = "std/poseidon_gl_memory_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); } @@ -51,7 +54,7 @@ fn split_bn254_test() { #[test] fn split_gl_test() { let f = "std/split_gl_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); } @@ -59,7 +62,7 @@ fn split_gl_test() { #[ignore = "Too slow"] fn arith_test() { let f = "std/arith_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); // Running gen_estark_proof(f, Default::default()) // is too slow for the PR tests. This will only create a single @@ -76,7 +79,7 @@ fn arith_test() { #[test] fn memory_test() { let f = "std/memory_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); test_halo2(f, Default::default()); } @@ -84,7 +87,7 @@ fn memory_test() { #[test] fn memory_with_bootloader_write_test() { let f = "std/memory_with_bootloader_write_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); test_halo2(f, Default::default()); } @@ -92,7 +95,7 @@ fn memory_with_bootloader_write_test() { #[test] fn memory_test_parallel_accesses() { let f = "std/memory_test_parallel_accesses.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); test_halo2(f, Default::default()); } @@ -171,7 +174,7 @@ fn bus_permutation_via_challenges_ext_bn() { #[test] fn write_once_memory_test() { let f = "std/write_once_memory_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); gen_estark_proof(f, Default::default()); test_halo2(f, Default::default()); } @@ -179,14 +182,14 @@ fn write_once_memory_test() { #[test] fn binary_test() { let f = "std/binary_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); test_halo2(f, Default::default()); } #[test] fn shift_test() { let f = "std/shift_test.asm"; - verify_test_file(f, Default::default(), vec![]).unwrap(); + run_pilcom_test_file(f, Default::default(), vec![]).unwrap(); test_halo2(f, Default::default()); } diff --git a/riscv/tests/common/mod.rs b/riscv/tests/common/mod.rs index d34ab86cb..274d927bc 100644 --- a/riscv/tests/common/mod.rs +++ b/riscv/tests/common/mod.rs @@ -1,20 +1,21 @@ use mktemp::Temp; -use powdr_backend::BackendType; use powdr_number::GoldilocksField; -use powdr_pipeline::{test_util::verify_pipeline, Pipeline}; +use powdr_pipeline::{ + test_util::{run_pilcom_with_backend_variant, BackendVariant}, + Pipeline, +}; use powdr_riscv::Runtime; use std::{ path::{Path, PathBuf}, process::Command, }; -/// Like compiler::test_util::verify_asm_string, but also runs RISCV executor. +/// Like compiler::test_util::run_pilcom_asm_string, but also runs RISCV executor. pub fn verify_riscv_asm_string( file_name: &str, contents: &str, inputs: &[GoldilocksField], data: Option<&[(u32, S)]>, - backend: BackendType, ) { let temp_dir = mktemp::Temp::new_dir().unwrap().release(); @@ -38,7 +39,7 @@ pub fn verify_riscv_asm_string( powdr_riscv_executor::ExecMode::Fast, Default::default(), ); - verify_pipeline(pipeline, backend).unwrap(); + run_pilcom_with_backend_variant(pipeline, BackendVariant::Composite).unwrap(); } fn find_assembler() -> &'static str { @@ -88,11 +89,5 @@ pub fn verify_riscv_asm_file(asm_file: &Path, runtime: &Runtime, use_pie: bool) let case_name = asm_file.file_stem().unwrap().to_str().unwrap(); let powdr_asm = powdr_riscv::elf::translate::(&executable, runtime, false); - verify_riscv_asm_string::<()>( - &format!("{case_name}.asm"), - &powdr_asm, - &[], - None, - BackendType::EStarkDumpComposite, - ); + verify_riscv_asm_string::<()>(&format!("{case_name}.asm"), &powdr_asm, &[], None); } diff --git a/riscv/tests/instructions.rs b/riscv/tests/instructions.rs index f077efe81..5dc811062 100644 --- a/riscv/tests/instructions.rs +++ b/riscv/tests/instructions.rs @@ -4,7 +4,6 @@ mod instruction_tests { use std::path::Path; use crate::common::{verify_riscv_asm_file, verify_riscv_asm_string}; - use powdr_backend::BackendType; use powdr_number::GoldilocksField; use powdr_riscv::asm::compile; use powdr_riscv::Runtime; @@ -28,13 +27,7 @@ mod instruction_tests { false, ); - verify_riscv_asm_string::<()>( - &format!("{name}.asm"), - &powdr_asm, - Default::default(), - None, - BackendType::EStarkDumpComposite, - ); + verify_riscv_asm_string::<()>(&format!("{name}.asm"), &powdr_asm, Default::default(), None); } include!(concat!(env!("OUT_DIR"), "/instruction_tests.rs")); diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 1f1e0623b..546b3aa3d 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -2,9 +2,11 @@ mod common; use common::{verify_riscv_asm_file, verify_riscv_asm_string}; use mktemp::Temp; -use powdr_backend::BackendType; use powdr_number::GoldilocksField; -use powdr_pipeline::{test_util::verify_pipeline, Pipeline}; +use powdr_pipeline::{ + test_util::{run_pilcom_with_backend_variant, BackendVariant}, + Pipeline, +}; use std::path::{Path, PathBuf}; use test_log::test; @@ -47,7 +49,7 @@ fn run_continuations_test(case: &str, powdr_asm: String) { .with_prover_inputs(Default::default()) .with_output(tmp_dir.to_path_buf(), false); let pipeline_callback = |pipeline: Pipeline| -> Result<(), ()> { - verify_pipeline(pipeline, BackendType::EStarkDumpComposite).unwrap(); + run_pilcom_with_backend_variant(pipeline, BackendVariant::Composite).unwrap(); Ok(()) }; @@ -143,14 +145,13 @@ fn keccak() { #[ignore = "Too slow"] fn vec_median_estark_polygon() { let case = "vec_median"; - verify_riscv_crate_with_backend( + verify_riscv_crate( case, [5, 11, 15, 75, 6, 5, 1, 4, 7, 3, 2, 9, 2] .into_iter() .map(|x| x.into()) .collect(), &Runtime::base(), - BackendType::EStarkPolygonComposite, ); } @@ -364,17 +365,7 @@ fn many_chunks_memory() { } fn verify_riscv_crate(case: &str, inputs: Vec, runtime: &Runtime) { - verify_riscv_crate_with_backend(case, inputs.clone(), runtime, BackendType::EStarkDump); - verify_riscv_crate_with_backend(case, inputs, runtime, BackendType::EStarkDumpComposite); -} - -fn verify_riscv_crate_with_backend( - case: &str, - inputs: Vec, - runtime: &Runtime, - backend: BackendType, -) { - verify_riscv_crate_from_both_paths::<()>(case, inputs, runtime, None, backend) + verify_riscv_crate_from_both_paths::<()>(case, inputs, runtime, None) } fn verify_riscv_crate_with_data( @@ -383,13 +374,7 @@ fn verify_riscv_crate_with_data( runtime: &Runtime, data: Vec<(u32, S)>, ) { - verify_riscv_crate_from_both_paths( - case, - inputs, - runtime, - Some(data), - BackendType::EStarkDumpComposite, - ) + verify_riscv_crate_from_both_paths(case, inputs, runtime, Some(data)) } fn verify_riscv_crate_from_both_paths( @@ -397,7 +382,6 @@ fn verify_riscv_crate_from_both_paths, runtime: &Runtime, data: Option>, - backend: BackendType, ) { let temp_dir = Temp::new_dir().unwrap(); let compiled = powdr_riscv::compile_rust_crate_to_riscv( @@ -416,7 +400,6 @@ fn verify_riscv_crate_from_both_paths Date: Wed, 17 Jul 2024 15:54:51 +0200 Subject: [PATCH 12/24] Turn range check lookups into links (#1557) --- std/machines/byte2.asm | 1 - std/machines/memory.asm | 7 +++++-- std/machines/memory_with_bootloader_write.asm | 8 +++++--- std/machines/split/mod.asm | 20 ++++++++++++++++++- std/machines/split/split_bn254.asm | 15 ++++---------- std/machines/split/split_gl.asm | 15 ++++---------- 6 files changed, 37 insertions(+), 29 deletions(-) diff --git a/std/machines/byte2.asm b/std/machines/byte2.asm index 074c43fa8..7e19e90eb 100644 --- a/std/machines/byte2.asm +++ b/std/machines/byte2.asm @@ -1,6 +1,5 @@ /// A machine to check that a field element represents two bytes. It uses an exhaustive lookup table. machine Byte2 with - degree: 65536, latch: latch, operation_id: operation_id { diff --git a/std/machines/memory.asm b/std/machines/memory.asm index 6e0599939..1a0793e0c 100644 --- a/std/machines/memory.asm +++ b/std/machines/memory.asm @@ -1,4 +1,5 @@ use std::array; +use std::machines::byte2::Byte2; // A read/write memory, similar to that of Polygon: // https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/mem.pil @@ -9,6 +10,8 @@ machine Memory with { // lower bound degree is 65536 + Byte2 byte2; + operation mload<0> m_addr, m_step -> m_value; operation mstore<1> m_addr, m_step, m_value ->; @@ -53,8 +56,8 @@ machine Memory with col fixed STEP(i) { i }; col fixed BIT16(i) { i & 0xffff }; - [m_diff_lower] in [BIT16]; - [m_diff_upper] in [BIT16]; + link => byte2.check(m_diff_lower); + link => byte2.check(m_diff_upper); std::utils::force_bool(m_change); diff --git a/std/machines/memory_with_bootloader_write.asm b/std/machines/memory_with_bootloader_write.asm index 40289aa0d..4b81b12be 100644 --- a/std/machines/memory_with_bootloader_write.asm +++ b/std/machines/memory_with_bootloader_write.asm @@ -1,4 +1,5 @@ use std::array; +use std::machines::byte2::Byte2; /// This machine is a slightly extended version of std::machines::memory::Memory, /// where in addition to mstore, there is an mstore_bootloader operation. It behaves @@ -11,6 +12,8 @@ machine MemoryWithBootloaderWrite with { // lower bound degree is 65536 + Byte2 byte2; + operation mload<0> m_addr, m_step -> m_value; operation mstore<1> m_addr, m_step, m_value ->; operation mstore_bootloader<2> m_addr, m_step, m_value ->; @@ -61,10 +64,9 @@ machine MemoryWithBootloaderWrite with col fixed FIRST = [1] + [0]*; let LAST = FIRST'; col fixed STEP(i) { i }; - col fixed BIT16(i) { i & 0xffff }; - [m_diff_lower] in [BIT16]; - [m_diff_upper] in [BIT16]; + link => byte2.check(m_diff_lower); + link => byte2.check(m_diff_upper); std::utils::force_bool(m_change); diff --git a/std/machines/split/mod.asm b/std/machines/split/mod.asm index c08ecdf61..27f4ecc8e 100644 --- a/std/machines/split/mod.asm +++ b/std/machines/split/mod.asm @@ -1,2 +1,20 @@ mod split_bn254; -mod split_gl; \ No newline at end of file +mod split_gl; + +use std::utils::cross_product; + +// Byte comparison block machine +machine ByteCompare with latch: latch, operation_id: operation_id { + let inputs = cross_product([256, 256]); + let a: int -> int = inputs[0]; + let b: int -> int = inputs[1]; + let P_A: col = a; + let P_B: col = b; + col fixed P_LT(i) { if a(i) < b(i) { 1 } else { 0 } }; + col fixed P_GT(i) { if a(i) > b(i) { 1 } else { 0 } }; + + operation run<0> P_A, P_B -> P_LT, P_GT; + + col fixed latch = [1]*; + col fixed operation_id = [0]*; +} \ No newline at end of file diff --git a/std/machines/split/split_bn254.asm b/std/machines/split/split_bn254.asm index e439ca76e..f19f58e39 100644 --- a/std/machines/split/split_bn254.asm +++ b/std/machines/split/split_bn254.asm @@ -1,5 +1,5 @@ -use std::utils::cross_product; use std::prover::Query; +use super::ByteCompare; // Splits an arbitrary field element into 8 u32s (in little endian order), on the BN254 field. machine SplitBN254 with @@ -7,6 +7,8 @@ machine SplitBN254 with // Allow this machine to be connected via a permutation call_selectors: sel, { + ByteCompare byte_compare; + operation split in_acc -> o1, o2, o3, o4, o5, o6, o7, o8; // Latch and operation ID @@ -66,19 +68,10 @@ machine SplitBN254 with // so the maximum value is 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000. col fixed BYTES_MAX = [0x00, 0x00, 0xf0, 0x93, 0xf5, 0xe1, 0x43, 0x91, 0x70, 0xb9, 0x79, 0x48, 0xe8, 0x33, 0x28, 0x5d, 0x58, 0x81, 0x81, 0xb6, 0x45, 0x50, 0xb8, 0x29, 0xa0, 0x31, 0xe1, 0x72, 0x4e, 0x64, 0x30, 0x00]*; - // Byte comparison block machine - let compare_inputs = cross_product([256, 256]); - let a = compare_inputs[1]; - let b = compare_inputs[0]; - let P_A: col = a; - let P_B: col = b; - col fixed P_LT(i) { if a(i) < b(i) { 1 } else { 0 } }; - col fixed P_GT(i) { if a(i) > b(i) { 1 } else { 0 } }; - // Compare the current byte with the corresponding byte of the maximum value. col witness lt; col witness gt; - [ bytes, BYTES_MAX, lt, gt ] in [ P_A, P_B, P_LT, P_GT ]; + link => (lt, gt) = byte_compare.run(bytes, BYTES_MAX); // Compute whether the current or any previous byte has been less than // the corresponding byte of the maximum value. diff --git a/std/machines/split/split_gl.asm b/std/machines/split/split_gl.asm index dce42c09a..e958c7581 100644 --- a/std/machines/split/split_gl.asm +++ b/std/machines/split/split_gl.asm @@ -1,5 +1,5 @@ -use std::utils::cross_product; use std::prover::Query; +use super::ByteCompare; // Splits an arbitrary field element into two u32s, on the Goldilocks field. machine SplitGL with @@ -7,6 +7,8 @@ machine SplitGL with // Allow this machine to be connected via a permutation call_selectors: sel, { + ByteCompare byte_compare; + operation split in_acc -> output_low, output_high; // Latch and operation ID @@ -62,19 +64,10 @@ machine SplitGL with // Bytes of the maximum value, in little endian order, rotated by one col fixed BYTES_MAX = [0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0]*; - // Byte comparison block machine - let inputs = cross_product([256, 256]); - let a: int -> int = inputs[0]; - let b: int -> int = inputs[1]; - let P_A: col = a; - let P_B: col = b; - col fixed P_LT(i) { if a(i) < b(i) { 1 } else { 0 } }; - col fixed P_GT(i) { if a(i) > b(i) { 1 } else { 0 } }; - // Compare the current byte with the corresponding byte of the maximum value. col witness lt; col witness gt; - [ bytes, BYTES_MAX, lt, gt ] in [ P_A, P_B, P_LT, P_GT ]; + link => (lt, gt) = byte_compare.run(bytes, BYTES_MAX); // Compute whether the current or any previous byte has been less than // the corresponding byte of the maximum value. From ebc4d56e7b496e0e6577ec03c0794ed359a06c91 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Wed, 17 Jul 2024 19:48:57 +0200 Subject: [PATCH 13/24] `CompositeBackend`: Print maximum degree per machine (#1589) Computes & prints the maximum degree of any identity for every machine. Example output for Keccak: ``` Instantiating a composite backend with 7 machines: * main: * Number of witness columns: 247 * Number of fixed columns: 195 * Maximum identity degree: 4 * Number of identities: * Polynomial: 68 * Plookup: 16 * main_binary: * Number of witness columns: 9 * Number of fixed columns: 2 * Maximum identity degree: 2 * Number of identities: * Polynomial: 8 * main_binary_byte_binary: * Number of witness columns: 1 * Number of fixed columns: 4 * Maximum identity degree: 1 * Number of identities: * Polynomial: 1 * main_memory: * Number of witness columns: 9 * Number of fixed columns: 2 * Maximum identity degree: 3 * Number of identities: * Polynomial: 12 * Plookup: 2 * main_shift: * Number of witness columns: 8 * Number of fixed columns: 3 * Maximum identity degree: 2 * Number of identities: * Polynomial: 7 * main_shift_byte_shift: * Number of witness columns: 1 * Number of fixed columns: 5 * Maximum identity degree: 1 * Number of identities: * Polynomial: 1 * main_split_gl: * Number of witness columns: 9 * Number of fixed columns: 9 * Maximum identity degree: 3 * Number of identities: * Polynomial: 7 * Plookup: 1 ``` --- ast/src/analyzed/mod.rs | 53 ++++++++++++++++++++++++++++++++++++ backend/src/composite/mod.rs | 49 +++++++++++++++++++++------------ 2 files changed, 85 insertions(+), 17 deletions(-) diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 66a3ad30c..cbe0a3c08 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -783,6 +783,10 @@ impl Identity>> { a => (a, None), } } + + pub fn degree(&self) -> usize { + self.children().map(|e| e.degree()).max().unwrap_or(0) + } } impl Identity>> { @@ -1120,6 +1124,22 @@ impl AlgebraicExpression { } } } + + /// Returns the degree of the expressions + pub fn degree(&self) -> usize { + match self { + // One for each column + AlgebraicExpression::Reference(_) => 1, + // Multiplying two expressions adds their degrees + AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { + op: AlgebraicBinaryOperator::Mul, + left, + right, + }) => left.degree() + right.degree(), + // In all other cases, we take the maximum of the degrees of the children + _ => self.children().map(|e| e.degree()).max().unwrap_or(0), + } + } } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] @@ -1337,6 +1357,8 @@ impl Display for PolynomialType { mod tests { use powdr_parser_util::SourceRef; + use crate::analyzed::{AlgebraicReference, PolyID, PolynomialType}; + use super::{AlgebraicExpression, Analyzed}; #[test] @@ -1376,4 +1398,35 @@ mod tests { assert_eq!(pil.identities, pil_result.identities); assert_eq!(pil.source_order, pil_result.source_order); } + + #[test] + fn test_degree() { + let column = AlgebraicExpression::::Reference(AlgebraicReference { + name: "column".to_string(), + poly_id: PolyID { + id: 0, + ptype: PolynomialType::Committed, + }, + next: false, + }); + let one = AlgebraicExpression::Number(1); + + let expr = one.clone() + one.clone() * one.clone(); + assert_eq!(expr.degree(), 0); + + let expr = column.clone() + one.clone() * one.clone(); + assert_eq!(expr.degree(), 1); + + let expr = column.clone() + one.clone() * column.clone(); + assert_eq!(expr.degree(), 1); + + let expr = column.clone() + column.clone() * column.clone(); + assert_eq!(expr.degree(), 2); + + let expr = column.clone() + column.clone() * (column.clone() + one.clone()); + assert_eq!(expr.degree(), 2); + + let expr = column.clone() * column.clone() * column.clone(); + assert_eq!(expr.degree(), 3); + } } diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index e162e2466..535bee188 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -82,23 +82,7 @@ impl> BackendFactory for CompositeBacke pils.len() ); for (machine_name, pil) in pils.iter() { - let num_witness_columns = pil.committed_polys_in_source_order().len(); - let num_fixed_columns = pil.constant_polys_in_source_order().len(); - let num_identities_by_kind = pil - .identities - .iter() - .map(|i| i.kind) - .counts() - .into_iter() - .collect::>(); - - log::info!("* {}:", machine_name); - log::info!(" * Number of witness columns: {}", num_witness_columns); - log::info!(" * Number of fixed columns: {}", num_fixed_columns); - log::info!(" * Number of identities:"); - for (kind, count) in num_identities_by_kind { - log::info!(" * {:?}: {}", kind, count); - } + log_machine_stats(machine_name, pil) } let machine_data = pils @@ -143,6 +127,37 @@ impl> BackendFactory for CompositeBacke } } +fn log_machine_stats(machine_name: &str, pil: &Analyzed) { + let num_witness_columns = pil.committed_polys_in_source_order().len(); + let num_fixed_columns = pil.constant_polys_in_source_order().len(); + let max_identity_degree = pil + .identities_with_inlined_intermediate_polynomials() + .iter() + .map(|i| i.degree()) + .max() + .unwrap_or(0); + let uses_next_operator = pil.identities.iter().any(|i| i.contains_next_ref()); + // This assumes that we'll always at least once reference the current row + let number_of_rotations = 1 + if uses_next_operator { 1 } else { 0 }; + let num_identities_by_kind = pil + .identities + .iter() + .map(|i| i.kind) + .counts() + .into_iter() + .collect::>(); + + log::info!("* {}:", machine_name); + log::info!(" * Number of witness columns: {}", num_witness_columns); + log::info!(" * Number of fixed columns: {}", num_fixed_columns); + log::info!(" * Maximum identity degree: {}", max_identity_degree); + log::info!(" * Number of rotations: {}", number_of_rotations); + log::info!(" * Number of identities:"); + for (kind, count) in num_identities_by_kind { + log::info!(" * {:?}: {}", kind, count); + } +} + struct MachineData<'a, F> { pil: Arc>, backend: Box + 'a>, From 7936a5a09d9450231aaa0f516431fc56a36fbfad Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Fri, 19 Jul 2024 15:06:16 +0200 Subject: [PATCH 14/24] Test re-parsing of all test files after optimization (with blacklist) (#1590) A new iteration of #1476. These added tests demonstrate that #1488 is *mostly* solved (see [this comment](https://github.com/powdr-labs/powdr/issues/1488#issuecomment-2220759508) though): We parse, optimize, serialize, and re-parse all test files. A small number of test files need to be blacklisted for reasons explained in the comments. --- pipeline/build.rs | 30 +++++++++++++++++++++++------- pipeline/src/test_util.rs | 28 ++++++++++++++++++++++++++++ pipeline/tests/asm.rs | 19 +++++++++++++++++++ pipeline/tests/pil.rs | 6 ++++++ pipeline/tests/powdr_std.rs | 23 +++++++++++++++++++++++ 5 files changed, 99 insertions(+), 7 deletions(-) diff --git a/pipeline/build.rs b/pipeline/build.rs index f7eb75b7d..e958858dd 100644 --- a/pipeline/build.rs +++ b/pipeline/build.rs @@ -8,20 +8,36 @@ use walkdir::WalkDir; fn main() { build_book_tests("asm"); build_book_tests("pil"); + build_reparse_test("asm", "asm"); + build_reparse_test("pil", "pil"); + build_reparse_test("asm", "std"); } -#[allow(clippy::print_stdout)] fn build_book_tests(kind: &str) { + build_tests(kind, kind, "book", "book") +} + +fn build_reparse_test(kind: &str, dir: &str) { + build_tests(kind, dir, "", "reparse") +} + +#[allow(clippy::print_stdout)] +fn build_tests(kind: &str, dir: &str, sub_dir: &str, name: &str) { + let sub_dir = if sub_dir.is_empty() { + "".to_string() + } else { + format!("{sub_dir}/") + }; let out_dir = env::var("OUT_DIR").unwrap(); - let destination = Path::new(&out_dir).join(format!("{kind}_book_tests.rs")); + let destination = Path::new(&out_dir).join(format!("{dir}_{name}_tests.rs")); let mut test_file = BufWriter::new(File::create(destination).unwrap()); - let dir = format!("../test_data/{kind}/book/"); - for file in WalkDir::new(&dir) { + let full_dir = format!("../test_data/{dir}/{sub_dir}"); + for file in WalkDir::new(&full_dir) { let file = file.unwrap(); let relative_name = file .path() - .strip_prefix(&dir) + .strip_prefix(&full_dir) .unwrap() .to_str() .unwrap() @@ -30,13 +46,13 @@ fn build_book_tests(kind: &str) { .replace('/', "_sub_") .strip_suffix(&format!(".{kind}")) { - println!("cargo:rerun-if-changed={dir}/{relative_name}"); + println!("cargo:rerun-if-changed={full_dir}/{relative_name}"); write!( test_file, r#" #[test] fn {test_name}() {{ - run_book_test("{kind}/book/{relative_name}"); + run_{name}_test("{dir}/{sub_dir}{relative_name}"); }} "#, ) diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index e864e1876..3a3bb5df2 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -409,3 +409,31 @@ pub fn assert_proofs_fail_for_invalid_witnesses(file_name: &str, witness: &[(Str #[cfg(feature = "halo2")] assert_proofs_fail_for_invalid_witnesses_halo2(file_name, witness); } + +pub fn run_reparse_test(file: &str) { + run_reparse_test_with_blacklist(file, &[]); +} + +pub fn run_reparse_test_with_blacklist(file: &str, blacklist: &[&str]) { + if blacklist.contains(&file) { + return; + } + + // Load file + let pipeline = Pipeline::::default(); + let mut pipeline = if file.ends_with(".asm") { + pipeline.from_asm_file(resolve_test_file(file)) + } else { + pipeline.from_pil_file(resolve_test_file(file)) + }; + + // Compute the optimized PIL + let optimized_pil = pipeline.compute_optimized_pil().unwrap(); + + // Run the pipeline using the string serialization of the optimized PIL. + // This panics if the re-parsing fails. + Pipeline::::default() + .from_pil_string(optimized_pil.to_string()) + .compute_optimized_pil() + .unwrap(); +} diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index a022eb368..44027edec 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -527,6 +527,25 @@ fn vm_args_two_levels() { gen_estark_proof(f, Default::default()); } +mod reparse { + + use powdr_pipeline::test_util::run_reparse_test_with_blacklist; + use test_log::test; + + /// Files that we don't expect to parse, analyze, and optimize without error. + const BLACKLIST: [&str; 4] = [ + "asm/failing_assertion.asm", + "asm/multi_return_wrong_assignment_register_length.asm", + "asm/multi_return_wrong_assignment_registers.asm", + "asm/permutations/incoming_needs_selector.asm", + ]; + + fn run_reparse_test(file: &str) { + run_reparse_test_with_blacklist(file, &BLACKLIST) + } + include!(concat!(env!("OUT_DIR"), "/asm_reparse_tests.rs")); +} + mod book { use super::*; use test_log::test; diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index b13470072..b855563ab 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -349,6 +349,12 @@ fn serialize_deserialize_optimized_pil() { assert_eq!(input_pil_file, output_pil_file); } +mod reparse { + use powdr_pipeline::test_util::run_reparse_test; + use test_log::test; + include!(concat!(env!("OUT_DIR"), "/pil_reparse_tests.rs")); +} + mod book { use super::*; use test_log::test; diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 93047b0f5..9268bf539 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -381,3 +381,26 @@ fn btree() { let f = "std/btree_test.asm"; execute_test_file(f, Default::default(), vec![]).unwrap(); } + +mod reparse { + + use powdr_pipeline::test_util::run_reparse_test_with_blacklist; + use test_log::test; + + /// For convenience, all re-parsing tests run with the Goldilocks field, + /// but these tests panic if the field is too small. This is *probably* + /// fine, because all of these tests have a similar variant that does + /// run on Goldilocks. + const BLACKLIST: [&str; 5] = [ + "std/bus_permutation_via_challenges.asm", + "std/permutation_via_challenges.asm", + "std/lookup_via_challenges.asm", + "std/poseidon_bn254_test.asm", + "std/split_bn254_test.asm", + ]; + + fn run_reparse_test(file: &str) { + run_reparse_test_with_blacklist(file, &BLACKLIST); + } + include!(concat!(env!("OUT_DIR"), "/std_reparse_tests.rs")); +} From 8b32dcaa503fbccde3881e45373c215dd6bdfe00 Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Mon, 22 Jul 2024 12:42:32 +0200 Subject: [PATCH 15/24] Extract ROM machine (#1555) This PR continues the task to replace lookups by links. The advantages of this approach are: - with a link, the two machines can have different degrees and be proven separately - with a lookup, only a monolithic proof works - some backends such as Plonky3 do not support lookups - this is backwards compatible, since in the monolithic setting, links are turned into lookups In the compiler, we currently reduce each machine to another machine which has only pil code and links. To this end, in the case of virtual machines, we encode the program in fixed columns. This change introduces a separate machine to store the ROM. Therefore, each VM gets turned into not one, but two machines: ``` machine MyVM { } ``` becomes ``` machine MyVM { MyVMROM _rom; } machine MyVMROM { } ``` We introduce a new name alongside the original name, which pollutes the module. When raised, it was decided that we should not currently allow defining the ROM machine *inside* the VM. A better long term solution would be to have a generic `ROM` machine in the stdlib which can be instantiated with the fixed columns which encode the program, using them in the `get_line` operation. There are a few missing pieces in the asm language to enable that. --- airgen/src/lib.rs | 7 +- asm-to-pil/src/lib.rs | 50 ++++++-- asm-to-pil/src/vm_to_constrained.rs | 185 +++++++++++++++++++--------- linker/src/lib.rs | 98 +++++++++------ 4 files changed, 229 insertions(+), 111 deletions(-) diff --git a/airgen/src/lib.rs b/airgen/src/lib.rs index 13dfd5bd2..a8ffdc845 100644 --- a/airgen/src/lib.rs +++ b/airgen/src/lib.rs @@ -26,13 +26,14 @@ const MAIN_FUNCTION: &str = "main"; pub fn compile(input: AnalysisASMFile) -> PILGraph { let main_location = Location::main(); - let non_std_machines = input + let non_std_non_rom_machines = input .machines() .filter(|(k, _)| k.parts().next() != Some("std")) + .filter(|(k, _)| !k.parts().last().unwrap().ends_with("ROM")) .collect::>(); // we start from the main machine - let main_ty = match non_std_machines.len() { + let main_ty = match non_std_non_rom_machines.len() { 0 => { // There is no machine. Create an empty main machine but retain // all PIL utility definitions. @@ -50,7 +51,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph { }; } // if there is a single machine, treat it as main - 1 => (*non_std_machines.keys().next().unwrap()).clone(), + 1 => (*non_std_non_rom_machines.keys().next().unwrap()).clone(), // otherwise, use the machine called `MAIN` _ => { let p = parse_absolute_path(MAIN_MACHINE); diff --git a/asm-to-pil/src/lib.rs b/asm-to-pil/src/lib.rs index b70680831..581e7ce57 100644 --- a/asm-to-pil/src/lib.rs +++ b/asm-to-pil/src/lib.rs @@ -1,31 +1,55 @@ #![deny(clippy::print_stdout)] -use powdr_ast::asm_analysis::{AnalysisASMFile, Item}; +use powdr_ast::asm_analysis::{AnalysisASMFile, Item, SubmachineDeclaration}; use powdr_number::FieldElement; use romgen::generate_machine_rom; +use vm_to_constrained::ROM_SUBMACHINE_NAME; mod common; mod romgen; mod vm_to_constrained; +pub const ROM_SUFFIX: &str = "ROM"; + /// Remove all ASM from the machine tree. Takes a tree of virtual or constrained machines and returns a tree of constrained machines pub fn compile(file: AnalysisASMFile) -> AnalysisASMFile { AnalysisASMFile { items: file .items .into_iter() - .map(|(name, m)| { - ( - name, - match m { - Item::Machine(m) => { - let (m, rom) = generate_machine_rom::(m); - Item::Machine(vm_to_constrained::convert_machine::(m, rom)) + .flat_map(|(name, m)| match m { + Item::Machine(m) => { + let (m, rom) = generate_machine_rom::(m); + let (mut m, rom_machine) = vm_to_constrained::convert_machine::(m, rom); + + match rom_machine { + // in the absence of ROM, simply return the machine + None => vec![(name, Item::Machine(m))], + Some(rom_machine) => { + // introduce a new name for the ROM machine, based on the original name + let mut rom_name = name.clone(); + let machine_name = rom_name.pop().unwrap(); + rom_name.push(format!("{machine_name}{ROM_SUFFIX}")); + + // add the ROM as a submachine + m.submachines.push(SubmachineDeclaration { + name: ROM_SUBMACHINE_NAME.into(), + ty: rom_name.clone(), + args: vec![], + }); + + // return both the machine and the rom + vec![ + (name, Item::Machine(m)), + (rom_name, Item::Machine(rom_machine)), + ] } - Item::Expression(e) => Item::Expression(e), - Item::TypeDeclaration(enum_decl) => Item::TypeDeclaration(enum_decl), - Item::TraitDeclaration(trait_decl) => Item::TraitDeclaration(trait_decl), - }, - ) + } + } + Item::Expression(e) => vec![(name, Item::Expression(e))], + Item::TypeDeclaration(enum_decl) => vec![(name, Item::TypeDeclaration(enum_decl))], + Item::TraitDeclaration(trait_decl) => { + vec![(name, Item::TraitDeclaration(trait_decl))] + } }) .collect(), } diff --git a/asm-to-pil/src/vm_to_constrained.rs b/asm-to-pil/src/vm_to_constrained.rs index ac2de17c2..b7d5e2e09 100644 --- a/asm-to-pil/src/vm_to_constrained.rs +++ b/asm-to-pil/src/vm_to_constrained.rs @@ -1,29 +1,42 @@ //! Compilation from powdr assembly to PIL -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + iter::once, +}; use powdr_ast::{ asm_analysis::{ - combine_flags, AssignmentStatement, Batch, DebugDirective, FunctionStatement, - InstructionDefinitionStatement, InstructionStatement, LabelStatement, LinkDefinition, - Machine, RegisterDeclarationStatement, RegisterTy, Rom, + combine_flags, AssignmentStatement, Batch, CallableSymbol, CallableSymbolDefinitions, + DebugDirective, FunctionStatement, InstructionDefinitionStatement, InstructionStatement, + LabelStatement, LinkDefinition, Machine, OperationSymbol, RegisterDeclarationStatement, + RegisterTy, Rom, }, parsed::{ self, - asm::{CallableRef, InstructionBody, InstructionParams, LinkDeclaration}, + asm::{ + CallableParams, CallableRef, InstructionBody, InstructionParams, LinkDeclaration, + OperationId, Param, Params, + }, build::{self, absolute_reference, direct_reference, next_reference}, visitor::ExpressionVisitable, - ArrayExpression, ArrayLiteral, BinaryOperation, BinaryOperator, Expression, FunctionCall, + ArrayExpression, BinaryOperation, BinaryOperator, Expression, FunctionCall, FunctionDefinition, FunctionKind, LambdaExpression, MatchArm, MatchExpression, Number, - Pattern, PilStatement, PolynomialName, SelectedExpressions, UnaryOperation, UnaryOperator, + Pattern, PilStatement, PolynomialName, UnaryOperation, UnaryOperator, }, }; use powdr_number::{BigUint, FieldElement, LargeInt}; use powdr_parser_util::SourceRef; -use crate::common::{instruction_flag, return_instruction, RETURN_NAME}; +use crate::{ + common::{instruction_flag, return_instruction, RETURN_NAME}, + utils::parse_pil_statement, +}; -pub fn convert_machine(machine: Machine, rom: Option) -> Machine { +pub fn convert_machine( + machine: Machine, + rom: Option, +) -> (Machine, Option) { let output_count = machine .operations() .map(|f| f.params.outputs.len()) @@ -43,11 +56,64 @@ pub enum LiteralKind { UnsignedConstant, } +const ROM_OPERATION_ID: &str = "operation_id"; +const ROM_LATCH: &str = "latch"; +pub const ROM_SUBMACHINE_NAME: &str = "_rom"; +const ROM_ENTRY_POINT: &str = "get_line"; + +fn rom_machine<'a>( + mut pil: Vec, + mut line_lookup: impl Iterator, +) -> Machine { + Machine { + operation_id: Some(ROM_OPERATION_ID.into()), + latch: Some(ROM_LATCH.into()), + pil: { + pil.extend([ + parse_pil_statement(&format!("pol fixed {ROM_OPERATION_ID} = [0]*;")), + parse_pil_statement(&format!("pol fixed {ROM_LATCH} = [1]*;")), + ]); + pil + }, + callable: CallableSymbolDefinitions( + once(( + ROM_ENTRY_POINT.into(), + CallableSymbol::Operation(OperationSymbol { + source: SourceRef::unknown(), + id: OperationId { + id: Some(0u32.into()), + }, + params: Params { + inputs: (&mut line_lookup) + .take(1) + .map(|x| Param { + name: x.to_string(), + index: None, + ty: None, + }) + .collect(), + outputs: line_lookup + .map(|x| Param { + name: x.to_string(), + index: None, + ty: None, + }) + .collect(), + }, + }), + )) + .collect(), + ), + ..Default::default() + } +} + /// Component that turns a virtual machine into a constrained machine. /// TODO check if the conversion really depends on the finite field. #[derive(Default)] struct VMConverter { pil: Vec, + rom_pil: Vec, pc_name: Option, assignment_register_names: Vec, registers: BTreeMap, @@ -70,10 +136,14 @@ impl VMConverter { } } - fn convert_machine(mut self, mut input: Machine, rom: Option) -> Machine { + fn convert_machine( + mut self, + mut input: Machine, + rom: Option, + ) -> (Machine, Option) { if !input.has_pc() { assert!(rom.is_none()); - return input; + return (input, None); } // store the names of all assignment registers: we need them to generate assignment columns for other registers. @@ -177,41 +247,38 @@ impl VMConverter { self.translate_code_lines(); - self.pil.push(PilStatement::PlookupIdentity( - SourceRef::unknown(), - SelectedExpressions { - selector: None, - expressions: Box::new( - ArrayLiteral { - items: self - .line_lookup - .iter() - .map(|x| direct_reference(&x.0)) - .collect(), - } - .into(), - ), - }, - SelectedExpressions { - selector: None, - expressions: Box::new( - ArrayLiteral { - items: self - .line_lookup - .iter() - .map(|x| direct_reference(&x.1)) - .collect(), - } - .into(), - ), + input.links.push(LinkDefinition { + source: SourceRef::unknown(), + instr_flag: None, + link_flag: Expression::from(1u32), + to: CallableRef { + instance: ROM_SUBMACHINE_NAME.to_string(), + callable: ROM_ENTRY_POINT.to_string(), + params: CallableParams { + inputs: self.line_lookup[..1] + .iter() + .map(|x| direct_reference(&x.0)) + .collect(), + outputs: self.line_lookup[1..] + .iter() + .map(|x| direct_reference(&x.0)) + .collect(), + }, }, - )); + is_permutation: false, + }); if !self.pil.is_empty() { input.pil.extend(self.pil); } - input + ( + input, + Some(rom_machine( + self.rom_pil, + self.line_lookup.iter().map(|(_, x)| x.as_ref()), + )), + ) } fn handle_batch(&mut self, batch: Batch) { @@ -843,19 +910,20 @@ impl VMConverter { /// Translates the code lines to fixed column but also fills /// the query hints for the free inputs. fn translate_code_lines(&mut self) { - self.pil.push(PilStatement::PolynomialConstantDefinition( - SourceRef::unknown(), - "p_line".to_string(), - FunctionDefinition::Array( - ArrayExpression::Value( - (0..self.code_lines.len()) - .map(|i| BigUint::from(i as u64).into()) - .collect(), - ) - .pad_with_last() - .unwrap_or_else(|| ArrayExpression::RepeatedValue(vec![0.into()])), - ), - )); + self.rom_pil + .push(PilStatement::PolynomialConstantDefinition( + SourceRef::unknown(), + "p_line".to_string(), + FunctionDefinition::Array( + ArrayExpression::Value( + (0..self.code_lines.len()) + .map(|i| BigUint::from(i as u64).into()) + .collect(), + ) + .pad_with_last() + .unwrap_or_else(|| ArrayExpression::RepeatedValue(vec![0.into()])), + ), + )); // TODO check that all of them are matched against execution trace witnesses. let mut rom_constants = self .rom_constant_names @@ -994,11 +1062,12 @@ impl VMConverter { .pad_with_last() .unwrap_or_else(|| ArrayExpression::RepeatedValue(vec![0.into()])) }; - self.pil.push(PilStatement::PolynomialConstantDefinition( - SourceRef::unknown(), - name.clone(), - FunctionDefinition::Array(array_expression), - )); + self.rom_pil + .push(PilStatement::PolynomialConstantDefinition( + SourceRef::unknown(), + name.clone(), + FunctionDefinition::Array(array_expression), + )); } } diff --git a/linker/src/lib.rs b/linker/src/lib.rs index cbd0172a4..c2eecd2bd 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -271,12 +271,15 @@ mod test { pol constant first_step = [1] + [0]*; pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; + 1 $ [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return]; +namespace main__rom(4 + 4); pol constant p_line = [0, 1, 2] + [2]*; pol constant p_instr__jump_to_operation = [0, 1, 0] + [0]*; pol constant p_instr__loop = [0, 0, 1] + [1]*; pol constant p_instr__reset = [1, 0, 0] + [0]*; pol constant p_instr_return = [0]*; - [pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in [p_line, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return]; + pol constant operation_id = [0]*; + pol constant latch = [1]*; "#; let file_name = format!( @@ -341,9 +344,16 @@ mod test { A' = reg_write_X_A * X + reg_write_Y_A * Y + instr__reset * 0 + (1 - (reg_write_X_A + reg_write_Y_A + instr__reset)) * A; pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; - pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol commit X_free_value; pol commit Y_free_value; + 1 $ [0, pc, reg_write_X_A, reg_write_Y_A, instr_identity, instr_one, instr_nothing, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_pc] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_reg_write_X_A, main__rom.p_reg_write_Y_A, main__rom.p_instr_identity, main__rom.p_instr_one, main__rom.p_instr_nothing, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return, main__rom.p_X_const, main__rom.p_X_read_free, main__rom.p_read_X_A, main__rom.p_read_X_pc, main__rom.p_Y_const, main__rom.p_Y_read_free, main__rom.p_read_Y_A, main__rom.p_read_Y_pc]; + instr_identity $ [2, X, Y] in main_sub.instr_return $ [main_sub._operation_id, main_sub._input_0, main_sub._output_0]; + instr_nothing $ [3] in main_sub.instr_return $ [main_sub._operation_id]; + instr_one $ [4, Y] in main_sub.instr_return $ [main_sub._operation_id, main_sub._output_0]; + pol constant _linker_first_step = [1] + [0]*; + _linker_first_step * (_operation_id - 2) = 0; +namespace main__rom(16); + pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol constant p_X_const = [0]*; pol constant p_X_read_free = [0]*; pol constant p_Y_const = [0]*; @@ -361,12 +371,8 @@ mod test { pol constant p_read_Y_pc = [0]*; pol constant p_reg_write_X_A = [0]*; pol constant p_reg_write_Y_A = [0, 0, 1, 0, 0] + [0]*; - [pc, reg_write_X_A, reg_write_Y_A, instr_identity, instr_one, instr_nothing, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_pc] in [p_line, p_reg_write_X_A, p_reg_write_Y_A, p_instr_identity, p_instr_one, p_instr_nothing, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_pc, p_Y_const, p_Y_read_free, p_read_Y_A, p_read_Y_pc]; - instr_identity $ [2, X, Y] in main_sub.instr_return $ [main_sub._operation_id, main_sub._input_0, main_sub._output_0]; - instr_nothing $ [3] in main_sub.instr_return $ [main_sub._operation_id]; - instr_one $ [4, Y] in main_sub.instr_return $ [main_sub._operation_id, main_sub._output_0]; - pol constant _linker_first_step = [1] + [0]*; - _linker_first_step * (_operation_id - 2) = 0; + pol constant operation_id = [0]*; + pol constant latch = [1]*; namespace main_sub(16); pol commit _operation_id(i) query std::prover::Query::Hint(5); pol constant _block_enforcer_last_step = [0]* + [1]; @@ -388,8 +394,10 @@ namespace main_sub(16); (1 - instr__reset) * (_input_0' - _input_0) = 0; pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; - pol constant p_line = [0, 1, 2, 3, 4, 5] + [5]*; pol commit _output_0_free_value; + 1 $ [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return, _output_0_const, _output_0_read_free, read__output_0_pc, read__output_0__input_0] in main_sub__rom.latch $ [main_sub__rom.operation_id, main_sub__rom.p_line, main_sub__rom.p_instr__jump_to_operation, main_sub__rom.p_instr__reset, main_sub__rom.p_instr__loop, main_sub__rom.p_instr_return, main_sub__rom.p__output_0_const, main_sub__rom.p__output_0_read_free, main_sub__rom.p_read__output_0_pc, main_sub__rom.p_read__output_0__input_0]; +namespace main_sub__rom(16); + pol constant p_line = [0, 1, 2, 3, 4, 5] + [5]*; pol constant p__output_0_const = [0, 0, 0, 0, 1, 0] + [0]*; pol constant p__output_0_read_free = [0]*; pol constant p_instr__jump_to_operation = [0, 1, 0, 0, 0, 0] + [0]*; @@ -398,7 +406,8 @@ namespace main_sub(16); pol constant p_instr_return = [0, 0, 1, 1, 1, 0] + [0]*; pol constant p_read__output_0__input_0 = [0, 0, 1, 0, 0, 0] + [0]*; pol constant p_read__output_0_pc = [0]*; - [pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return, _output_0_const, _output_0_read_free, read__output_0_pc, read__output_0__input_0] in [p_line, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p__output_0_const, p__output_0_read_free, p_read__output_0_pc, p_read__output_0__input_0]; + pol constant operation_id = [0]*; + pol constant latch = [1]*; "#; let file_name = format!( "{}/../test_data/asm/different_signatures.asm", @@ -452,13 +461,17 @@ namespace main_sub(16); CNT' = reg_write_X_CNT * X + instr_dec_CNT * (CNT - 1) + instr__reset * 0 + (1 - (reg_write_X_CNT + instr_dec_CNT + instr__reset)) * CNT; pol pc_update = instr_jmpz * (instr_jmpz_pc_update + instr_jmpz_pc_update_1) + instr_jmp * instr_jmp_param_l + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_jmpz + instr_jmp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; - pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + [10]*; pol commit X_free_value(__i) query match std::prover::eval(pc) { 2 => std::prover::Query::Input(1), 4 => std::prover::Query::Input(std::convert::int(std::prover::eval(CNT) + 1)), 7 => std::prover::Query::Input(0), _ => std::prover::Query::None, }; + 1 $ [0, pc, reg_write_X_A, reg_write_X_CNT, instr_jmpz, instr_jmpz_param_l, instr_jmp, instr_jmp_param_l, instr_dec_CNT, instr_assert_zero, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_CNT, read_X_pc] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_reg_write_X_A, main__rom.p_reg_write_X_CNT, main__rom.p_instr_jmpz, main__rom.p_instr_jmpz_param_l, main__rom.p_instr_jmp, main__rom.p_instr_jmp_param_l, main__rom.p_instr_dec_CNT, main__rom.p_instr_assert_zero, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return, main__rom.p_X_const, main__rom.p_X_read_free, main__rom.p_read_X_A, main__rom.p_read_X_CNT, main__rom.p_read_X_pc]; + pol constant _linker_first_step = [1] + [0]*; + _linker_first_step * (_operation_id - 2) = 0; +namespace main__rom(1024); + pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + [10]*; pol constant p_X_const = [0]*; pol constant p_X_read_free = [0, 0, 1, 0, 1, 0, 0, 18446744069414584320, 0, 0, 0] + [0]*; pol constant p_instr__jump_to_operation = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*; @@ -476,9 +489,8 @@ namespace main_sub(16); pol constant p_read_X_pc = [0]*; pol constant p_reg_write_X_A = [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0] + [0]*; pol constant p_reg_write_X_CNT = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*; - [pc, reg_write_X_A, reg_write_X_CNT, instr_jmpz, instr_jmpz_param_l, instr_jmp, instr_jmp_param_l, instr_dec_CNT, instr_assert_zero, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_CNT, read_X_pc] in [p_line, p_reg_write_X_A, p_reg_write_X_CNT, p_instr_jmpz, p_instr_jmpz_param_l, p_instr_jmp, p_instr_jmp_param_l, p_instr_dec_CNT, p_instr_assert_zero, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_CNT, p_read_X_pc]; - pol constant _linker_first_step = [1] + [0]*; - _linker_first_step * (_operation_id - 2) = 0; + pol constant operation_id = [0]*; + pol constant latch = [1]*; "#; let file_name = format!( "{}/../test_data/asm/simple_sum.asm", @@ -527,6 +539,10 @@ machine Machine { fp' = instr_inc_fp * (fp + instr_inc_fp_param_amount) + instr_adjust_fp * (fp + instr_adjust_fp_param_amount) + instr__reset * 0 + (1 - (instr_inc_fp + instr_adjust_fp + instr__reset)) * fp; pol pc_update = instr_adjust_fp * instr_adjust_fp_param_t + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_adjust_fp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; + 1 $ [0, pc, instr_inc_fp, instr_inc_fp_param_amount, instr_adjust_fp, instr_adjust_fp_param_amount, instr_adjust_fp_param_t, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_instr_inc_fp, main__rom.p_instr_inc_fp_param_amount, main__rom.p_instr_adjust_fp, main__rom.p_instr_adjust_fp_param_amount, main__rom.p_instr_adjust_fp_param_t, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return]; + pol constant _linker_first_step = [1] + [0]*; + _linker_first_step * (_operation_id - 2) = 0; +namespace main__rom(1024); pol constant p_line = [0, 1, 2, 3, 4] + [4]*; pol constant p_instr__jump_to_operation = [0, 1, 0, 0, 0] + [0]*; pol constant p_instr__loop = [0, 0, 0, 0, 1] + [1]*; @@ -537,9 +553,8 @@ machine Machine { pol constant p_instr_inc_fp = [0, 0, 1, 0, 0] + [0]*; pol constant p_instr_inc_fp_param_amount = [0, 0, 7, 0, 0] + [0]*; pol constant p_instr_return = [0]*; - [pc, instr_inc_fp, instr_inc_fp_param_amount, instr_adjust_fp, instr_adjust_fp_param_amount, instr_adjust_fp_param_t, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in [p_line, p_instr_inc_fp, p_instr_inc_fp_param_amount, p_instr_adjust_fp, p_instr_adjust_fp_param_amount, p_instr_adjust_fp_param_t, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return]; - pol constant _linker_first_step = [1] + [0]*; - _linker_first_step * (_operation_id - 2) = 0; + pol constant operation_id = [0]*; + pol constant latch = [1]*; "#; let graph = parse_analyze_and_compile::(source); let pil = link(graph).unwrap(); @@ -617,8 +632,13 @@ machine Main { A' = reg_write_X_A * X + instr_add5_into_A * A' + instr__reset * 0 + (1 - (reg_write_X_A + instr_add5_into_A + instr__reset)) * A; pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; - pol constant p_line = [0, 1, 2, 3] + [3]*; pol commit X_free_value; + instr_add5_into_A $ [0, X, A'] in main_vm.latch $ [main_vm.operation_id, main_vm.x, main_vm.y]; + 1 $ [0, pc, reg_write_X_A, instr_add5_into_A, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_pc] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_reg_write_X_A, main__rom.p_instr_add5_into_A, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return, main__rom.p_X_const, main__rom.p_X_read_free, main__rom.p_read_X_A, main__rom.p_read_X_pc]; + pol constant _linker_first_step = [1] + [0]*; + _linker_first_step * (_operation_id - 2) = 0; +namespace main__rom(1024); + pol constant p_line = [0, 1, 2, 3] + [3]*; pol constant p_X_const = [0, 0, 10, 0] + [0]*; pol constant p_X_read_free = [0]*; pol constant p_instr__jump_to_operation = [0, 1, 0, 0] + [0]*; @@ -629,10 +649,8 @@ machine Main { pol constant p_read_X_A = [0]*; pol constant p_read_X_pc = [0]*; pol constant p_reg_write_X_A = [0]*; - [pc, reg_write_X_A, instr_add5_into_A, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_pc] in [p_line, p_reg_write_X_A, p_instr_add5_into_A, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_pc]; - instr_add5_into_A $ [0, X, A'] in main_vm.latch $ [main_vm.operation_id, main_vm.x, main_vm.y]; - pol constant _linker_first_step = [1] + [0]*; - _linker_first_step * (_operation_id - 2) = 0; + pol constant operation_id = [0]*; + pol constant latch = [1]*; namespace main_vm(1024); pol commit operation_id; pol constant latch = [1]*; @@ -695,10 +713,16 @@ namespace main_vm(1024); B' = reg_write_X_B * X + reg_write_Y_B * Y + reg_write_Z_B * Z + instr_or_into_B * B' + instr__reset * 0 + (1 - (reg_write_X_B + reg_write_Y_B + reg_write_Z_B + instr_or_into_B + instr__reset)) * B; pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; - pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + [13]*; pol commit X_free_value; pol commit Y_free_value; pol commit Z_free_value; + instr_or_into_B $ [0, X, Y, B'] is main_bin.latch * main_bin.sel[0] $ [main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C]; + 1 $ [0, pc, reg_write_X_A, reg_write_Y_A, reg_write_Z_A, reg_write_X_B, reg_write_Y_B, reg_write_Z_B, instr_or, instr_or_into_B, instr_assert_eq, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_B, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_B, read_Y_pc, Z_const, Z_read_free, read_Z_A, read_Z_B, read_Z_pc] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_reg_write_X_A, main__rom.p_reg_write_Y_A, main__rom.p_reg_write_Z_A, main__rom.p_reg_write_X_B, main__rom.p_reg_write_Y_B, main__rom.p_reg_write_Z_B, main__rom.p_instr_or, main__rom.p_instr_or_into_B, main__rom.p_instr_assert_eq, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return, main__rom.p_X_const, main__rom.p_X_read_free, main__rom.p_read_X_A, main__rom.p_read_X_B, main__rom.p_read_X_pc, main__rom.p_Y_const, main__rom.p_Y_read_free, main__rom.p_read_Y_A, main__rom.p_read_Y_B, main__rom.p_read_Y_pc, main__rom.p_Z_const, main__rom.p_Z_read_free, main__rom.p_read_Z_A, main__rom.p_read_Z_B, main__rom.p_read_Z_pc]; + instr_or $ [0, X, Y, Z] is main_bin.latch * main_bin.sel[1] $ [main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C]; + pol constant _linker_first_step = [1] + [0]*; + _linker_first_step * (_operation_id - 2) = 0; +namespace main__rom(65536); + pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + [13]*; pol constant p_X_const = [0, 0, 2, 0, 1, 0, 3, 0, 2, 0, 1, 0, 0, 0] + [0]*; pol constant p_X_read_free = [0]*; pol constant p_Y_const = [0, 0, 3, 3, 2, 3, 4, 7, 3, 3, 2, 3, 0, 0] + [0]*; @@ -727,11 +751,8 @@ namespace main_vm(1024); pol constant p_reg_write_Y_B = [0]*; pol constant p_reg_write_Z_A = [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0] + [0]*; pol constant p_reg_write_Z_B = [0]*; - [pc, reg_write_X_A, reg_write_Y_A, reg_write_Z_A, reg_write_X_B, reg_write_Y_B, reg_write_Z_B, instr_or, instr_or_into_B, instr_assert_eq, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_B, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_B, read_Y_pc, Z_const, Z_read_free, read_Z_A, read_Z_B, read_Z_pc] in [p_line, p_reg_write_X_A, p_reg_write_Y_A, p_reg_write_Z_A, p_reg_write_X_B, p_reg_write_Y_B, p_reg_write_Z_B, p_instr_or, p_instr_or_into_B, p_instr_assert_eq, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_B, p_read_X_pc, p_Y_const, p_Y_read_free, p_read_Y_A, p_read_Y_B, p_read_Y_pc, p_Z_const, p_Z_read_free, p_read_Z_A, p_read_Z_B, p_read_Z_pc]; - instr_or_into_B $ [0, X, Y, B'] is main_bin.latch * main_bin.sel[0] $ [main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C]; - instr_or $ [0, X, Y, Z] is main_bin.latch * main_bin.sel[1] $ [main_bin.operation_id, main_bin.A, main_bin.B, main_bin.C]; - pol constant _linker_first_step = [1] + [0]*; - _linker_first_step * (_operation_id - 2) = 0; + pol constant operation_id = [0]*; + pol constant latch = [1]*; namespace main_bin(65536); pol commit operation_id; pol constant latch(i) { if i % 4 == 3 { 1 } else { 0 } }; @@ -840,11 +861,20 @@ namespace main_bin(65536); C' = reg_write_X_C * X + reg_write_Y_C * Y + reg_write_Z_C * Z + reg_write_W_C * W + instr__reset * 0 + (1 - (reg_write_X_C + reg_write_Y_C + reg_write_Z_C + reg_write_W_C + instr__reset)) * C; pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; - pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18] + [18]*; pol commit X_free_value; pol commit Y_free_value; pol commit Z_free_value; pol commit W_free_value; + instr_add_to_A $ [0, X, Y, A'] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; + instr_add_BC_to_A $ [0, B, C, A'] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; + 1 $ [0, pc, reg_write_X_A, reg_write_Y_A, reg_write_Z_A, reg_write_W_A, reg_write_X_B, reg_write_Y_B, reg_write_Z_B, reg_write_W_B, reg_write_X_C, reg_write_Y_C, reg_write_Z_C, reg_write_W_C, instr_add, instr_sub_with_add, instr_addAB, instr_add3, instr_add_to_A, instr_add_BC_to_A, instr_sub, instr_add_with_sub, instr_assert_eq, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_B, read_X_C, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_B, read_Y_C, read_Y_pc, Z_const, Z_read_free, read_Z_A, read_Z_B, read_Z_C, read_Z_pc, W_const, W_read_free, read_W_A, read_W_B, read_W_C, read_W_pc] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_reg_write_X_A, main__rom.p_reg_write_Y_A, main__rom.p_reg_write_Z_A, main__rom.p_reg_write_W_A, main__rom.p_reg_write_X_B, main__rom.p_reg_write_Y_B, main__rom.p_reg_write_Z_B, main__rom.p_reg_write_W_B, main__rom.p_reg_write_X_C, main__rom.p_reg_write_Y_C, main__rom.p_reg_write_Z_C, main__rom.p_reg_write_W_C, main__rom.p_instr_add, main__rom.p_instr_sub_with_add, main__rom.p_instr_addAB, main__rom.p_instr_add3, main__rom.p_instr_add_to_A, main__rom.p_instr_add_BC_to_A, main__rom.p_instr_sub, main__rom.p_instr_add_with_sub, main__rom.p_instr_assert_eq, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return, main__rom.p_X_const, main__rom.p_X_read_free, main__rom.p_read_X_A, main__rom.p_read_X_B, main__rom.p_read_X_C, main__rom.p_read_X_pc, main__rom.p_Y_const, main__rom.p_Y_read_free, main__rom.p_read_Y_A, main__rom.p_read_Y_B, main__rom.p_read_Y_C, main__rom.p_read_Y_pc, main__rom.p_Z_const, main__rom.p_Z_read_free, main__rom.p_read_Z_A, main__rom.p_read_Z_B, main__rom.p_read_Z_C, main__rom.p_read_Z_pc, main__rom.p_W_const, main__rom.p_W_read_free, main__rom.p_read_W_A, main__rom.p_read_W_B, main__rom.p_read_W_C, main__rom.p_read_W_pc]; + instr_add + instr_add3 + instr_addAB + instr_sub_with_add $ [0, X * instr_add + X * instr_add3 + A * instr_addAB + X * instr_sub_with_add, Y * instr_add + Y * instr_add3 + B * instr_addAB + Z * instr_sub_with_add, Z * instr_add + tmp * instr_add3 + X * instr_addAB + Y * instr_sub_with_add] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; + instr_add3 $ [0, tmp, Z, W] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; + instr_add_with_sub + instr_sub $ [1, X * instr_add_with_sub + X * instr_sub, Y * instr_add_with_sub + Y * instr_sub, Z * instr_add_with_sub + Z * instr_sub] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.z, main_submachine.y]; + pol constant _linker_first_step = [1] + [0]*; + _linker_first_step * (_operation_id - 2) = 0; +namespace main__rom(1024); + pol constant p_line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18] + [18]*; pol constant p_W_const = [0]*; pol constant p_W_read_free = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] + [0]*; pol constant p_X_const = [0, 0, 2, 0, 6, 0, 6, 0, 6, 0, 20, 0, 0, 1, 0, 1, 0, 0, 0] + [0]*; @@ -894,14 +924,8 @@ namespace main_bin(65536); pol constant p_reg_write_Z_A = [0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] + [0]*; pol constant p_reg_write_Z_B = [0]*; pol constant p_reg_write_Z_C = [0]*; - [pc, reg_write_X_A, reg_write_Y_A, reg_write_Z_A, reg_write_W_A, reg_write_X_B, reg_write_Y_B, reg_write_Z_B, reg_write_W_B, reg_write_X_C, reg_write_Y_C, reg_write_Z_C, reg_write_W_C, instr_add, instr_sub_with_add, instr_addAB, instr_add3, instr_add_to_A, instr_add_BC_to_A, instr_sub, instr_add_with_sub, instr_assert_eq, instr__jump_to_operation, instr__reset, instr__loop, instr_return, X_const, X_read_free, read_X_A, read_X_B, read_X_C, read_X_pc, Y_const, Y_read_free, read_Y_A, read_Y_B, read_Y_C, read_Y_pc, Z_const, Z_read_free, read_Z_A, read_Z_B, read_Z_C, read_Z_pc, W_const, W_read_free, read_W_A, read_W_B, read_W_C, read_W_pc] in [p_line, p_reg_write_X_A, p_reg_write_Y_A, p_reg_write_Z_A, p_reg_write_W_A, p_reg_write_X_B, p_reg_write_Y_B, p_reg_write_Z_B, p_reg_write_W_B, p_reg_write_X_C, p_reg_write_Y_C, p_reg_write_Z_C, p_reg_write_W_C, p_instr_add, p_instr_sub_with_add, p_instr_addAB, p_instr_add3, p_instr_add_to_A, p_instr_add_BC_to_A, p_instr_sub, p_instr_add_with_sub, p_instr_assert_eq, p_instr__jump_to_operation, p_instr__reset, p_instr__loop, p_instr_return, p_X_const, p_X_read_free, p_read_X_A, p_read_X_B, p_read_X_C, p_read_X_pc, p_Y_const, p_Y_read_free, p_read_Y_A, p_read_Y_B, p_read_Y_C, p_read_Y_pc, p_Z_const, p_Z_read_free, p_read_Z_A, p_read_Z_B, p_read_Z_C, p_read_Z_pc, p_W_const, p_W_read_free, p_read_W_A, p_read_W_B, p_read_W_C, p_read_W_pc]; - instr_add_to_A $ [0, X, Y, A'] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; - instr_add_BC_to_A $ [0, B, C, A'] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; - instr_add + instr_add3 + instr_addAB + instr_sub_with_add $ [0, X * instr_add + X * instr_add3 + A * instr_addAB + X * instr_sub_with_add, Y * instr_add + Y * instr_add3 + B * instr_addAB + Z * instr_sub_with_add, Z * instr_add + tmp * instr_add3 + X * instr_addAB + Y * instr_sub_with_add] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; - instr_add3 $ [0, tmp, Z, W] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.y, main_submachine.z]; - instr_add_with_sub + instr_sub $ [1, X * instr_add_with_sub + X * instr_sub, Y * instr_add_with_sub + Y * instr_sub, Z * instr_add_with_sub + Z * instr_sub] in main_submachine.latch $ [main_submachine.operation_id, main_submachine.x, main_submachine.z, main_submachine.y]; - pol constant _linker_first_step = [1] + [0]*; - _linker_first_step * (_operation_id - 2) = 0; + pol constant operation_id = [0]*; + pol constant latch = [1]*; namespace main_submachine(1024); pol commit operation_id; pol constant latch = [1]*; From 60acaec5fcffb37a42a0dd47adb5af0ab7d6f7f4 Mon Sep 17 00:00:00 2001 From: onurinanc Date: Mon, 22 Jul 2024 14:50:09 +0300 Subject: [PATCH 16/24] Remove some unused code in some tests & std (#1592) --- std/machines/memory.asm | 2 -- std/machines/memory_with_bootloader_write.asm | 1 - test_data/std/poseidon_gl_memory_test.asm | 10 ---------- 3 files changed, 13 deletions(-) diff --git a/std/machines/memory.asm b/std/machines/memory.asm index 1a0793e0c..c68fe8bff 100644 --- a/std/machines/memory.asm +++ b/std/machines/memory.asm @@ -53,8 +53,6 @@ machine Memory with col fixed FIRST = [1] + [0]*; let LAST = FIRST'; - col fixed STEP(i) { i }; - col fixed BIT16(i) { i & 0xffff }; link => byte2.check(m_diff_lower); link => byte2.check(m_diff_upper); diff --git a/std/machines/memory_with_bootloader_write.asm b/std/machines/memory_with_bootloader_write.asm index 4b81b12be..3782e2471 100644 --- a/std/machines/memory_with_bootloader_write.asm +++ b/std/machines/memory_with_bootloader_write.asm @@ -63,7 +63,6 @@ machine MemoryWithBootloaderWrite with col fixed FIRST = [1] + [0]*; let LAST = FIRST'; - col fixed STEP(i) { i }; link => byte2.check(m_diff_lower); link => byte2.check(m_diff_upper); diff --git a/test_data/std/poseidon_gl_memory_test.asm b/test_data/std/poseidon_gl_memory_test.asm index 0166ee70c..124688957 100644 --- a/test_data/std/poseidon_gl_memory_test.asm +++ b/test_data/std/poseidon_gl_memory_test.asm @@ -6,16 +6,6 @@ machine Main with degree: 65536 { reg pc[@pc]; reg X1[<=]; reg X2[<=]; - reg X3[<=]; - reg X4[<=]; - reg X5[<=]; - reg X6[<=]; - reg X7[<=]; - reg X8[<=]; - reg X9[<=]; - reg X10[<=]; - reg X11[<=]; - reg X12[<=]; reg ADDR1[<=]; reg ADDR2[<=]; From d8218a2cae60a044ff19a0964f72f50413193cd7 Mon Sep 17 00:00:00 2001 From: Leo Date: Mon, 22 Jul 2024 15:17:02 +0200 Subject: [PATCH 17/24] Registers in memory (#1443) Fixes #1492 --------- Co-authored-by: schaeff --- cli-rs/src/main.rs | 2 +- pipeline/src/lib.rs | 36 + pipeline/src/pipeline.rs | 8 +- riscv-executor/src/lib.rs | 485 ++++++-- riscv-syscalls/src/lib.rs | 12 - riscv/src/code_gen.rs | 1054 ++++++++++++----- riscv/src/continuations.rs | 85 +- riscv/src/continuations/bootloader.rs | 242 ++-- riscv/src/runtime.rs | 301 ++--- riscv/tests/riscv.rs | 39 + riscv/tests/riscv_data/read_slice/Cargo.toml | 11 + .../riscv_data/read_slice/rust-toolchain.toml | 4 + riscv/tests/riscv_data/read_slice/src/lib.rs | 20 + std/machines/memory.asm | 210 ++++ 14 files changed, 1838 insertions(+), 671 deletions(-) create mode 100644 riscv/tests/riscv_data/read_slice/Cargo.toml create mode 100644 riscv/tests/riscv_data/read_slice/rust-toolchain.toml create mode 100644 riscv/tests/riscv_data/read_slice/src/lib.rs diff --git a/cli-rs/src/main.rs b/cli-rs/src/main.rs index 9bf28ec49..f09a8e668 100644 --- a/cli-rs/src/main.rs +++ b/cli-rs/src/main.rs @@ -435,7 +435,7 @@ fn execute( (false, false) => { let mut pipeline = pipeline.with_prover_inputs(inputs); let program = pipeline.compute_asm_string().unwrap().clone(); - let (trace, _mem) = powdr_riscv_executor::execute::( + let (trace, _mem, _reg_mem) = powdr_riscv_executor::execute::( &program.1, powdr_riscv_executor::MemoryState::new(), pipeline.data_callback().unwrap(), diff --git a/pipeline/src/lib.rs b/pipeline/src/lib.rs index 01541ebea..25252d912 100644 --- a/pipeline/src/lib.rs +++ b/pipeline/src/lib.rs @@ -148,6 +148,42 @@ pub fn serde_data_to_query_callback( + dict: BTreeMap>, +) -> impl QueryCallback { + move |query: &str| -> Result, String> { + let (id, data) = parse_query(query)?; + match id { + "None" => Ok(None), + "DataIdentifier" => { + let [index, cb_channel] = data[..] else { + panic!() + }; + let cb_channel = cb_channel + .parse::() + .map_err(|e| format!("Error parsing callback data channel: {e})"))?; + + if !dict.contains_key(&cb_channel) { + return Err("Callback channel mismatch".to_string()); + } + + let index = index + .parse::() + .map_err(|e| format!("Error parsing index: {e})"))?; + + let bytes = dict.get(&cb_channel).unwrap(); + + // query index 0 means the length + Ok(Some(match index { + 0 => (bytes.len() as u64).into(), + index => bytes[index - 1], + })) + } + _ => Err(format!("Unsupported query: {query}")), + } + } +} + pub fn inputs_to_query_callback(inputs: Vec) -> impl QueryCallback { move |query: &str| -> Result, String> { let (id, data) = parse_query(query)?; diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index 8e4a282cb..9836a10ad 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -30,9 +30,11 @@ use powdr_number::{write_polys_csv_file, write_polys_file, CsvRenderMode, FieldE use powdr_schemas::SerializedAnalyzed; use crate::{ - handle_simple_queries_callback, inputs_to_query_callback, serde_data_to_query_callback, + dict_data_to_query_callback, handle_simple_queries_callback, inputs_to_query_callback, + serde_data_to_query_callback, util::{read_poly_set, FixedPolySet, WitnessPolySet}, }; +use std::collections::BTreeMap; type Columns = Vec<(String, Vec)>; @@ -276,6 +278,10 @@ impl Pipeline { self.add_query_callback(Arc::new(inputs_to_query_callback(inputs))) } + pub fn with_prover_dict_inputs(self, inputs: BTreeMap>) -> Self { + self.add_query_callback(Arc::new(dict_data_to_query_callback(inputs))) + } + pub fn with_backend(mut self, backend: BackendType, options: Option) -> Self { self.arguments.backend = Some(backend); self.arguments.backend_options = options.unwrap_or_default(); diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index d35ba9436..7b8d8d83b 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -27,7 +27,6 @@ use powdr_ast::{ }, }; use powdr_number::{FieldElement, LargeInt}; -use powdr_riscv_syscalls::SYSCALL_REGISTERS; pub use profiler::ProfilerOptions; pub mod arith; @@ -110,6 +109,46 @@ impl Elem { Self::Field(f) => f.is_zero(), } } + + fn add(&self, other: &Self) -> Self { + match (self, other) { + (Self::Binary(a), Self::Binary(b)) => Self::Binary(a.checked_add(*b).unwrap()), + (Self::Field(a), Self::Field(b)) => Self::Field(*a + *b), + (Self::Binary(a), Self::Field(b)) => Self::Field(F::from(*a) + *b), + (Self::Field(a), Self::Binary(b)) => Self::Field(*a + F::from(*b)), + } + } + + fn sub(&self, other: &Self) -> Self { + match (self, other) { + (Self::Binary(a), Self::Binary(b)) => Self::Binary(a.checked_sub(*b).unwrap()), + (Self::Field(a), Self::Field(b)) => Self::Field(*a - *b), + (Self::Binary(a), Self::Field(b)) => Self::Field(F::from(*a) - *b), + (Self::Field(a), Self::Binary(b)) => Self::Field(*a - F::from(*b)), + } + } + + fn mul(&self, other: &Self) -> Self { + match (self, other) { + (Self::Binary(a), Self::Binary(b)) => match a.checked_mul(*b) { + Some(v) => Self::Binary(v), + None => { + let a = F::from(*a); + let b = F::from(*b); + Self::Field(a * b) + } + }, + (Self::Field(a), Self::Field(b)) => Self::Field(*a * *b), + (Self::Binary(a), Self::Field(b)) => Self::Field(F::from(*a) * *b), + (Self::Field(a), Self::Binary(b)) => Self::Field(*a * F::from(*b)), + } + } +} + +impl From for Elem { + fn from(value: i64) -> Self { + Self::Binary(value) + } } impl From for Elem { @@ -140,6 +179,7 @@ impl Display for Elem { } pub type MemoryState = HashMap; +pub type RegisterMemoryState = HashMap>; #[derive(Debug)] pub enum MemOperationKind { @@ -237,7 +277,7 @@ mod builder { use crate::{ Elem, ExecMode, ExecutionTrace, MemOperation, MemOperationKind, MemoryState, RegWrite, - PC_INITIAL_VAL, + RegisterMemoryState, PC_INITIAL_VAL, }; fn register_names(main: &Machine) -> Vec<&str> { @@ -260,7 +300,6 @@ mod builder { max_rows: usize, // index of special case registers to look after: - x0_idx: u16, pc_idx: u16, /// The value of PC at the start of the execution of the current row. @@ -280,6 +319,9 @@ mod builder { /// Current memory. mem: HashMap, + /// Separate register memory. + reg_mem: HashMap>, + /// The execution mode we running. /// Fast: do not save the register's trace and memory accesses. /// Trace: save everything - needed for continuations. @@ -292,13 +334,14 @@ mod builder { /// May fail if max_rows_len is too small or if the main machine is /// empty. In this case, the final (empty) execution trace is returned /// in Err. + #[allow(clippy::type_complexity)] pub fn new( main: &'a Machine, mem: MemoryState, batch_to_line_map: &'b [u32], max_rows_len: usize, mode: ExecMode, - ) -> Result, MemoryState)>> { + ) -> Result, MemoryState, RegisterMemoryState)>> { let reg_map = register_names(main) .into_iter() .enumerate() @@ -324,7 +367,6 @@ mod builder { regs[pc_idx as usize] = PC_INITIAL_VAL.into(); let mut ret = Self { - x0_idx: reg_map["x0"], pc_idx, curr_pc: PC_INITIAL_VAL.into(), trace: ExecutionTrace { @@ -338,6 +380,7 @@ mod builder { max_rows: max_rows_len, regs, mem, + reg_mem: Default::default(), mode, }; @@ -360,6 +403,9 @@ mod builder { /// get current value of register by register index instead of name fn get_reg_idx(&self, idx: u16) -> Elem { + if idx == self.pc_idx { + return self.get_pc(); + } self.regs[idx as usize] } @@ -380,9 +426,6 @@ mod builder { fn set_reg_impl(&mut self, idx: &str, value: Elem) { let idx = self.trace.reg_map[idx]; assert!(idx != self.pc_idx); - if idx == self.x0_idx { - return; - } self.set_reg_idx(idx, value); } @@ -446,8 +489,23 @@ mod builder { *self.mem.get(&addr).unwrap_or(&0) } - pub fn finish(self) -> (ExecutionTrace, MemoryState) { - (self.trace, self.mem) + pub(crate) fn set_reg_mem(&mut self, addr: u32, val: Elem) { + if addr != 0 { + self.reg_mem.insert(addr, val); + } + } + + pub(crate) fn get_reg_mem(&mut self, addr: u32) -> Elem { + let zero: Elem = 0u32.into(); + if addr == 0 { + zero + } else { + *self.reg_mem.get(&addr).unwrap_or(&zero) + } + } + + pub fn finish(self) -> (ExecutionTrace, MemoryState, RegisterMemoryState) { + (self.trace, self.mem, self.reg_mem) } /// Should we stop the execution because the maximum number of rows has @@ -600,40 +658,111 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { .collect::>(); match name { + "set_reg" => { + let addr = args[0].u(); + self.proc.set_reg_mem(addr, args[1]); + + Vec::new() + } + "get_reg" => { + let addr = args[0].u(); + let val = self.proc.get_reg_mem(addr); + + vec![val] + } + "move_reg" => { + let val = self.proc.get_reg_mem(args[0].u()); + let write_reg = args[1].u(); + let factor = args[2]; + let offset = args[3]; + + let val = val.mul(&factor).add(&offset); + + self.proc.set_reg_mem(write_reg, val); + + Vec::new() + } + "mstore" | "mstore_bootloader" => { - let addr = args[0].bin() as u32; + let addr1 = self.proc.get_reg_mem(args[0].u()); + let addr2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2].bin(); + let value = self.proc.get_reg_mem(args[3].u()); + + let addr = addr1.bin() - addr2.bin() + offset; + let addr = addr as u32; assert_eq!(addr % 4, 0); - self.proc.set_mem(addr, args[1].u()); + self.proc.set_mem(addr, value.u()); Vec::new() } "mload" => { - let addr = args[0].bin() as u32; - let val = self.proc.get_mem(addr & 0xfffffffc); + let addr1 = self.proc.get_reg_mem(args[0].u()); + let offset = args[1].bin(); + let write_addr1 = args[2].u(); + let write_addr2 = args[3].u(); + + let addr = addr1.bin() + offset; + + let val = self.proc.get_mem(addr as u32 & 0xfffffffc); let rem = addr % 4; - vec![val.into(), rem.into()] + self.proc.set_reg_mem(write_addr1, val.into()); + self.proc.set_reg_mem(write_addr2, rem.into()); + + Vec::new() } "load_bootloader_input" => { - let addr = args[0].bin() as usize; - let val = self.bootloader_inputs[addr]; + let addr = self.proc.get_reg_mem(args[0].u()); + let write_addr = args[1].u(); + let factor = args[2].bin(); + let offset = args[3].bin(); - vec![val] + let addr = addr.bin() * factor + offset; + let val = self.bootloader_inputs[addr as usize]; + + self.proc.set_reg_mem(write_addr, val); + + Vec::new() } "assert_bootloader_input" => { - let addr = args[0].bin() as usize; - let actual_val = self.bootloader_inputs[addr]; + let addr = self.proc.get_reg_mem(args[0].u()); + let val = self.proc.get_reg_mem(args[1].u()); + let factor = args[2].bin(); + let offset = args[3].bin(); - assert_eq!(args[1], actual_val); + let actual_val = self.bootloader_inputs[(addr.bin() * factor + offset) as usize]; - vec![] + assert_eq!(val, actual_val); + + Vec::new() } - "load_label" => args, - "jump" | "jump_dyn" => { + "load_label" => { + let write_reg = args[0].u(); + self.proc.set_reg_mem(write_reg, args[1]); + + Vec::new() + } + "jump" => { let next_pc = self.proc.get_pc().u() + 1; + let write_reg = args[1].u(); + + self.proc.set_reg_mem(write_reg, next_pc.into()); + self.proc.set_pc(args[0]); - vec![next_pc.into()] + Vec::new() + } + "jump_dyn" => { + let addr = self.proc.get_reg_mem(args[0].u()); + let next_pc = self.proc.get_pc().u() + 1; + let write_reg = args[1].u(); + + self.proc.set_reg_mem(write_reg, next_pc.into()); + + self.proc.set_pc(addr); + + Vec::new() } "jump_to_bootloader_input" => { let bootloader_input_idx = args[0].bin() as usize; @@ -642,84 +771,162 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { Vec::new() } - "branch_if_nonzero" => { - if !args[0].is_zero() { - self.proc.set_pc(args[1]); + "branch_if_diff_nonzero" => { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let val: Elem = val1.sub(&val2); + if !val.is_zero() { + self.proc.set_pc(args[2]); } Vec::new() } "branch_if_zero" => { - if args[0].is_zero() { - self.proc.set_pc(args[1]); + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2]; + let val: Elem = val1.sub(&val2.add(&offset)); + if val.is_zero() { + self.proc.set_pc(args[3]); } Vec::new() } "skip_if_zero" => { - if args[0].is_zero() { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2]; + let val: Elem = val1.sub(&val2).add(&offset); + + if val.is_zero() { let pc = self.proc.get_pc().s(); - self.proc.set_pc((pc + args[1].s() + 1).into()); + self.proc.set_pc((pc + args[3].s() + 1).into()); } Vec::new() } "branch_if_positive" => { - if args[0].bin() > 0 { - self.proc.set_pc(args[1]); + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2]; + let val: Elem = val1.sub(&val2).add(&offset); + if val.bin() > 0 { + self.proc.set_pc(args[3]); } Vec::new() } "is_positive" => { - let r = if args[0].bin() > 0 { 1 } else { 0 }; + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + + let offset = args[2]; + let write_reg = args[3].u(); + let val = val1.sub(&val2).add(&offset); - vec![r.into()] + let r = if val.bin() > 0 { 1 } else { 0 }; + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() } "is_equal_zero" => { - let r = if args[0].is_zero() { 1 } else { 0 }; + let val = self.proc.get_reg_mem(args[0].u()); + let write_reg = args[1].u(); + + let r = if val.is_zero() { 1 } else { 0 }; + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() + } + "is_not_equal" => { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let write_reg = args[2].u(); + let val: Elem = (val1.bin() - val2.bin()).into(); - vec![r.into()] + let r = if !val.is_zero() { 1 } else { 0 }; + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() } - "is_not_equal_zero" => { - let r = if !args[0].is_zero() { 1 } else { 0 }; + "add_wrap" => { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2]; + let write_reg = args[3].u(); + let val = val1.add(&val2).add(&offset); + // don't use .u() here: we are deliberately discarding the + // higher bits + let r = val.bin() as u32; + self.proc.set_reg_mem(write_reg, r.into()); - vec![r.into()] + Vec::new() } - "wrap" | "wrap16" => { + "wrap16" => { + let val = self.proc.get_reg_mem(args[0].u()); + let factor = args[1].bin(); + let write_reg = args[2].u(); + let val: Elem = (val.bin() * factor).into(); + // don't use .u() here: we are deliberately discarding the // higher bits - let r = args[0].bin() as u32; + let r = val.bin() as u32; + self.proc.set_reg_mem(write_reg, r.into()); - vec![r.into()] + Vec::new() } - "wrap_signed" => { - let r = (args[0].bin() + 0x100000000) as u32; + "sub_wrap_with_offset" => { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2]; + let write_reg = args[3].u(); + let val = val1.sub(&val2).add(&offset); + + let r = (val.bin() + 0x100000000) as u32; + self.proc.set_reg_mem(write_reg, r.into()); - vec![r.into()] + Vec::new() } "sign_extend_byte" => { - let r = args[0].u() as i8 as u32; + let val = self.proc.get_reg_mem(args[0].u()); + let write_reg = args[1].u(); + + let r = val.u() as i8 as u32; + self.proc.set_reg_mem(write_reg, r.into()); - vec![r.into()] + Vec::new() } "sign_extend_16_bits" => { - let r = args[0].u() as i16 as u32; + let val = self.proc.get_reg_mem(args[0].u()); + let write_reg = args[1].u(); + + let r = val.u() as i16 as u32; - vec![r.into()] + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() } "to_signed" => { - let r = args[0].u() as i32; + let val = self.proc.get_reg_mem(args[0].u()); + let write_reg = args[1].u(); + let r = val.u() as i32; - vec![r.into()] + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() } "fail" => { // TODO: handle it better panic!("reached a fail instruction") } "divremu" => { - let y = args[0].u(); - let x = args[1].u(); + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let write_reg1 = args[2].u(); + let write_reg2 = args[3].u(); + + let y = val1.u(); + let x = val2.u(); let div; let rem; if x != 0 { @@ -730,29 +937,77 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { rem = y; } - vec![div.into(), rem.into()] + self.proc.set_reg_mem(write_reg1, div.into()); + self.proc.set_reg_mem(write_reg2, rem.into()); + + Vec::new() } "mul" => { - let r = args[0].u() as u64 * args[1].u() as u64; + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let write_reg1 = args[2].u(); + let write_reg2 = args[3].u(); + + let r = val1.u() as u64 * val2.u() as u64; let lo = r as u32; let hi = (r >> 32) as u32; - vec![lo.into(), hi.into()] + self.proc.set_reg_mem(write_reg1, lo.into()); + self.proc.set_reg_mem(write_reg2, hi.into()); + + Vec::new() + } + "and" | "or" | "xor" => { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2].bin(); + let write_reg = args[3].u(); + let val2: Elem = (val2.bin() + offset).into(); + + let r = match name { + "and" => val1.u() & val2.u(), + "or" => val1.u() | val2.u(), + "xor" => val1.u() ^ val2.u(), + _ => unreachable!(), + }; + + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() + } + "shl" | "shr" => { + let val1 = self.proc.get_reg_mem(args[0].u()); + let val2 = self.proc.get_reg_mem(args[1].u()); + let offset = args[2].bin(); + let write_reg = args[3].u(); + let val2: Elem = (val2.bin() + offset).into(); + + let r = match name { + "shl" => val1.u() << val2.u(), + "shr" => val1.u() >> val2.u(), + _ => unreachable!(), + }; + + self.proc.set_reg_mem(write_reg, r.into()); + + Vec::new() } - "and" => vec![(args[0].u() & args[1].u()).into()], - "or" => vec![(args[0].u() | args[1].u()).into()], - "xor" => vec![(args[0].u() ^ args[1].u()).into()], - "shl" => vec![(args[0].u() << args[1].u()).into()], - "shr" => vec![(args[0].u() >> args[1].u()).into()], "split_gl" => { - let value = args[0].into_fe().to_integer(); + let val1 = self.proc.get_reg_mem(args[0].u()); + let write_reg1 = args[1].u(); + let write_reg2 = args[2].u(); + + let value = val1.into_fe().to_integer(); // This instruction is only for Goldilocks, so the value must // fit into a u64. let value = value.try_into_u64().unwrap(); let lo = (value & 0xffffffff) as u32; let hi = (value >> 32) as u32; - vec![lo.into(), hi.into()] + self.proc.set_reg_mem(write_reg1, lo.into()); + self.proc.set_reg_mem(write_reg2, hi.into()); + + Vec::new() } "poseidon_gl" => { assert!(args.is_empty()); @@ -764,49 +1019,51 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { self.proc .set_reg(®ister_by_idx(i), Elem::Field(result[i])) }); + vec![] } "affine_256" => { assert!(args.is_empty()); // take input from registers let x1 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 3)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i)).into_fe()) .collect::>(); let y1 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 11)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 8)).into_fe()) .collect::>(); let x2 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 19)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 16)).into_fe()) .collect::>(); let result = arith::affine_256(&x1, &y1, &x2); // store result in registers (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 3), Elem::Field(result.0[i])) + .set_reg(®ister_by_idx(i), Elem::Field(result.0[i])) }); (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 11), Elem::Field(result.1[i])) + .set_reg(®ister_by_idx(i + 8), Elem::Field(result.1[i])) }); + vec![] } "mod_256" => { assert!(args.is_empty()); // take input from registers let y2 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 3)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i)).into_fe()) .collect::>(); let y3 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 11)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 8)).into_fe()) .collect::>(); let x1 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 19)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 16)).into_fe()) .collect::>(); let result = arith::mod_256(&y2, &y3, &x1); // store result in registers (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 3), Elem::Field(result[i])) + .set_reg(®ister_by_idx(i), Elem::Field(result[i])) }); vec![] } @@ -814,48 +1071,50 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { assert!(args.is_empty()); // take input from registers let x1 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 4)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i)).into_fe()) .collect::>(); let y1 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 12)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 8)).into_fe()) .collect::>(); let x2 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 20)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 16)).into_fe()) .collect::>(); let y2 = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 28)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 24)).into_fe()) .collect::>(); let result = arith::ec_add(&x1, &y1, &x2, &y2); // store result in registers (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 4), Elem::Field(result.0[i])) + .set_reg(®ister_by_idx(i), Elem::Field(result.0[i])) }); (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 12), Elem::Field(result.1[i])) + .set_reg(®ister_by_idx(i + 8), Elem::Field(result.1[i])) }); + vec![] } "ec_double" => { assert!(args.is_empty()); // take input from registers let x = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 2)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i)).into_fe()) .collect::>(); let y = (0..8) - .map(|i| self.proc.get_reg(®ister_by_idx(i + 10)).into_fe()) + .map(|i| self.proc.get_reg(®ister_by_idx(i + 8)).into_fe()) .collect::>(); let result = arith::ec_double(&x, &y); // store result in registers (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 2), Elem::Field(result.0[i])) + .set_reg(®ister_by_idx(i), Elem::Field(result.0[i])) }); (0..8).for_each(|i| { self.proc - .set_reg(®ister_by_idx(i + 10), Elem::Field(result.1[i])) + .set_reg(®ister_by_idx(i + 8), Elem::Field(result.1[i])) }); + vec![] } instr => { @@ -1041,7 +1300,7 @@ pub fn execute_ast( max_steps_to_execute: usize, mode: ExecMode, profiling: Option, -) -> (ExecutionTrace, MemoryState) { +) -> (ExecutionTrace, MemoryState, RegisterMemoryState) { let main_machine = get_main_machine(program); let PreprocessedMain { statements, @@ -1101,7 +1360,11 @@ pub fn execute_ast( if a.lhs_with_reg[0].0 == "tmp1" { p.jump(pc_after); } else { - p.jump_and_link(pc_before, pc_after, pc_return); + p.jump_and_link( + pc_before as usize, + pc_after as usize, + pc_return as usize, + ); } } } @@ -1111,11 +1374,37 @@ pub fn execute_ast( } } FunctionStatement::Instruction(i) => { - assert!(!["jump", "jump_dyn"].contains(&i.instruction.as_str())); if let Some(p) = &mut profiler { p.add_instruction_cost(e.proc.get_pc().u() as usize); } - e.exec_instruction(&i.instruction, &i.inputs); + + if ["jump", "jump_dyn"].contains(&i.instruction.as_str()) { + let pc_return = e.proc.get_pc().u() + 1; + let pc_before = e.proc.get_reg("pc").u(); + + e.exec_instruction(&i.instruction, &i.inputs); + + let pc_after = e.proc.get_reg("pc").u(); + + let target_reg = e.eval_expression(&i.inputs[1]); + assert_eq!(target_reg.len(), 1); + let target_reg = target_reg[0].u(); + + if let Some(p) = &mut profiler { + // in the generated powdr asm, not writing to `x1` means the returning pc is ignored + if target_reg != 1 { + p.jump(pc_after as usize); + } else { + p.jump_and_link( + pc_before as usize, + pc_after as usize, + pc_return as usize, + ); + } + } + } else { + e.exec_instruction(&i.instruction, &i.inputs); + } } FunctionStatement::Return(_) => break, FunctionStatement::DebugDirective(dd) => { @@ -1163,7 +1452,7 @@ pub fn execute( bootloader_inputs: &[Elem], mode: ExecMode, profiling: Option, -) -> (ExecutionTrace, MemoryState) { +) -> (ExecutionTrace, MemoryState, RegisterMemoryState) { log::info!("Parsing..."); let parsed = powdr_parser::parse_asm(None, asm_source).unwrap(); log::info!("Resolving imports..."); @@ -1185,22 +1474,6 @@ pub fn execute( /// FIXME: copied from `riscv/runtime.rs` instead of adding dependency. /// Helper function for register names used in submachine instruction params. -fn register_by_idx(mut idx: usize) -> String { - // s0..11 callee saved registers - static SAVED_REGS: [&str; 12] = [ - "x8", "x9", "x18", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", - ]; - - // first, use syscall_registers - if idx < SYSCALL_REGISTERS.len() { - return SYSCALL_REGISTERS[idx].to_string(); - } - idx -= SYSCALL_REGISTERS.len(); - // second, callee saved registers - if idx < SAVED_REGS.len() { - return SAVED_REGS[idx].to_string(); - } - idx -= SAVED_REGS.len(); - // lastly, use extra submachine registers +fn register_by_idx(idx: usize) -> String { format!("xtra{idx}") } diff --git a/riscv-syscalls/src/lib.rs b/riscv-syscalls/src/lib.rs index e5ded5587..91bec597d 100644 --- a/riscv-syscalls/src/lib.rs +++ b/riscv-syscalls/src/lib.rs @@ -1,17 +1,5 @@ #![no_std] -/// For parameter passing and return values in `ecall`, we allow use of all "caller saved" registers in RISC-V: `a0-a7,t0-t6`. -/// Register `t0` is reserved for passing the syscall number. -/// By convention, arguments and return values are used in this given order -/// (e.g., if 7 registers are needed, use reg indexes `0..6`). -/// Registers directly used as input/output to the ecall shouldn't need any special handling. -/// Any other register used inside the syscall implementation *must be explicitly saved and restored*, -/// as they are not visible from LLVM. -/// For an example, see the `poseidon_gl` syscall implementation in the riscv runtime. -pub static SYSCALL_REGISTERS: [&str; 14] = [ - "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x6", "x7", "x28", "x29", "x30", "x31", -]; - macro_rules! syscalls { ($(($num:expr, $identifier:ident, $name:expr)),* $(,)?) => { #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] diff --git a/riscv/src/code_gen.rs b/riscv/src/code_gen.rs index fa1cd458c..3afdba318 100644 --- a/riscv/src/code_gen.rs +++ b/riscv/src/code_gen.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{fmt, vec}; use itertools::Itertools; use powdr_asm_utils::data_storage::SingleDataValue; @@ -21,13 +21,53 @@ impl Register { pub fn is_zero(&self) -> bool { self.value == 0 } + + pub fn addr(&self) -> u8 { + self.value + } } impl powdr_asm_utils::ast::Register for Register {} impl fmt::Display for Register { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "x{}", self.value) + if self.value < 32 { + // 0 indexed + write!(f, "x{}", self.value) + } else if self.value < 36 { + // 1 indexed + write!(f, "tmp{}", self.value - 31 + 1) + } else if self.value == 36 { + write!(f, "lr_sc_reservation") + } else { + // 0 indexed + write!(f, "xtra{}", self.value - 37) + } + } +} + +impl From<&str> for Register { + fn from(s: &str) -> Self { + if let Some(prefix) = s.strip_prefix("xtra") { + // 0 indexed + let value: u8 = prefix.parse().expect("Invalid register"); + Self::new(value + 37) + } else if let Some(prefix) = s.strip_prefix('x') { + // 0 indexed + let value = prefix.parse().expect("Invalid register"); + assert!(value < 32, "Invalid register"); + Self::new(value) + } else if let Some(prefix) = s.strip_prefix("tmp") { + // 1 indexed + let value: u8 = prefix.parse().expect("Invalid register"); + assert!(value >= 1); + assert!(value <= 4); + Self::new(value - 1 + 32) + } else if s == "lr_sc_reservation" { + Self::new(36) + } else { + panic!("Invalid register") + } } } @@ -95,10 +135,11 @@ pub fn translate_program( let (initial_mem, instructions, degree) = translate_program_impl(program, runtime, with_bootloader); + let degree_log = degree.ilog2(); riscv_machine( runtime, degree, - &preamble::(runtime, with_bootloader), + &preamble::(runtime, degree_log.into(), with_bootloader), initial_mem, instructions, ) @@ -130,9 +171,19 @@ fn translate_program_impl( // snapshot, committed by the bootloader. initial_mem.push(format!("(0x{addr:x}, 0x{v:x})")); } else { - // There is no bootloader to commit to memory, so we have to - // load it explicitly. - data_code.push(format!("mstore 0x{addr:x}, 0x{v:x};")); + data_code.push(format!( + "set_reg {}, 0x{v:x};", + Register::from("tmp2").addr() + )); + data_code.push(format!( + "set_reg {}, 0x{addr:x};", + Register::from("tmp1").addr() + )); + data_code.push(format!( + "mstore {}, 0, 0, {};", + Register::from("tmp1").addr(), + Register::from("tmp2").addr() + )); } } SingleDataValue::LabelReference(sym) => { @@ -141,8 +192,17 @@ fn translate_program_impl( // // TODO should be possible without temporary data_code.extend([ - format!("tmp1 <== load_label({});", escape_label(&sym)), - format!("mstore 0x{addr:x}, tmp1;"), + format!( + "load_label {}, {};", + Register::from("tmp2").addr(), + escape_label(&sym) + ), + format!("set_reg {}, 0x{addr:x};", Register::from("tmp1").addr()), + format!( + "mstore {}, 0, 0, {};", + Register::from("tmp1").addr(), + Register::from("tmp2").addr() + ), ]); } SingleDataValue::Offset(_, _) => { @@ -189,10 +249,11 @@ fn translate_program_impl( .chain(bootloader_and_shutdown_routine_lines) .collect(); if !data_code.is_empty() { - statements.push("x1 <== jump(__data_init);".to_string()); + statements.push("jump __data_init, 1;".to_string()); } statements.extend([ - format!("x1 <== jump({});", program.start_function().as_ref()), + "set_reg 0, 0;".to_string(), + format!("jump {}, 1;", program.start_function().as_ref()), "return;".to_string(), // This is not "riscv ret", but "return from powdr asm function". ]); for s in program.take_executable_statements() { @@ -213,12 +274,17 @@ fn translate_program_impl( if !data_code.is_empty() { statements.extend( - ["// This is the data initialization routine.\n__data_init:".to_string()].into_iter() - .chain(data_code) - .chain([ - "// This is the end of the data initialization routine.\ntmp1 <== jump_dyn(x1);" - .to_string(), - ])); + [ + "// This is the data initialization routine.".to_string(), + "__data_init:".to_string(), + ] + .into_iter() + .chain(data_code) + .chain([ + "// This is the end of the data initialization routine.".to_string(), + format!("jump_dyn 1, {};", Register::from("tmp1").addr()), + ]), + ); } statements.extend(runtime.ecall_handler()); @@ -277,7 +343,7 @@ let initial_memory: (fe, fe)[] = [ ) } -fn preamble(runtime: &Runtime, with_bootloader: bool) -> String { +fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bool) -> String { let bootloader_preamble_if_included = if with_bootloader { bootloader_preamble() } else { @@ -299,17 +365,8 @@ fn preamble(runtime: &Runtime, with_bootloader: bool) -> String reg Y[<=]; reg Z[<=]; reg W[<=]; - reg tmp1; - reg tmp2; - reg tmp3; - reg tmp4; - reg lr_sc_reservation; "# .to_string() - // risc-v x* registers - + &(0..32) - .map(|i| format!("\t\treg x{i};\n")) - .join("") // runtime extra registers + &runtime .submachines_extra_registers() @@ -319,43 +376,171 @@ fn preamble(runtime: &Runtime, with_bootloader: bool) -> String + &bootloader_preamble_if_included + &memory(with_bootloader) + r#" - // ============== Constraint on x0 ======================= - - x0 = 0; - - // ============== iszero check for X ======================= - let XIsZero = std::utils::is_zero(X); + // =============== Register memory ======================= +"# + format!("std::machines::memory::Memory_{} regs;", degree + 2) + .as_str() + + r#" + // Get the value in register Y. + instr get_reg Y -> X link ~> X = regs.mload(Y, STEP); + + // Set the value in register X to the value in register Y. + instr set_reg X, Y -> link ~> regs.mstore(X, STEP, Y); + + // We still need these registers prover inputs. + reg query_arg_1; + reg query_arg_2; + + // Witness columns used in instuctions for intermediate values inside instructions. + col witness val1_col; + col witness val2_col; + col witness val3_col; + col witness val4_col; + + // We need to add these inline instead of using std::utils::is_zero + // because when XX is not constrained, witgen will try to set XX, + // XX_inv and XXIsZero to zero, which fails this constraint. + // Therefore, we have to activate constrained whenever XXIsZero is used. + // XXIsZero = 1 - XX * XX_inv + col witness XX, XX_inv, XXIsZero; + std::utils::force_bool(XXIsZero); + XXIsZero * XX = 0; // ============== control-flow instructions ============== - instr load_label l: label -> X { X = l } + // Load the value of label `l` into register X. + instr load_label X, l: label + link ~> regs.mstore(X, STEP, val1_col) + { + val1_col = l + } - instr jump l: label -> Y { pc' = l, Y = pc + 1} - instr jump_dyn X -> Y { pc' = X, Y = pc + 1} + // Jump to `l` and store the return program counter in register W. + instr jump l: label, W + link ~> regs.mstore(W, STEP, val3_col) + { + pc' = l, + val3_col = pc + 1 + } + + // Jump to the address in register X and store the return program counter in register W. + instr jump_dyn X, W + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(W, STEP, val3_col) + { + pc' = val1_col, + val3_col = pc + 1 + } + + // Jump to `l` if val(X) - val(Y) is nonzero, where X and Y are register ids. + instr branch_if_diff_nonzero X, Y, l: label + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + { + XXIsZero = 1 - XX * XX_inv, + XX = val1_col - val2_col, + pc' = (1 - XXIsZero) * l + XXIsZero * (pc + 1) + } - instr branch_if_nonzero X, l: label { pc' = (1 - XIsZero) * l + XIsZero * (pc + 1) } - instr branch_if_zero X, l: label { pc' = XIsZero * l + (1 - XIsZero) * (pc + 1) } + // Jump to `l` if val(X) - (val(Y) + Z) is zero, where X and Y are register ids and Z is a + // constant offset. + instr branch_if_zero X, Y, Z, l: label + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + { + XXIsZero = 1 - XX * XX_inv, + XX = val1_col - (val2_col + Z), + pc' = XXIsZero * l + (1 - XXIsZero) * (pc + 1) + } - // Skips Y instructions if X is zero - instr skip_if_zero X, Y { pc' = pc + 1 + (XIsZero * Y) } + // Skips W instructions if val(X) - val(Y) + Z is zero, where X and Y are register ids and Z is a + // constant offset. + instr skip_if_zero X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + { + XXIsZero = 1 - XX * XX_inv, + XX = val1_col - val2_col + Z, + pc' = pc + 1 + (XXIsZero * W) + } - // input X is required to be the difference of two 32-bit unsigend values. - // i.e. -2**32 < X < 2**32 - instr branch_if_positive X, l: label { - X + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32, + // Branches to `l` if V = val(X) - val(Y) + Z is positive, where X and Y are register ids and Z is a + // constant offset. + // V is required to be the difference of two 32-bit unsigned values. + // i.e. -2**32 < V < 2**32. + instr branch_if_positive X, Y, Z, l: label + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + { + (val1_col - val2_col + Z) + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32, pc' = wrap_bit * l + (1 - wrap_bit) * (pc + 1) } - // input X is required to be the difference of two 32-bit unsigend values. - // i.e. -2**32 < X < 2**32 - instr is_positive X -> Y { - X + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32, - Y = wrap_bit + + // Stores 1 in register W if V = val(X) - val(Y) + Z is positive, where X and Y are register ids and Z is a + // constant offset. + // V is required to be the difference of two 32-bit unsigend values. + // i.e. -2**32 < V < 2**32 + instr is_positive X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, wrap_bit) + { + (val1_col - val2_col + Z) + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32 + } + + // Stores val(X) * Z + W in register Y. + instr move_reg X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 1, val3_col) + { + val3_col = val1_col * Z + W + } + + // ================= wrapping instructions ================= + + // Computes V = val(X) + val(Y) + Z, wraps it in 32 bits, and stores the result in register W. + // Requires 0 <= V < 2**33. + instr add_wrap X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, val3_col) + { + val1_col + val2_col + Z = val3_col + wrap_bit * 2**32, + val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + } + + // Computes V = val(X) - val(Y) + Z, wraps it in 32 bits, and stores the result in register W. + // Requires -2**32 <= V < 2**32. + instr sub_wrap_with_offset X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, val3_col) + { + (val1_col - val2_col + Z) + 2**32 = val3_col + wrap_bit * 2**32, + val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } // ================= logical instructions ================= - instr is_equal_zero X -> Y { Y = XIsZero } - instr is_not_equal_zero X -> Y { Y = 1 - XIsZero } + // Stores 1 in register W if the value in register X is zero, + // otherwise stores 0. + instr is_equal_zero X, W + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(W, STEP + 2, XXIsZero) + { + XXIsZero = 1 - XX * XX_inv, + XX = val1_col + } + + // Stores 1 in register W if val(X) == val(Y), otherwise stores 0. + instr is_not_equal X, Y, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, val3_col) + { + XXIsZero = 1 - XX * XX_inv, + XX = val1_col - val2_col, + val3_col = 1 - XXIsZero + } // ================= submachine instructions ================= "# + &runtime @@ -364,11 +549,6 @@ fn preamble(runtime: &Runtime, with_bootloader: bool) -> String .map(|s| format!(" {s}")) .join("\n") + r#" - // Wraps a value in Y to 32 bits. - // Requires 0 <= Y < 2**33 - instr wrap Y -> X { Y = X + wrap_bit * 2**32, X = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } - // Requires -2**32 <= Y < 2**32 - instr wrap_signed Y -> X { Y + 2**32 = X + wrap_bit * 2**32, X = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } col fixed bytes(i) { i & 0xff }; col witness X_b1; col witness X_b2; @@ -381,41 +561,63 @@ fn preamble(runtime: &Runtime, with_bootloader: bool) -> String col witness wrap_bit; wrap_bit * (1 - wrap_bit) = 0; + // Sign extends the value in register X and stores it in register Y. // Input is a 32 bit unsigned number. We check bit 7 and set all higher bits to that value. - instr sign_extend_byte Y -> X { + instr sign_extend_byte X, Y + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 3, val3_col) + { // wrap_bit is used as sign_bit here. - Y = Y_7bit + wrap_bit * 0x80 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, - X = Y_7bit + wrap_bit * 0xffffff80 + val1_col = Y_7bit + wrap_bit * 0x80 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, + val3_col = Y_7bit + wrap_bit * 0xffffff80 } + col fixed seven_bit(i) { i & 0x7f }; col witness Y_7bit; [ Y_7bit ] in [ seven_bit ]; + // Sign extends the value in register X and stores it in register Y. // Input is a 32 bit unsigned number. We check bit 15 and set all higher bits to that value. - instr sign_extend_16_bits Y -> X { + instr sign_extend_16_bits X, Y + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 3, val3_col) + { Y_15bit = X_b1 + Y_7bit * 0x100, // wrap_bit is used as sign_bit here. - Y = Y_15bit + wrap_bit * 0x8000 + X_b3 * 0x10000 + X_b4 * 0x1000000, - X = Y_15bit + wrap_bit * 0xffff8000 + val1_col = Y_15bit + wrap_bit * 0x8000 + X_b3 * 0x10000 + X_b4 * 0x1000000, + val3_col = Y_15bit + wrap_bit * 0xffff8000 } col witness Y_15bit; - // Input is a 32 but unsigned number (0 <= Y < 2**32) interpreted as a two's complement numbers. - // Returns a signed number (-2**31 <= X < 2**31). - instr to_signed Y -> X { + // Converts the value in register X to a signed number and stores it in register Y. + // Input is a 32 bit unsigned number (0 <= val(X) < 2**32) interpreted as a two's complement numbers. + // Returns a signed number (-2**31 <= val(Y) < 2**31). + instr to_signed X, Y + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 1, val3_col) + { // wrap_bit is used as sign_bit here. - Y = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + Y_7bit * 0x1000000 + wrap_bit * 0x80000000, - X = Y - wrap_bit * 2**32 + val1_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + Y_7bit * 0x1000000 + wrap_bit * 0x80000000, + val3_col = val1_col - wrap_bit * 0x100000000 } // ======================= assertions ========================= instr fail { 1 = 0 } + // Wraps V = val(X) * Y and stores it in register Z, + // where X is a register and Y is a constant factor. // Removes up to 16 bits beyond 32 // TODO is this really safe? - instr wrap16 Y -> X { Y = Y_b5 * 2**32 + Y_b6 * 2**40 + X, X = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } + instr wrap16 X, Y, Z + link ~> val1_col = regs.mload(X, STEP) + link ~> regs.mstore(Z, STEP + 3, val3_col) + { + (val1_col * Y) = Y_b5 * 2**32 + Y_b6 * 2**40 + val3_col, + val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + } + col witness Y_b5; col witness Y_b6; col witness Y_b7; @@ -434,27 +636,30 @@ fn preamble(runtime: &Runtime, with_bootloader: bool) -> String [ REM_b3 ] in [ bytes ]; [ REM_b4 ] in [ bytes ]; - // implements Z = Y / X and W = Y % X. - instr divremu Y, X -> Z, W { - // main division algorithm: - // Y is the known dividend - // X is the known divisor - // Z is the unknown quotient - // W is the unknown remainder + // Computes Q = val(Y) / val(X) and R = val(Y) % val(X) and stores them in registers Z and W. + instr divremu Y, X, Z, W + link ~> val1_col = regs.mload(Y, STEP) + link ~> val2_col = regs.mload(X, STEP + 1) + link ~> regs.mstore(Z, STEP + 2, val3_col) + link ~> regs.mstore(W, STEP + 3, val4_col) + { + XXIsZero = 1 - XX * XX_inv, + XX = val2_col, + // if X is zero, remainder is set to dividend, as per RISC-V specification: - X * Z + W = Y, + val2_col * val3_col + val4_col = val1_col, // remainder >= 0: - W = REM_b1 + REM_b2 * 0x100 + REM_b3 * 0x10000 + REM_b4 * 0x1000000, + val4_col = REM_b1 + REM_b2 * 0x100 + REM_b3 * 0x10000 + REM_b4 * 0x1000000, - // remainder < divisor, conditioned to X not being 0: - (1 - XIsZero) * (X - W - 1 - Y_b5 - Y_b6 * 0x100 - Y_b7 * 0x10000 - Y_b8 * 0x1000000) = 0, + // remainder < divisor, conditioned to val(X) not being 0: + (1 - XXIsZero) * (val2_col - val4_col - 1 - Y_b5 - Y_b6 * 0x100 - Y_b7 * 0x10000 - Y_b8 * 0x1000000) = 0, // in case X is zero, we set quotient according to RISC-V specification - XIsZero * (Z - 0xffffffff) = 0, + XXIsZero * (val3_col - 0xffffffff) = 0, // quotient is 32 bits: - Z = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } "# + mul_instruction } @@ -465,14 +670,17 @@ fn mul_instruction(runtime: &Runtime) -> &'static str { // The BN254 field can fit any 64-bit number, so we can naively de-compose // Z * W into 8 bytes and put them together to get the upper and lower word. r#" - // Multiply two 32-bits unsigned, return the upper and lower unsigned 32-bit - // halves of the result. - // X is the lower half (least significant bits) - // Y is the higher half (most significant bits) - instr mul Z, W -> X, Y { - Z * W = X + Y * 2**32, - X = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, - Y = Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000 + // Computes V = val(X) * val(Y) and + // stores the lower 32 bits in register Z and the upper 32 bits in register W. + instr mul X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(Z, STEP + 2, val3_col) + link ~> regs.mstore(W, STEP + 3, val4_col); + { + val1_col * val2_col = val3_col + val4_col * 2**32, + val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, + val4_col = Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000 } "# } @@ -484,11 +692,14 @@ fn mul_instruction(runtime: &Runtime) -> &'static str { // The Goldilocks field cannot fit some 64-bit numbers, so we have to use // the split machine. Note that it can fit a product of two 32-bit numbers. r#" - // Multiply two 32-bits unsigned, return the upper and lower unsigned 32-bit - // halves of the result. - // X is the lower half (least significant bits) - // Y is the higher half (most significant bits) - instr mul Z, W -> X, Y link ~> (X, Y) = split_gl.split(Z * W); + // Computes V = val(X) * val(Y) and + // stores the lower 32 bits in register Z and the upper 32 bits in register W. + instr mul X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> (val3_col, val4_col) = split_gl.split(val1_col * val2_col) + link ~> regs.mstore(Z, STEP + 2, val3_col) + link ~> regs.mstore(W, STEP + 3, val4_col); "# } } @@ -498,8 +709,16 @@ fn memory(with_bootloader: bool) -> String { let memory_machine = if with_bootloader { r#" std::machines::memory_with_bootloader_write::MemoryWithBootloaderWrite memory; - instr mstore_bootloader Y, Z -> link ~> memory.mstore_bootloader(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP, Z) { - Y = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 + + // Stores val(W) at address (V = val(X) - val(Z) + Y) % 2**32. + // V can be between 0 and 2**33. + instr mstore_bootloader X, Z, Y, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Z, STEP + 1) + link ~> val3_col = regs.mload(W, STEP + 2) + link ~> memory.mstore_bootloader(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP + 3, val3_col) + { + val1_col - val2_col + Y = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 } "# } else { @@ -511,25 +730,38 @@ fn memory(with_bootloader: bool) -> String { memory_machine.to_string() + r#" - col fixed STEP(i) { i }; + // Increased by 4 in each step, because we do up to 4 register memory accesses per step + col fixed STEP(i) { 4 * i }; // ============== memory instructions ============== let up_to_three: col = |i| i % 4; let six_bits: col = |i| i % 2**6; - /// Loads one word from an address Y, where Y can be between 0 and 2**33 (sic!), + /// Loads one word from an address V = val(X) + Y, where V can be between 0 and 2**33 (sic!), /// wraps the address to 32 bits and rounds it down to the next multiple of 4. - /// Returns the loaded word and the remainder of the division by 4. - instr mload Y -> X, Z link ~> X = memory.mload(X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4, STEP) { - [ Z ] in [ up_to_three ], - Y = wrap_bit * 2**32 + X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4 + Z, + /// Writes the loaded word and the remainder of the division by 4 to registers Z and W, + /// respectively. + instr mload X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val3_col = memory.mload(X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4, STEP + 1) + link ~> regs.mstore(Z, STEP + 2, val3_col) + link ~> regs.mstore(W, STEP + 3, val4_col) + { + [ val4_col ] in [ up_to_three ], + val1_col + Y = wrap_bit * 2**32 + X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4 + val4_col, [ X_b1 ] in [ six_bits ] } - /// Stores Z at address Y % 2**32. Y can be between 0 and 2**33. - /// Y should be a multiple of 4, but this instruction does not enforce it. - instr mstore Y, Z -> link ~> memory.mstore(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP, Z) { - Y = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 + // Stores val(W) at address (V = val(X) - val(Y) + Z) % 2**32. + // V can be between 0 and 2**33. + // V should be a multiple of 4, but this instruction does not enforce it. + instr mstore X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> val3_col = regs.mload(W, STEP + 2) + link ~> memory.mstore(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP + 3, val3_col) + { + val1_col - val2_col + Z = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 } "# } @@ -552,11 +784,11 @@ pub trait InstructionArgs { fn empty(&self) -> Result<(), Self::Error>; } -fn only_if_no_write_to_zero(statement: String, reg: Register) -> Vec { - only_if_no_write_to_zero_vec(vec![statement], reg) +fn only_if_no_write_to_zero(reg: Register, statement: String) -> Vec { + only_if_no_write_to_zero_vec(reg, vec![statement]) } -fn only_if_no_write_to_zero_vec(statements: Vec, reg: Register) -> Vec { +fn only_if_no_write_to_zero_vec(reg: Register, statements: Vec) -> Vec { if reg.is_zero() { vec![] } else { @@ -565,23 +797,38 @@ fn only_if_no_write_to_zero_vec(statements: Vec, reg: Register) -> Vec [String; 2] { - [ - "x2 <=X= wrap(x2 - 4);".to_string(), - format!("mstore x2, {name};"), + +pub fn push_register(name: &str) -> Vec { + assert!(name.starts_with('x'), "Only x registers are supported"); + let reg = Register::from(name); + vec![ + // x2 + x0 - 4 => x2 + format!("add_wrap 2, 0, -4, 2;",), + format!("mstore 2, 0, 0, {};", reg.addr()), ] } /// Pop register from the stack -pub fn pop_register(name: &str) -> [String; 2] { - [ - format!("{name}, tmp1 <== mload(x2);"), - "x2 <=X= wrap(x2 + 4);".to_string(), +pub fn pop_register(name: &str) -> Vec { + assert!(name.starts_with('x'), "Only x registers are supported"); + let reg = Register::from(name); + vec![ + format!( + "mload 2, 0, {}, {};", + reg.addr(), + Register::from("tmp1").addr() + ), + "add_wrap 2, 0, 4, 2;".to_string(), ] } fn process_instruction(instr: &str, args: A) -> Result, A::Error> { - Ok(match instr { + let tmp1 = Register::from("tmp1"); + let tmp2 = Register::from("tmp2"); + let tmp3 = Register::from("tmp3"); + let tmp4 = Register::from("tmp4"); + let lr_sc_reservation = Register::from("lr_sc_reservation"); + let statements = match instr { // load/store registers "li" | "la" => { // The difference between "li" and "la" in RISC-V is that the former @@ -590,132 +837,230 @@ fn process_instruction(instr: &str, args: A) -> Result { let (rd, imm) = args.ri()?; - only_if_no_write_to_zero(format!("{rd} <=X= {};", imm << 12), rd) + only_if_no_write_to_zero(rd, format!("set_reg {}, {};", rd.addr(), imm << 12)) } "mv" => { let (rd, rs) = args.rr()?; - only_if_no_write_to_zero(format!("{rd} <=X= {rs};"), rd) + only_if_no_write_to_zero(rd, format!("move_reg {}, {}, 1, 0;", rs.addr(), rd.addr())) } // Arithmetic "add" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd} <== wrap({r1} + {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "add_wrap {}, {}, {}, {};", + r1.addr(), + r2.addr(), + 0, + rd.addr() + ), + ) } "addi" => { let (rd, rs, imm) = args.rri()?; - only_if_no_write_to_zero(format!("{rd} <== wrap({rs} + {imm});"), rd) + only_if_no_write_to_zero( + rd, + format!("add_wrap {}, 0, {imm}, {};", rs.addr(), rd.addr()), + ) } "sub" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd} <== wrap_signed({r1} - {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "sub_wrap_with_offset {}, {}, 0, {};", + r1.addr(), + r2.addr(), + rd.addr() + ), + ) } "neg" => { let (rd, r1) = args.rr()?; - only_if_no_write_to_zero(format!("{rd} <== wrap_signed(0 - {r1});"), rd) + only_if_no_write_to_zero( + rd, + format!("sub_wrap_with_offset 0, {}, 0, {};", r1.addr(), rd.addr()), + ) } "mul" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd}, tmp1 <== mul({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "mul {}, {}, {}, {};", + r1.addr(), + r2.addr(), + rd.addr(), + tmp1.addr() + ), + ) } "mulhu" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("tmp1, {rd} <== mul({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "mul {}, {}, {}, {};", + r1.addr(), + r2.addr(), + tmp1.addr(), + rd.addr(), + ), + ) } "mulh" => { let (rd, r1, r2) = args.rrr()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== to_signed({r1});"), - format!("tmp2 <== to_signed({r2});"), + format!("to_signed {}, {};", r1.addr(), tmp1.addr()), + format!("to_signed {}, {};", r2.addr(), tmp2.addr()), // tmp3 is 1 if tmp1 is non-negative - "tmp3 <== is_positive(tmp1 + 1);".into(), + format!("is_positive {}, 0, 1, {};", tmp1.addr(), tmp3.addr()), // tmp4 is 1 if tmp2 is non-negative - "tmp4 <== is_positive(tmp2 + 1);".into(), + format!("is_positive {}, 0, 1, {};", tmp2.addr(), tmp4.addr()), // If tmp1 is negative, convert to positive - "skip_if_zero 0, tmp3;".into(), - "tmp1 <=X= 0 - tmp1;".into(), + format!("skip_if_zero 0, {}, 1, 1;", tmp3.addr()), + format!("move_reg {}, {}, -1, 0;", tmp1.addr(), tmp1.addr()), // If tmp2 is negative, convert to positive - "skip_if_zero 0, tmp4;".into(), - "tmp2 <=X= 0 - tmp2;".into(), - format!("tmp1, {rd} <== mul(tmp1, tmp2);"), + format!("skip_if_zero 0, {}, 1, 1;", tmp4.addr()), + format!("move_reg {}, {}, -1, 0;", tmp2.addr(), tmp2.addr()), + format!( + "mul {}, {}, {}, {};", + tmp1.addr(), + tmp2.addr(), + tmp1.addr(), + rd.addr() + ), // Determine the sign of the result based on the signs of tmp1 and tmp2 - "tmp3 <== is_not_equal_zero(tmp3 - tmp4);".into(), + format!( + "is_not_equal {}, {}, {};", + tmp3.addr(), + tmp4.addr(), + tmp3.addr() + ), // If the result should be negative, convert back to negative - "skip_if_zero tmp3, 2;".into(), - "tmp1 <== is_equal_zero(tmp1);".into(), - format!("{rd} <== wrap_signed(-{rd} - 1 + tmp1);"), + format!("skip_if_zero {}, 0, 0, 2;", tmp3.addr()), + format!("is_equal_zero {}, {};", tmp1.addr(), tmp1.addr()), + format!( + "sub_wrap_with_offset {}, {}, -1, {};", + tmp1.addr(), + rd.addr(), + rd.addr() + ), ], - rd, ) } "mulhsu" => { let (rd, r1, r2) = args.rrr()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== to_signed({r1});"), + format!("to_signed {}, {};", r1.addr(), tmp1.addr()), // tmp2 is 1 if tmp1 is non-negative - "tmp2 <== is_positive(tmp1 + 1);".into(), + format!("is_positive {}, 0, 1, {};", tmp1.addr(), tmp2.addr()), // If negative, convert to positive - "skip_if_zero 0, tmp2;".into(), - "tmp1 <=X= 0 - tmp1;".into(), - format!("tmp1, {rd} <== mul(tmp1, {r2});"), + format!("skip_if_zero 0, {}, 1, 1;", tmp2.addr()), + format!("move_reg {}, {}, -1, 0;", tmp1.addr(), tmp1.addr()), + format!( + "mul {}, {}, {}, {};", + tmp1.addr(), + r2.addr(), + tmp1.addr(), + rd.addr() + ), // If was negative before, convert back to negative - "skip_if_zero (1-tmp2), 2;".into(), - "tmp1 <== is_equal_zero(tmp1);".into(), + format!("skip_if_zero 0, {}, 1, 2;", tmp2.addr()), + format!("is_equal_zero {}, {};", tmp1.addr(), tmp1.addr()), // If the lower bits are zero, return the two's complement, // otherwise return one's complement. - format!("{rd} <== wrap_signed(-{rd} - 1 + tmp1);"), + format!( + "sub_wrap_with_offset {}, {}, -1, {};", + tmp1.addr(), + rd.addr(), + rd.addr() + ), ], - rd, ) } "divu" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd}, tmp1 <== divremu({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "divremu {}, {}, {}, {};", + r1.addr(), + r2.addr(), + rd.addr(), + tmp1.addr() + ), + ) } "remu" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("tmp1, {rd} <== divremu({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "divremu {}, {}, {}, {};", + r1.addr(), + r2.addr(), + tmp1.addr(), + rd.addr() + ), + ) } // bitwise "xor" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd} <== xor({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!("xor {}, {}, 0, {};", r1.addr(), r2.addr(), rd.addr()), + ) } "xori" => { let (rd, r1, imm) = args.rri()?; - only_if_no_write_to_zero(format!("{rd} <== xor({r1}, {imm});"), rd) + only_if_no_write_to_zero(rd, format!("xor {}, 0, {imm}, {};", r1.addr(), rd.addr())) } "and" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd} <== and({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!("and {}, {}, 0, {};", r1.addr(), r2.addr(), rd.addr()), + ) } "andi" => { let (rd, r1, imm) = args.rri()?; - only_if_no_write_to_zero(format!("{rd} <== and({r1}, {imm});"), rd) + only_if_no_write_to_zero(rd, format!("and {}, 0, {imm}, {};", r1.addr(), rd.addr())) } "or" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd} <== or({r1}, {r2});"), rd) + only_if_no_write_to_zero( + rd, + format!("or {}, {}, 0, {};", r1.addr(), r2.addr(), rd.addr()), + ) } "ori" => { let (rd, r1, imm) = args.rri()?; - only_if_no_write_to_zero(format!("{rd} <== or({r1}, {imm});"), rd) + only_if_no_write_to_zero(rd, format!("or {}, 0, {imm}, {};", r1.addr(), rd.addr())) } "not" => { let (rd, rs) = args.rr()?; - only_if_no_write_to_zero(format!("{rd} <== wrap_signed(-{rs} - 1);"), rd) + only_if_no_write_to_zero( + rd, + format!("sub_wrap_with_offset 0, {}, -1, {};", rs.addr(), rd.addr()), + ) } // shift @@ -723,42 +1068,55 @@ fn process_instruction(instr: &str, args: A) -> Result { let (rd, r1, r2) = args.rrr()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== and({r2}, 0x1f);"), - format!("{rd} <== shl({r1}, tmp1);"), + format!("and {}, 0, 0x1f, {};", r2.addr(), tmp1.addr()), + format!("shl {}, {}, 0, {};", r1.addr(), tmp1.addr(), rd.addr()), ], - rd, ) } "srli" => { // logical shift right let (rd, rs, amount) = args.rri()?; assert!(amount <= 31); - only_if_no_write_to_zero(format!("{rd} <== shr({rs}, {amount});"), rd) + only_if_no_write_to_zero( + rd, + format!("shr {}, 0, {amount}, {};", rs.addr(), rd.addr()), + ) } "srl" => { // logical shift right let (rd, r1, r2) = args.rrr()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== and({r2}, 0x1f);"), - format!("{rd} <== shr({r1}, tmp1);"), + format!("and {}, 0, 0x1f, {};", r2.addr(), tmp1.addr()), + format!("shr {}, {}, 0, {};", r1.addr(), tmp1.addr(), rd.addr()), ], - rd, ) } "srai" => { @@ -769,66 +1127,87 @@ fn process_instruction(instr: &str, args: A) -> Result { let (rd, rs) = args.rr()?; - only_if_no_write_to_zero(format!("{rd} <=Y= is_equal_zero({rs});"), rd) + only_if_no_write_to_zero(rd, format!("is_equal_zero {}, {};", rs.addr(), rd.addr())) } "snez" => { let (rd, rs) = args.rr()?; - only_if_no_write_to_zero(format!("{rd} <=Y= is_not_equal_zero({rs});"), rd) + only_if_no_write_to_zero(rd, format!("is_not_equal {}, 0, {};", rs.addr(), rd.addr())) } "slti" => { let (rd, rs, imm) = args.rri()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== to_signed({rs});"), - format!("{rd} <=Y= is_positive({} - tmp1);", imm as i32), + format!("to_signed {}, {};", rs.addr(), tmp1.addr()), + format!( + "is_positive 0, {}, {}, {};", + tmp1.addr(), + imm as i32, + rd.addr() + ), ], - rd, ) } "slt" => { let (rd, r1, r2) = args.rrr()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== to_signed({r1});"), - format!("tmp2 <== to_signed({r2});"), - format!("{rd} <=Y= is_positive(tmp2 - tmp1);"), + format!("to_signed {}, {};", r1.addr(), tmp1.addr()), + format!("to_signed {}, {};", r2.addr(), tmp2.addr()), + format!( + "is_positive {}, {}, 0, {};", + tmp2.addr(), + tmp1.addr(), + rd.addr() + ), ], - rd, ) } "sltiu" => { let (rd, rs, imm) = args.rri()?; - only_if_no_write_to_zero(format!("{rd} <=Y= is_positive({imm} - {rs});"), rd) + only_if_no_write_to_zero( + rd, + format!("is_positive 0, {}, {imm}, {};", rs.addr(), rd.addr()), + ) } "sltu" => { let (rd, r1, r2) = args.rrr()?; - only_if_no_write_to_zero(format!("{rd} <=Y= is_positive({r2} - {r1});"), rd) + only_if_no_write_to_zero( + rd, + format!( + "is_positive {}, {}, 0, {};", + r2.addr(), + r1.addr(), + rd.addr() + ), + ) } "sgtz" => { let (rd, rs) = args.rr()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("tmp1 <== to_signed({rs});"), - format!("{rd} <=Y= is_positive(tmp1);"), + format!("to_signed {}, {};", rs.addr(), tmp1.addr()), + format!("is_positive {}, 0, 0, {};", tmp1.addr(), rd.addr()), ], - rd, ) } @@ -836,31 +1215,43 @@ fn process_instruction(instr: &str, args: A) -> Result { let (r1, r2, label) = args.rrl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_zero {r1} - {r2}, {label};")] + vec![format!( + "branch_if_zero {}, {}, 0, {label};", + r1.addr(), + r2.addr() + )] } "beqz" => { let (r1, label) = args.rl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_zero {r1}, {label};")] + vec![format!("branch_if_zero {}, 0, 0, {label};", r1.addr())] } "bgeu" => { let (r1, r2, label) = args.rrl()?; let label = escape_label(label.as_ref()); // TODO does this fulfill the input requirements for branch_if_positive? - vec![format!("branch_if_positive {r1} - {r2} + 1, {label};")] + vec![format!( + "branch_if_positive {}, {}, 1, {label};", + r1.addr(), + r2.addr() + )] } "bgez" => { let (r1, label) = args.rl()?; let label = escape_label(label.as_ref()); vec![ - format!("tmp1 <== to_signed({r1});"), - format!("branch_if_positive tmp1 + 1, {label};"), + format!("to_signed {}, {};", r1.addr(), tmp1.addr()), + format!("branch_if_positive {}, 0, 1, {label};", tmp1.addr()), ] } "bltu" => { let (r1, r2, label) = args.rrl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_positive {r2} - {r1}, {label};")] + vec![format!( + "branch_if_positive {}, {}, 0, {label};", + r2.addr(), + r1.addr() + )] } "blt" => { let (r1, r2, label) = args.rrl()?; @@ -868,9 +1259,13 @@ fn process_instruction(instr: &str, args: A) -> Result { @@ -879,24 +1274,31 @@ fn process_instruction(instr: &str, args: A) -> Result= r2 (signed). // TODO does this fulfill the input requirements for branch_if_positive? vec![ - format!("tmp1 <== to_signed({r1});"), - format!("tmp2 <== to_signed({r2});"), - format!("branch_if_positive tmp1 - tmp2 + 1, {label};"), + format!("to_signed {}, {};", r1.addr(), tmp1.addr()), + format!("to_signed {}, {};", r2.addr(), tmp2.addr()), + format!( + "branch_if_positive {}, {}, 1, {label};", + tmp1.addr(), + tmp2.addr() + ), ] } "bltz" => { // branch if 2**31 <= r1 < 2**32 let (r1, label) = args.rl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_positive {r1} - 2**31 + 1, {label};")] + vec![format!( + "branch_if_positive {}, 0, -(2**31) + 1, {label};", + r1.addr() + )] } "blez" => { // branch less or equal zero let (r1, label) = args.rl()?; let label = escape_label(label.as_ref()); vec![ - format!("tmp1 <== to_signed({r1});"), - format!("branch_if_positive -tmp1 + 1, {label};"), + format!("to_signed {}, {};", r1.addr(), tmp1.addr()), + format!("branch_if_positive 0, {}, 1, {label};", tmp1.addr()), ] } "bgtz" => { @@ -904,61 +1306,66 @@ fn process_instruction(instr: &str, args: A) -> Result { let (r1, r2, label) = args.rrl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_nonzero {r1} - {r2}, {label};")] + vec![format!( + "branch_if_diff_nonzero {}, {}, {label};", + r1.addr(), + r2.addr() + )] } "bnez" => { let (r1, label) = args.rl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_nonzero {r1}, {label};")] + vec![format!("branch_if_diff_nonzero {}, 0, {label};", r1.addr())] } // jump and call "j" | "tail" => { let label = args.l()?; let label = escape_label(label.as_ref()); - vec![format!("tmp1 <== jump({label});",)] + vec![format!("jump {label}, {};", tmp1.addr(),)] } "jr" => { let rs = args.r()?; - vec![format!("tmp1 <== jump_dyn({rs});")] + vec![format!("jump_dyn {}, {};", rs.addr(), tmp1.addr())] } "jal" => { if let Ok(label) = args.l() { let label = escape_label(label.as_ref()); - vec![format!("x1 <== jump({label});")] + vec![format!("jump {label}, 1;")] } else { let (rd, label) = args.rl()?; let label = escape_label(label.as_ref()); - let statement = if rd.is_zero() { - format!("tmp1 <== jump({label});") + if rd.is_zero() { + vec![format!("jump {label}, {};", tmp1.addr())] } else { - format!("{rd} <== jump({label});") - }; - vec![statement] + vec![format!("jump {label}, {};", rd.addr())] + } } } - "jalr" => vec![if let Ok(rs) = args.r() { - format!("x1 <== jump_dyn({rs});") - } else { - let (rd, rs, off) = args.rro()?; - assert_eq!(off, 0, "jalr with non-zero offset is not supported"); - if rd.is_zero() { - format!("tmp1 <== jump_dyn({rs});") + "jalr" => { + if let Ok(rs) = args.r() { + vec![format!("jump_dyn {}, 1;", rs.addr())] } else { - format!("{rd} <== jump_dyn({rs});") + let (rd, rs, off) = args.rro()?; + assert_eq!(off, 0, "jalr with non-zero offset is not supported"); + if rd.is_zero() { + vec![format!("jump_dyn {}, {};", rs.addr(), tmp1.addr())] + } else { + vec![format!("jump_dyn {}, {};", rs.addr(), rd.addr())] + } } - }], + } "call" => { let label = args.l()?; let label = escape_label(label.as_ref()); - vec![format!("x1 <== jump({label});")] + vec![format!("jump {label}, 1;")] } "ecall" => { args.empty()?; @@ -966,7 +1373,7 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result { args.empty()?; - vec!["tmp1 <== jump_dyn(x1);".to_string()] + vec![format!("jump_dyn 1, {};", tmp1.addr())] } // memory access "lw" => { let (rd, rs, off) = args.rro()?; // TODO we need to consider misaligned loads / stores - only_if_no_write_to_zero_vec(vec![format!("{rd}, tmp1 <== mload({rs} + {off});")], rd) + only_if_no_write_to_zero( + rd, + format!( + "mload {}, {off}, {}, {};", + rs.addr(), + rd.addr(), + tmp1.addr() + ), + ) } "lb" => { // load byte and sign-extend. the memory is little-endian. let (rd, rs, off) = args.rro()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("{rd}, tmp2 <== mload({rs} + {off});"), - format!("{rd} <== shr({rd}, 8 * tmp2);"), - format!("{rd} <== sign_extend_byte({rd});"), + format!( + "mload {}, {off}, {}, {};", + rs.addr(), + tmp1.addr(), + tmp2.addr() + ), + format!("move_reg {}, {}, 8, 0;", tmp2.addr(), tmp2.addr()), + format!("shr {}, {}, 0, {};", tmp1.addr(), tmp2.addr(), tmp1.addr()), + format!("sign_extend_byte {}, {};", tmp1.addr(), rd.addr()), ], - rd, ) } "lbu" => { // load byte and zero-extend. the memory is little-endian. let (rd, rs, off) = args.rro()?; only_if_no_write_to_zero_vec( + rd, vec![ - format!("{rd}, tmp2 <== mload({rs} + {off});"), - format!("{rd} <== shr({rd}, 8 * tmp2);"), - format!("{rd} <== and({rd}, 0xff);"), + format!( + "mload {}, {off}, {}, {};", + rs.addr(), + tmp1.addr(), + tmp2.addr() + ), + format!("move_reg {}, {}, 8, 0;", tmp2.addr(), tmp2.addr()), + format!("shr {}, {}, 0, {};", tmp1.addr(), tmp2.addr(), tmp1.addr()), + format!("and {}, 0, 0xff, {};", tmp1.addr(), rd.addr()), ], - rd, ) } "lh" => { @@ -1016,12 +1443,18 @@ fn process_instruction(instr: &str, args: A) -> Result { @@ -1029,17 +1462,23 @@ fn process_instruction(instr: &str, args: A) -> Result { let (r2, r1, off) = args.rro()?; - vec![format!("mstore {r1} + {off}, {r2};")] + vec![format!("mstore {}, 0, {off}, {};", r1.addr(), r2.addr())] } "sh" => { // store half word (two bytes) @@ -1048,28 +1487,52 @@ fn process_instruction(instr: &str, args: A) -> Result { // store byte let (r2, r1, off) = args.rro()?; vec![ - format!("tmp1, tmp2 <== mload({r1} + {off});"), - "tmp3 <== shl(0xff, 8 * tmp2);".to_string(), - "tmp3 <== xor(tmp3, 0xffffffff);".to_string(), - "tmp1 <== and(tmp1, tmp3);".to_string(), - format!("tmp3 <== and({r2}, 0xff);"), - "tmp3 <== shl(tmp3, 8 * tmp2);".to_string(), - "tmp1 <== or(tmp1, tmp3);".to_string(), - format!("mstore {r1} + {off} - tmp2, tmp1;"), + format!( + "mload {}, {off}, {}, {};", + r1.addr(), + tmp1.addr(), + tmp2.addr() + ), + format!("set_reg {}, 0xff;", tmp3.addr()), + format!("move_reg {}, {}, 8, 0;", tmp2.addr(), tmp4.addr()), + format!("shl {}, {}, 0, {};", tmp3.addr(), tmp4.addr(), tmp3.addr()), + format!("xor {}, 0, 0xffffffff, {};", tmp3.addr(), tmp3.addr()), + format!("and {}, {}, 0, {};", tmp1.addr(), tmp3.addr(), tmp1.addr()), + format!("and {}, 0, 0xff, {};", r2.addr(), tmp3.addr()), + format!("shl {}, {}, 0, {};", tmp3.addr(), tmp4.addr(), tmp3.addr()), + format!("or {}, {}, 0, {};", tmp1.addr(), tmp3.addr(), tmp1.addr()), + format!( + "mstore {}, {}, {off}, {};", + r1.addr(), + tmp2.addr(), + tmp1.addr() + ), ] } "fence" | "nop" => vec![], @@ -1081,11 +1544,19 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result { // Some overlap with "sw", but also writes 0 to rd on success let (rd, rs2, rs1) = args.rrr2()?; // TODO: misaligned access should raise misaligned address exceptions - let mut statements = vec![ - "skip_if_zero lr_sc_reservation, 1;".into(), - format!("mstore {rs1}, {rs2};"), - ]; - if !rd.is_zero() { - statements.push(format!("{rd} <=X= (1 - lr_sc_reservation);")); - } - statements.push("lr_sc_reservation <=X= 0;".into()); - statements + [ + format!("skip_if_zero {}, 0, 0, 1;", lr_sc_reservation.addr()), + format!("mstore {}, 0, 0, {};", rs1.addr(), rs2.addr()), + ] + .into_iter() + .chain(only_if_no_write_to_zero_vec( + rd, + vec![format!( + "move_reg {}, {}, -1, 1;", + lr_sc_reservation.addr(), + rd.addr() + )], + )) + .chain([format!("set_reg {}, 0;", lr_sc_reservation.addr())]) + .collect() } _ => { panic!("Unknown instruction: {instr}"); } - }) + }; + for s in &statements { + log::debug!(" {s}"); + } + Ok(statements) } diff --git a/riscv/src/continuations.rs b/riscv/src/continuations.rs index d912d69e8..5ed348b51 100644 --- a/riscv/src/continuations.rs +++ b/riscv/src/continuations.rs @@ -4,7 +4,7 @@ use std::{ }; use powdr_ast::{ - asm_analysis::{AnalysisASMFile, RegisterTy}, + asm_analysis::AnalysisASMFile, parsed::{asm::parse_absolute_path, Expression, Number, PilStatement}, }; use powdr_number::FieldElement; @@ -23,6 +23,8 @@ use crate::continuations::bootloader::{ WORDS_PER_PAGE, }; +use crate::code_gen::Register; + fn transposed_trace(trace: &ExecutionTrace) -> HashMap>> { let mut reg_values: HashMap<&str, Vec>> = HashMap::with_capacity(trace.reg_map.len()); @@ -68,10 +70,19 @@ where let num_chunks = bootloader_inputs.len(); log::info!("Computing fixed columns..."); - pipeline.compute_fixed_cols().unwrap(); + let fixed_cols = pipeline.compute_fixed_cols().unwrap(); + + // Advance the pipeline to the optimized PIL stage, so that it doesn't need to be computed + // in every chunk. + pipeline.compute_optimized_pil().unwrap(); - // we can assume optimized_pil has been computed - let length = pipeline.compute_optimized_pil().unwrap().degree(); + // TODO hacky way to find the degree of the main machine, fix. + let length = fixed_cols + .iter() + .find(|(col, _)| col == "main.STEP") + .unwrap() + .1 + .len() as u64; bootloader_inputs .into_iter() @@ -139,25 +150,6 @@ fn sanity_check(program: &AnalysisASMFile) { panic!(); } } - - // Check that the registers of the machine are as expected. - let machine_registers = main_machine - .registers - .iter() - .filter_map(|r| { - ((r.ty == RegisterTy::Pc || r.ty == RegisterTy::Write) && r.name != "x0") - .then_some(format!("main.{}", r.name)) - }) - .collect::>(); - let expected_registers = REGISTER_NAMES - .iter() - .map(|s| s.to_string()) - .collect::>(); - // FIXME: Currently, continuations don't support a Runtime with extra - // registers. This has not been fixed because extra registers will not be - // needed once we support accessing the memory machine from multiple - // machines. This comment can be removed then. - assert_eq!(machine_registers, expected_registers); } pub fn load_initial_memory(program: &AnalysisASMFile) -> MemoryState { @@ -336,18 +328,23 @@ pub fn rust_continuations_dry_run( log::info!("Bootloader inputs length: {}", bootloader_inputs.len()); log::info!("Simulating chunk execution..."); - let (chunk_trace, memory_snapshot_update) = { - let (trace, memory_snapshot_update) = powdr_riscv_executor::execute_ast::( - &program, - MemoryState::new(), - pipeline.data_callback().unwrap(), - &bootloader_inputs, - num_rows, - powdr_riscv_executor::ExecMode::Trace, - // profiling was done when full trace was generated - None, - ); - (transposed_trace(&trace), memory_snapshot_update) + let (chunk_trace, memory_snapshot_update, register_memory_snapshot) = { + let (trace, memory_snapshot_update, register_memory_snapshot) = + powdr_riscv_executor::execute_ast::( + &program, + MemoryState::new(), + pipeline.data_callback().unwrap(), + &bootloader_inputs, + num_rows, + powdr_riscv_executor::ExecMode::Trace, + // profiling was done when full trace was generated + None, + ); + ( + transposed_trace(&trace), + memory_snapshot_update, + register_memory_snapshot, + ) }; let mut memory_updates_by_page = merkle_tree.organize_updates_by_page(memory_snapshot_update.into_iter()); @@ -388,11 +385,17 @@ pub fn rust_continuations_dry_run( .copy_from_slice(&page_hash.map(Elem::Field)); } - // Update initial register values for the next chunk. - register_values = REGISTER_NAMES - .iter() - .map(|&r| *chunk_trace[r].last().unwrap()) - .collect(); + // Go over all registers except the PC + let register_iter = REGISTER_NAMES.iter().take(REGISTER_NAMES.len() - 1); + register_values = register_iter + .map(|reg| { + let reg = reg.strip_prefix("main.").unwrap(); + let id = Register::from(reg).addr(); + *register_memory_snapshot.get(&(id as u32)).unwrap() + }) + .collect::>(); + + register_values.push(*chunk_trace["main.pc"].last().unwrap()); // Replace final register values of the current chunk bootloader_inputs[REGISTER_NAMES.len()..2 * REGISTER_NAMES.len()] @@ -435,7 +438,7 @@ pub fn rust_continuations_dry_run( (length - start - shutdown_routine_rows) * 100 / length ); for i in 0..(chunk_trace["main.pc"].len() - start) { - for ® in REGISTER_NAMES.iter() { + for ® in ["main.pc", "main.query_arg_1", "main.query_arg_1"].iter() { let chunk_i = i + start; let full_i = i + proven_trace; if chunk_trace[reg][chunk_i] != full_trace[reg][full_i] { diff --git a/riscv/src/continuations/bootloader.rs b/riscv/src/continuations/bootloader.rs index d2c5af48c..a7c1f6874 100644 --- a/riscv/src/continuations/bootloader.rs +++ b/riscv/src/continuations/bootloader.rs @@ -1,9 +1,8 @@ use powdr_number::FieldElement; use powdr_riscv_executor::Elem; -use powdr_riscv_syscalls::SYSCALL_REGISTERS; - use super::memory_merkle_tree::MerkleTree; +use crate::code_gen::Register; /// 32-Bit architecture -> 2^32 bytes of addressable memory pub const MEMORY_SIZE_LOG: usize = 32; @@ -54,8 +53,15 @@ pub fn bootloader_preamble() -> String { // Write-once memory std::machines::write_once_memory::WriteOnceMemory bootloader_inputs; - instr load_bootloader_input X -> Y link => bootloader_inputs.access(X, Y); - instr assert_bootloader_input X, Y -> link => bootloader_inputs.access(X, Y); + instr load_bootloader_input X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link => bootloader_inputs.access(val1_col * Z + W, val3_col) + link ~> regs.mstore(Y, STEP + 2, val3_col); + + instr assert_bootloader_input X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link => bootloader_inputs.access(val1_col * Z + W, val2_col); // Sets the PC to the bootloader input at the provided index instr jump_to_bootloader_input X link => bootloader_inputs.access(X, pc'); @@ -111,19 +117,25 @@ pub fn bootloader_preamble() -> String { preamble } +// TODO also save/load the extra registers +static EXTRA_REGISTERS: [&str; 12] = [ + "xtra0", "xtra1", "xtra2", "xtra3", "xtra4", "xtra5", "xtra6", "xtra7", "xtra8", "xtra9", + "xtra10", "xtra11", +]; + // registers used by poseidon, for easier reference and string interpolation -static P0: &str = SYSCALL_REGISTERS[0]; -static P1: &str = SYSCALL_REGISTERS[1]; -static P2: &str = SYSCALL_REGISTERS[2]; -static P3: &str = SYSCALL_REGISTERS[3]; -static P4: &str = SYSCALL_REGISTERS[4]; -static P5: &str = SYSCALL_REGISTERS[5]; -static P6: &str = SYSCALL_REGISTERS[6]; -static P7: &str = SYSCALL_REGISTERS[7]; -static P8: &str = SYSCALL_REGISTERS[8]; -static P9: &str = SYSCALL_REGISTERS[9]; -static P10: &str = SYSCALL_REGISTERS[10]; -static P11: &str = SYSCALL_REGISTERS[11]; +static P0: &str = EXTRA_REGISTERS[0]; +static P1: &str = EXTRA_REGISTERS[1]; +static P2: &str = EXTRA_REGISTERS[2]; +static P3: &str = EXTRA_REGISTERS[3]; +static P4: &str = EXTRA_REGISTERS[4]; +static P5: &str = EXTRA_REGISTERS[5]; +static P6: &str = EXTRA_REGISTERS[6]; +static P7: &str = EXTRA_REGISTERS[7]; +static P8: &str = EXTRA_REGISTERS[8]; +static P9: &str = EXTRA_REGISTERS[9]; +static P10: &str = EXTRA_REGISTERS[10]; +static P11: &str = EXTRA_REGISTERS[11]; /// The bootloader: An assembly program that can be executed at the beginning a RISC-V execution. /// It lets the prover provide arbitrary memory pages and writes them to memory, as well as values for @@ -148,19 +160,19 @@ pub fn bootloader_and_shutdown_routine(submachine_initialization: &[String]) -> bootloader.push_str(&format!( r#" // Skip the next instruction -tmp1 <== jump(submachine_init); +jump submachine_init, 32; // For convenience, this instruction has a known fixed PC ({DEFAULT_PC}) and just jumps // to whatever comes after the bootloader + shutdown routine. This avoids having to count // the instructions of the bootloader and the submachine initialization. -tmp1 <== jump(computation_start); +jump computation_start, 32; // Similarly, this instruction has a known fixed PC ({SHUTDOWN_START}) and just jumps // to the shutdown routine. -tmp1 <== jump(shutdown_start); +jump shutdown_start, 32; shutdown_sink: -tmp1 <== jump(shutdown_sink); +jump shutdown_sink, 32; // Submachine initialization: Calls each submachine once, because that helps witness // generation figure out default values that can be used if the machine is never used. @@ -188,25 +200,25 @@ submachine_init: // - P8-P11 will contain the capacity elements (0 throughout the execution) // Number of pages -x1 <== load_bootloader_input({NUM_PAGES_INDEX}); -x1 <== wrap(x1); +load_bootloader_input 0, 1, 1, {NUM_PAGES_INDEX}; +add_wrap 1, 0, 0, 1; // Initialize memory hash -x18 <== load_bootloader_input({MEMORY_HASH_START_INDEX}); -x19 <== load_bootloader_input({MEMORY_HASH_START_INDEX} + 1); -x20 <== load_bootloader_input({MEMORY_HASH_START_INDEX} + 2); -x21 <== load_bootloader_input({MEMORY_HASH_START_INDEX} + 3); +load_bootloader_input 0, 18, 1, {MEMORY_HASH_START_INDEX}; +load_bootloader_input 0, 19, 1, {MEMORY_HASH_START_INDEX} + 1; +load_bootloader_input 0, 20, 1, {MEMORY_HASH_START_INDEX} + 2; +load_bootloader_input 0, 21, 1, {MEMORY_HASH_START_INDEX} + 3; // Current page index -x2 <=X= 0; +set_reg 2, 0; -branch_if_zero x1, bootloader_end_page_loop; +branch_if_zero 1, 0, 0, bootloader_end_page_loop; bootloader_start_page_loop: // Page number -x3 <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET}); -x3 <== and(x3, {PAGE_NUMBER_MASK}); +load_bootloader_input 2, 3, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET}; +and 3, 0, {PAGE_NUMBER_MASK}, 3; // Store & hash {WORDS_PER_PAGE} page words. This is an unrolled loop that for each each word: // - Loads the word into the P{{(i % 4) + 4}} register @@ -228,12 +240,16 @@ x3 <== and(x3, {PAGE_NUMBER_MASK}); "#, )); + bootloader.push_str(&format!("move_reg 3, 90, {PAGE_SIZE_BYTES}, 0;\n")); for i in 0..WORDS_PER_PAGE { - let reg = SYSCALL_REGISTERS[(i % 4) + 4]; + let reg = EXTRA_REGISTERS[(i % 4) + 4]; bootloader.push_str(&format!( r#" -{reg} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {i}); -mstore_bootloader x3 * {PAGE_SIZE_BYTES} + {i} * {BYTES_PER_WORD}, {reg};"# +load_bootloader_input 2, 91, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {i}; +{reg} <== get_reg(91); + +move_reg 3, 90, {PAGE_SIZE_BYTES}, {i} * {BYTES_PER_WORD}; +mstore_bootloader 90, 0, 0, 91;"# )); // Hash if buffer is full @@ -263,7 +279,7 @@ mstore_bootloader x3 * {PAGE_SIZE_BYTES} + {i} * {BYTES_PER_WORD}, {reg};"# // root is as claimed. // Set phase to validation -x9 <=X= 0; +set_reg 9, 0; bootloader_merkle_proof_validation_loop: @@ -283,22 +299,41 @@ bootloader_merkle_proof_validation_loop: let mask = 1 << i; bootloader.push_str(&format!( r#" -x4 <== and(x3, {mask}); -branch_if_nonzero x4, bootloader_level_{i}_is_right; -{P4} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 0); -{P5} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 1); -{P6} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 2); -{P7} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 3); -tmp1 <== jump(bootloader_level_{i}_end); +and 3, 0, {mask}, 4; + +branch_if_diff_nonzero 4, 0, bootloader_level_{i}_is_right; + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 0; +{P4} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 1; +{P5} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 2; +{P6} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 3; +{P7} <== get_reg(90); + +jump bootloader_level_{i}_end, 90; bootloader_level_{i}_is_right: {P4} <=X= {P0}; {P5} <=X= {P1}; {P6} <=X= {P2}; {P7} <=X= {P3}; -{P0} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 0); -{P1} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 1); -{P2} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 2); -{P3} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 3); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 0; +{P0} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 1; +{P1} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 2; +{P2} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 4 + {i} * 4 + 3; +{P3} <== get_reg(90); + bootloader_level_{i}_end: poseidon_gl; "# @@ -307,49 +342,68 @@ bootloader_level_{i}_end: bootloader.push_str(&format!( r#" -branch_if_nonzero x9, bootloader_update_memory_hash; +branch_if_diff_nonzero 9, 0, bootloader_update_memory_hash; // Assert Correct Merkle Root -branch_if_nonzero {P0} - x18, bootloader_memory_hash_mismatch; -branch_if_nonzero {P1} - x19, bootloader_memory_hash_mismatch; -branch_if_nonzero {P2} - x20, bootloader_memory_hash_mismatch; -branch_if_nonzero {P3} - x21, bootloader_memory_hash_mismatch; -tmp1 <== jump(bootloader_memory_hash_ok); +move_reg 18, 90, -1, 0; +move_reg 90, 90, 1, {P0}; +branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; + +move_reg 19, 90, -1, 0; +move_reg 90, 90, 1, {P1}; +branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; + +move_reg 20, 90, -1, 0; +move_reg 90, 90, 1, {P2}; +branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; + +move_reg 21, 90, -1, 0; +move_reg 90, 90, 1, {P3}; +branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; + +jump bootloader_memory_hash_ok, 90; bootloader_memory_hash_mismatch: fail; bootloader_memory_hash_ok: // Set phase to update -x9 <=X= 1; +set_reg 9, 1; // Load claimed updated page hash into {P0}-{P3} -{P0} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 0); -{P1} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 1); -{P2} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 2); -{P3} <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 3); +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 0; +{P0} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 1; +{P1} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 2; +{P2} <== get_reg(90); + +load_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {WORDS_PER_PAGE} + 3; +{P3} <== get_reg(90); // Repeat Merkle proof validation loop to compute updated Merkle root -tmp1 <== jump(bootloader_merkle_proof_validation_loop); +jump bootloader_merkle_proof_validation_loop, 90; bootloader_update_memory_hash: -x18 <=X= {P0}; -x19 <=X= {P1}; -x20 <=X= {P2}; -x21 <=X= {P3}; +set_reg 18, {P0}; +set_reg 19, {P1}; +set_reg 20, {P2}; +set_reg 21, {P3}; // Increment page index -x2 <=X= x2 + 1; +move_reg 2, 2, 1, 1; -branch_if_nonzero x2 - x1, bootloader_start_page_loop; +branch_if_diff_nonzero 2, 1, bootloader_start_page_loop; bootloader_end_page_loop: // Assert final Merkle root is as claimed -assert_bootloader_input {MEMORY_HASH_START_INDEX} + 4, x18; -assert_bootloader_input {MEMORY_HASH_START_INDEX} + 5, x19; -assert_bootloader_input {MEMORY_HASH_START_INDEX} + 6, x20; -assert_bootloader_input {MEMORY_HASH_START_INDEX} + 7, x21; +assert_bootloader_input 0, 18, 1, {MEMORY_HASH_START_INDEX} + 4; +assert_bootloader_input 0, 19, 1, {MEMORY_HASH_START_INDEX} + 5; +assert_bootloader_input 0, 20, 1, {MEMORY_HASH_START_INDEX} + 6; +assert_bootloader_input 0, 21, 1, {MEMORY_HASH_START_INDEX} + 7; // Initialize registers, starting with index 0 "# @@ -360,9 +414,13 @@ assert_bootloader_input {MEMORY_HASH_START_INDEX} + 7, x21; for (i, reg) in register_iter.enumerate() { let reg = reg.strip_prefix("main.").unwrap(); - bootloader.push_str(&format!(r#"{reg} <== load_bootloader_input({i});"#)); + bootloader.push_str(&format!( + r#"load_bootloader_input 0, {}, 1, {i};"#, + Register::from(reg).addr() + )); bootloader.push('\n'); } + bootloader.push_str(&format!( r#" // Default PC is 0, but we already started from 0, so in that case we do nothing. @@ -407,27 +465,28 @@ shutdown_start: for (i, reg) in register_iter.enumerate() { let reg = reg.strip_prefix("main.").unwrap(); bootloader.push_str(&format!( - "assert_bootloader_input {}, {reg};\n", - i + REGISTER_NAMES.len() + "assert_bootloader_input {i}, {}, 1, {};\n", + Register::from(reg).addr(), + REGISTER_NAMES.len() )); } bootloader.push_str(&format!( r#" // Number of pages -x1 <== load_bootloader_input({NUM_PAGES_INDEX}); -x1 <== wrap(x1); +load_bootloader_input 0, 1, 1, {NUM_PAGES_INDEX}; +add_wrap 1, 0, 0, 1; // Current page index -x2 <=X= 0; +set_reg 2, 0; -branch_if_zero x1, shutdown_end_page_loop; +branch_if_zero 1, 0, 0, shutdown_end_page_loop; shutdown_start_page_loop: // Page number -x3 <== load_bootloader_input(x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET}); -x3 <== and(x3, {PAGE_NUMBER_MASK}); +load_bootloader_input 2, 3, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET}; +and 3, 0, {PAGE_NUMBER_MASK}, 3; // Store & hash {WORDS_PER_PAGE} page words. This is an unrolled loop that for each each word: // - Loads the word at the address x3 * {PAGE_SIZE_BYTES} + i * {BYTES_PER_WORD} @@ -453,11 +512,11 @@ x3 <== and(x3, {PAGE_NUMBER_MASK}); "#, )); + bootloader.push_str(&format!("move_reg 90, 3, {PAGE_SIZE_BYTES}, 0;\n")); for i in 0..WORDS_PER_PAGE { - let reg = SYSCALL_REGISTERS[(i % 4) + 4]; - bootloader.push_str(&format!( - "{reg}, x0 <== mload(x3 * {PAGE_SIZE_BYTES} + {i} * {BYTES_PER_WORD});\n" - )); + let reg = EXTRA_REGISTERS[(i % 4) + 4]; + bootloader.push_str(&format!("mload 90, {i} * {BYTES_PER_WORD}, 90, 91;\n")); + bootloader.push_str(&format!("{reg} <== get_reg(90);\n")); // Hash if buffer is full if i % 4 == 3 { @@ -470,20 +529,27 @@ x3 <== and(x3, {PAGE_NUMBER_MASK}); // Assert page hash is as claimed // At this point, P0-P3 contain the actual page hash at the end of the execution. -assert_bootloader_input x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 0, {P0}; -assert_bootloader_input x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 1, {P1}; -assert_bootloader_input x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 2, {P2}; -assert_bootloader_input x2 * {BOOTLOADER_INPUTS_PER_PAGE} + {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 3, {P3}; + +set_reg 90, {P0}; +assert_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 0; + +set_reg 90, {P1}; +assert_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 1; + +set_reg 90, {P2}; +assert_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 2; + +set_reg 90, {P3}; +assert_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 3; // Increment page index -x2 <=X= x2 + 1; +move_reg 2, 2, 1, 1; -branch_if_nonzero x2 - x1, shutdown_start_page_loop; +branch_if_diff_nonzero 2, 1, shutdown_start_page_loop; shutdown_end_page_loop: - -tmp1 <== jump(shutdown_sink); +jump shutdown_sink, 90; // END OF SHUTDOWN ROUTINE diff --git a/riscv/src/runtime.rs b/riscv/src/runtime.rs index a14635cc6..64cca3f04 100644 --- a/riscv/src/runtime.rs +++ b/riscv/src/runtime.rs @@ -1,13 +1,13 @@ use std::{collections::BTreeMap, convert::TryFrom}; -use powdr_riscv_syscalls::{Syscall, SYSCALL_REGISTERS}; +use powdr_riscv_syscalls::Syscall; use powdr_ast::parsed::asm::{FunctionStatement, MachineStatement, SymbolPath}; use itertools::Itertools; use powdr_parser::ParserContext; -use crate::code_gen::{pop_register, push_register}; +use crate::code_gen::Register; static EXTRA_REG_PREFIX: &str = "xtra"; @@ -41,8 +41,7 @@ struct SubMachine { instance_name: String, /// Instruction declarations instructions: Vec, - /// Number of extra registers needed by this machine's instruction declarations. - /// 26 of the RISC-V registers are available for use, these are added to that number. + /// Number of registers needed by this machine's instruction declarations if > 4. extra_registers: u8, /// TODO: only needed because of witgen requiring that each machine be called at least once init_call: Vec, @@ -90,12 +89,24 @@ impl Runtime { None, "binary", [ - "instr and Y, Z -> X link ~> X = binary.and(Y, Z);", - "instr or Y, Z -> X link ~> X = binary.or(Y, Z);", - "instr xor Y, Z -> X link ~> X = binary.xor(Y, Z);", + r#"instr and X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> val3_col = binary.and(val1_col, val2_col + Z) + link ~> regs.mstore(W, STEP + 3, val3_col);"#, + r#"instr or X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> val3_col = binary.or(val1_col, val2_col + Z) + link ~> regs.mstore(W, STEP + 3, val3_col);"#, + r#"instr xor X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> val3_col = binary.xor(val1_col, val2_col + Z) + link ~> regs.mstore(W, STEP + 3, val3_col);"#, ], 0, - ["x10 <== and(x10, x10);"], + ["and 0, 0, 0, 0;"], ); r.add_submachine( @@ -103,38 +114,63 @@ impl Runtime { None, "shift", [ - "instr shl Y, Z -> X link ~> X = shift.shl(Y, Z);", - "instr shr Y, Z -> X link ~> X = shift.shr(Y, Z);", + r#"instr shl X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> val3_col = shift.shl(val1_col, val2_col + Z) + link ~> regs.mstore(W, STEP + 3, val3_col);"#, + r#"instr shr X, Y, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> val3_col = shift.shr(val1_col, val2_col + Z) + link ~> regs.mstore(W, STEP + 3, val3_col);"#, ], 0, - ["x10 <== shl(x10, x10);"], + ["shl 0, 0, 0, 0;"], ); r.add_submachine( "std::machines::split::split_gl::SplitGL", None, "split_gl", - ["instr split_gl Z -> X, Y link ~> (X, Y) = split_gl.split(Z);"], + [r#"instr split_gl X, Z, W + link ~> val1_col = regs.mload(X, STEP) + link ~> (val3_col, val4_col) = split_gl.split(val1_col) + link ~> regs.mstore(Z, STEP + 2, val3_col) + link ~> regs.mstore(W, STEP + 3, val4_col);"#], 0, - ["x10, x11 <== split_gl(x10);", "x10 <=X= 0;", "x11 <=X= 0;"], + ["split_gl 0, 0, 0;"], ); // Base syscalls r.add_syscall( Syscall::Input, - ["x10 <=X= ${ std::prover::Query::Input(std::convert::int(std::prover::eval(x10))) };"], + [ + // TODO this is a quite inefficient way of getting prover inputs. + // We need to be able to access the register memory within PIL functions. + "query_arg_1 <== get_reg(10);", + "set_reg 10, ${ std::prover::Query::Input(std::convert::int(std::prover::eval(query_arg_1))) };", + ], ); r.add_syscall( Syscall::DataIdentifier, - ["x10 <=X= ${ std::prover::Query::DataIdentifier(std::convert::int(std::prover::eval(x11)), std::convert::int(std::prover::eval(x10))) };"] + [ + "query_arg_1 <== get_reg(10);", + "query_arg_2 <== get_reg(11);", + "set_reg 10, ${ std::prover::Query::DataIdentifier(std::convert::int(std::prover::eval(query_arg_2)), std::convert::int(std::prover::eval(query_arg_1))) };", + ] ); r.add_syscall( Syscall::Output, // This is using x0 on purpose, because we do not want to introduce // nondeterminism with this. - ["x0 <=X= ${ std::prover::Query::Output(std::convert::int(std::prover::eval(x10)), std::convert::int(std::prover::eval(x11))) };"] + [ + "query_arg_1 <== get_reg(10);", + "query_arg_2 <== get_reg(11);", + "set_reg 0, ${ std::prover::Query::Output(std::convert::int(std::prover::eval(query_arg_1)), std::convert::int(std::prover::eval(query_arg_2))) };" + ] ); r @@ -155,13 +191,13 @@ impl Runtime { "poseidon_gl", [format!( "instr poseidon_gl link ~> {};", - instr_link("poseidon_gl.poseidon_permutation", 0, 12, 4) + instr_link("poseidon_gl.poseidon_permutation", 12, 4) )], - 0, + 12, // init call std::iter::once("poseidon_gl;".to_string()) // zero out output registers - .chain((0..4).map(|i| format!("{} <=X= 0;", reg(i)))), + .chain((0..4).map(|i| format!("{} <== get_reg(0);", reg(i)))), ); // The poseidon syscall has a single argument passed on x10, the @@ -170,20 +206,11 @@ impl Runtime { let implementation = // The poseidon syscall uses x10 for input, we store it in tmp3 and // reuse x10 as input to the poseidon machine instruction. - std::iter::once("tmp3 <=X= x10;".to_string()) // The poseidon instruction uses registers 0..12 as input/output. // The memory field elements are loaded into these registers before calling the instruction. - // They might be in use by the riscv machine, so we save the registers on the stack. - .chain((0..12).flat_map(|i| push_register(®(i)))) - .chain((0..12).flat_map(|i| load_gl_fe("tmp3", i as u32 * 8, ®(i)))) + (0..12).flat_map(|i| load_gl_fe(10, i as u32 * 8, ®(i))) .chain(std::iter::once("poseidon_gl;".to_string())) - .chain((0..4).flat_map(|i| store_gl_fe("tmp3", i as u32 * 8, ®(i)))) - // After copying the result back into memory, we restore the original register values. - .chain( - (0..12) - .rev() - .flat_map(|i| pop_register(SYSCALL_REGISTERS[i])), - ); + .chain((0..4).flat_map(|i| store_gl_fe(10, i as u32 * 8, ®(i)))); self.add_syscall(Syscall::PoseidonGL, implementation); self @@ -197,25 +224,24 @@ impl Runtime { [ format!( "instr affine_256 link ~> {};", - instr_link("arith.affine_256", 3, 24, 16) // will use registers 3..27 + instr_link("arith.affine_256", 24, 16) ), format!( "instr ec_add link ~> {};", - instr_link("arith.ec_add", 4, 32, 16) // will use registers 4..36 + instr_link("arith.ec_add", 32, 16) ), format!( "instr ec_double link ~> {};", - instr_link("arith.ec_double", 2, 16, 16) // will use registers 2..18 + instr_link("arith.ec_double", 16, 16) ), format!( "instr mod_256 link ~> {};", - instr_link("arith.mod_256", 3, 24, 8) // will use registers 3..27 + instr_link("arith.mod_256", 24, 8) ), ], - // machine uses the 26 registers from risc-v plus 10 extra registers - 10, + 32, // calling ec_double for machine initialization. - // store x in registers 2..10 + // store x in registers 0..8 [ 0x60297556u32, 0x2f057a14, @@ -228,8 +254,8 @@ impl Runtime { ] .into_iter() .enumerate() - .map(|(i, fe)| format!("{} <=X= {fe};", reg(i + 2))) - // store y in registers 10..18 + .map(|(i, fe)| format!("{} <=X= {fe};", reg(i))) + // store y in registers 8..16 .chain( [ 0xb075f297u32, @@ -243,108 +269,82 @@ impl Runtime { ] .into_iter() .enumerate() - .map(|(i, fe)| format!("{} <=X= {fe};", reg(i + 10))), + .map(|(i, fe)| format!("{} <=X= {fe};", reg(i + 8))), ) // call machine instruction .chain(std::iter::once("ec_double;".to_string())) // set output registers to zero - .chain((2..18).map(|i| format!("{} <=X= 0;", reg(i)))), + .chain((0..16).map(|i| format!("{} <=X= 0;", reg(i)))), ); - // TODO: we're also saving the "extra registers", but those don't have to be saved - // The affine_256 syscall takes as input the addresses of x1, y1 and x2. let affine256 = - // Save instruction registers - (3..27).flat_map(|i| push_register(®(i))) - // Load x1 in 3..11 - .chain((0..8).flat_map(|i| load_word(®(0), i as u32 *4 , ®(i + 3)))) - // Load y1 in 11..19 - .chain((0..8).flat_map(|i| load_word(®(1), i as u32 *4 , ®(i + 11)))) - // Load x2 in 19..27 - .chain((0..8).flat_map(|i| load_word(®(2), i as u32 *4 , ®(i + 19)))) + // Load x1 in 0..8 + (0..8).flat_map(|i| load_word(10, i as u32 *4 , ®(i))) + // Load y1 in 8..16 + .chain((0..8).flat_map(|i| load_word(11, i as u32 *4 , ®(i + 8)))) + // Load x2 in 16..24 + .chain((0..8).flat_map(|i| load_word(12, i as u32 *4 , ®(i + 16)))) // Call instruction .chain(std::iter::once("affine_256;".to_string())) // Store result y2 in x1's memory - .chain((0..8).flat_map(|i| store_word(®(0), i as u32 *4 , ®(i + 3)))) + .chain((0..8).flat_map(|i| store_word(10, i as u32 *4 , ®(i)))) // Store result y3 in y1's memory - .chain((0..8).flat_map(|i| store_word(®(1), i as u32 *4 , ®(i + 11)))) - // Restore instruction registers - .chain( - (3..27) - .rev() - .flat_map(|i| pop_register(®(i)))); + .chain((0..8).flat_map(|i| store_word(11, i as u32 *4 , ®(i + 8)))); + self.add_syscall(Syscall::Affine256, affine256); // The mod_256 syscall takes as input the addresses of y2, y3, and x1. let mod256 = - // Save instruction registers - (3..27).flat_map(|i| push_register(®(i))) - // Load y2 in 3..11 - .chain((0..8).flat_map(|i| load_word(®(0), i as u32 *4 , ®(i + 3)))) - // Load y3 in 11..19 - .chain((0..8).flat_map(|i| load_word(®(1), i as u32 *4 , ®(i + 11)))) - // Load x1 in 19..27 - .chain((0..8).flat_map(|i| load_word(®(2), i as u32 *4 , ®(i + 19)))) + // Load y2 in 0..8 + (0..8).flat_map(|i| load_word(10, i as u32 *4 , ®(i))) + // Load y3 in 8..16 + .chain((0..8).flat_map(|i| load_word(11, i as u32 *4 , ®(i + 8)))) + // Load x1 in 16..24 + .chain((0..8).flat_map(|i| load_word(12, i as u32 *4 , ®(i + 16)))) // Call instruction .chain(std::iter::once("mod_256;".to_string())) // Store result x2 in y2's memory - .chain((0..8).flat_map(|i| store_word(®(0), i as u32 *4 , ®(i + 3)))) - // Restore instruction registers - .chain( - (3..27) - .rev() - .flat_map(|i| pop_register(®(i)))); + .chain((0..8).flat_map(|i| store_word(10, i as u32 *4 , ®(i)))); + self.add_syscall(Syscall::Mod256, mod256); // The ec_add syscall takes as input the four addresses of x1, y1, x2, y2. let ec_add = - // Save instruction registers. - (4..36).flat_map(|i| push_register(®(i))) - // Load x1 in 4..12 - .chain((0..8).flat_map(|i| load_word(®(0), i as u32 * 4, ®(i + 4)))) - // Load y1 in 12..20 - .chain((0..8).flat_map(|i| load_word(®(1), i as u32 * 4, ®(i + 12)))) - // Load x2 in 20..28 - .chain((0..8).flat_map(|i| load_word(®(2), i as u32 * 4, ®(i + 20)))) - // Load y2 in 28..36 - .chain((0..8).flat_map(|i| load_word(®(3), i as u32 * 4, ®(i + 28)))) + // Load x1 in 0..8 + (0..8).flat_map(|i| load_word(10, i as u32 * 4, ®(i))) + // Load y1 in 8..16 + .chain((0..8).flat_map(|i| load_word(11, i as u32 * 4, ®(i + 8)))) + // Load x2 in 16..24 + .chain((0..8).flat_map(|i| load_word(12, i as u32 * 4, ®(i + 16)))) + // Load y2 in 24..32 + .chain((0..8).flat_map(|i| load_word(13, i as u32 * 4, ®(i + 24)))) // Call instruction .chain(std::iter::once("ec_add;".to_string())) // Save result x3 in x1 - .chain((0..8).flat_map(|i| store_word(®(0), i as u32 * 4, ®(i + 4)))) + .chain((0..8).flat_map(|i| store_word(10, i as u32 * 4, ®(i)))) // Save result y3 in y1 - .chain((0..8).flat_map(|i| store_word(®(1), i as u32 * 4, ®(i + 12)))) - // Restore instruction registers. - .chain( - (4..36) - .rev() - .flat_map(|i| pop_register(®(i)))); + .chain((0..8).flat_map(|i| store_word(11, i as u32 * 4, ®(i + 8)))); + self.add_syscall(Syscall::EcAdd, ec_add); // The ec_double syscall takes as input the addresses of x and y in x10 and x11 respectively. - // We load x and y from memory into registers 2..10 and registers 10..18 respectively. + // We load x and y from memory into registers 0..8 and registers 8..16 respectively. // We then store the result from those registers into the same addresses (x10 and x11). let ec_double = - // Save instruction registers. - (2..18).flat_map(|i| push_register(®(i))) - // Load x in 2..10 - .chain((0..8).flat_map(|i| load_word(®(0), i as u32 * 4, ®(i + 2)))) - // Load y in 10..18 - .chain((0..8).flat_map(|i| load_word(®(1), i as u32 * 4, ®(i + 10)))) + // Load x in 0..8 + (0..8).flat_map(|i| load_word(10, i as u32 * 4, ®(i))) + // Load y in 8..16 + .chain((0..8).flat_map(|i| load_word(11, i as u32 * 4, ®(i + 8)))) // Call instruction .chain(std::iter::once("ec_double;".to_string())) // Store result in x - .chain((0..8).flat_map(|i| store_word(®(0), i as u32 * 4, ®(i + 2)))) + .chain((0..8).flat_map(|i| store_word(10, i as u32 * 4, ®(i)))) // Store result in y - .chain((0..8).flat_map(|i| store_word(®(1), i as u32 * 4, ®(i + 10)))) - // Restore instruction registers. - .chain( - (2..18) - .rev() - .flat_map(|i| pop_register(®(i)))); + .chain((0..8).flat_map(|i| store_word(11, i as u32 * 4, ®(i + 8)))); self.add_syscall(Syscall::EcDouble, ec_double); + self } @@ -446,14 +446,14 @@ impl Runtime { let jump_table = self .syscalls .keys() - .map(|s| format!("branch_if_zero x5 - {}, __ecall_handler_{};", *s as u32, s)); + .map(|s| format!("branch_if_zero 5, 0, {}, __ecall_handler_{};", *s as u32, s)); let invalid_handler = ["__invalid_syscall:".to_string(), "fail;".to_string()].into_iter(); let handlers = self.syscalls.iter().flat_map(|(syscall, implementation)| { std::iter::once(format!("__ecall_handler_{syscall}:")) .chain(implementation.0.iter().map(|i| i.to_string())) - .chain(std::iter::once("tmp1 <== jump_dyn(x1);".to_string())) + .chain([format!("jump_dyn 1, {};", Register::from("tmp1").addr())]) }); ecall @@ -489,77 +489,98 @@ impl TryFrom<&[&str]> for Runtime { } /// Helper function for register names used in instruction params -fn reg(mut idx: usize) -> String { - // s0..11 callee saved registers - static SAVED_REGS: [&str; 12] = [ - "x8", "x9", "x18", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", - ]; - - // first, use syscall_registers - if idx < SYSCALL_REGISTERS.len() { - return SYSCALL_REGISTERS[idx].to_string(); - } - idx -= SYSCALL_REGISTERS.len(); - // second, callee saved registers - if idx < SAVED_REGS.len() { - return SAVED_REGS[idx].to_string(); - } - idx -= SAVED_REGS.len(); - // lastly, use extra submachine registers +fn reg(idx: usize) -> String { format!("{EXTRA_REG_PREFIX}{idx}") } /// Helper function to generate instr link for large number input/output registers -fn instr_link(call: &str, start_idx: usize, inputs: usize, outputs: usize) -> String { +fn instr_link(call: &str, inputs: usize, outputs: usize) -> String { format!( "{}{}({})", if outputs > 0 { format!( "({}) = ", - (start_idx..start_idx + outputs) - .map(|i| format!("{}'", reg(i))) - .join(", ") + (0..outputs).map(|i| format!("{}'", reg(i))).join(", ") ) } else { "".to_string() }, call, - (start_idx..start_idx + inputs).map(reg).join(", ") + (0..inputs).map(reg).join(", ") ) } /// Load gl field element from addr+offset into register -fn load_gl_fe(addr: &str, offset: u32, reg: &str) -> [String; 3] { +fn load_gl_fe(addr_reg_id: u32, offset: u32, reg: &str) -> [String; 5] { let lo = offset; let hi = offset + 4; + let tmp1 = Register::from("tmp1"); + let tmp2 = Register::from("tmp2"); + let tmp3 = Register::from("tmp3"); + let tmp4 = Register::from("tmp4"); [ - format!("{reg}, tmp2 <== mload({lo} + {addr});"), - format!("tmp1, tmp2 <== mload({hi} + {addr});"), - format!("{reg} <=X= {reg} + tmp1 * 2**32;"), + format!( + "mload {addr_reg_id}, {lo}, {}, {};", + tmp1.addr(), + tmp2.addr() + ), + format!( + "mload {addr_reg_id}, {hi}, {}, {};", + tmp3.addr(), + tmp4.addr() + ), + format!("query_arg_1 <== get_reg({});", tmp1.addr()), + format!("query_arg_2 <== get_reg({});", tmp3.addr()), + format!("{reg} <=X= query_arg_1 + query_arg_2 * 2**32;"), ] } /// Store gl field element from register into addr+offset -fn store_gl_fe(addr: &str, offset: u32, reg: &str) -> [String; 3] { +fn store_gl_fe(addr_reg_id: u32, offset: u32, reg: &str) -> [String; 4] { let lo = offset; let hi = offset + 4; + let tmp1 = Register::from("tmp1"); + let tmp2 = Register::from("tmp2"); [ - format!("tmp1, tmp2 <== split_gl({reg});"), - format!("mstore {lo} + {addr}, tmp1;"), - format!("mstore {hi} + {addr}, tmp2;"), + format!("set_reg {}, {reg};", tmp1.addr()), + format!( + "split_gl {}, {}, {};", + tmp1.addr(), + tmp1.addr(), + tmp2.addr() + ), + format!("mstore {addr_reg_id}, 0, {lo}, {};", tmp1.addr()), + format!("mstore {addr_reg_id}, 0, {hi}, {};", tmp2.addr()), ] } /// Load word from addr+offset into register -fn load_word(addr: &str, offset: u32, reg: &str) -> [String; 1] { - [format!("{reg}, tmp2 <== mload({offset} + {addr});")] +fn load_word(addr_reg_id: u32, offset: u32, reg: &str) -> [String; 2] { + let tmp1 = Register::from("tmp1"); + let tmp2 = Register::from("tmp2"); + [ + format!( + "mload {addr_reg_id}, {offset}, {}, {};", + tmp1.addr(), + tmp2.addr() + ), + format!("{reg} <=X= get_reg({});", tmp1.addr()), + ] } /// Store word from register into addr+offset -fn store_word(addr: &str, offset: u32, reg: &str) -> [String; 2] { +fn store_word(addr_reg_id: u32, offset: u32, reg: &str) -> [String; 3] { + let tmp1 = Register::from("tmp1"); + let tmp2 = Register::from("tmp2"); [ // split_gl ensures we store a 32-bit value - format!("tmp1, tmp2 <== split_gl({reg});"), - format!("mstore {offset} + {addr}, tmp1;"), + format!("set_reg {}, {reg};", tmp1.addr()), + format!( + "split_gl {}, {}, {};", + tmp1.addr(), + tmp1.addr(), + tmp2.addr() + ), + format!("mstore {addr_reg_id}, 0, {offset}, {};", tmp1.addr()), ] } diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 546b3aa3d..4aa331472 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -248,6 +248,45 @@ fn sum_serde() { ); } +#[test] +fn read_slice() { + let case = "read_slice"; + let runtime = Runtime::base(); + let temp_dir = Temp::new_dir().unwrap(); + let riscv_asm = powdr_riscv::compile_rust_crate_to_riscv_asm( + &format!("tests/riscv_data/{case}/Cargo.toml"), + &temp_dir, + ); + let powdr_asm = powdr_riscv::asm::compile::(riscv_asm, &runtime, false); + + let data: Vec = vec![]; + let answer = data.iter().sum::(); + + use std::collections::BTreeMap; + let d: BTreeMap> = vec![( + 42, + vec![ + 0u32.into(), + 1u32.into(), + 2u32.into(), + 3u32.into(), + 4u32.into(), + 5u32.into(), + 6u32.into(), + 7u32.into(), + ], + )] + .into_iter() + .collect(); + + let mut pipeline = Pipeline::::default() + .from_asm_string(powdr_asm, Some(PathBuf::from(case))) + .with_prover_inputs(vec![answer.into()]) + .with_prover_dict_inputs(d); + + pipeline.compute_witness().unwrap(); +} + #[ignore = "Too slow"] #[test] fn two_sums_serde() { diff --git a/riscv/tests/riscv_data/read_slice/Cargo.toml b/riscv/tests/riscv_data/read_slice/Cargo.toml new file mode 100644 index 000000000..87b75667a --- /dev/null +++ b/riscv/tests/riscv_data/read_slice/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "sum_serde" +version = "0.1.0" +edition = "2021" + +[dependencies] +powdr-riscv-runtime = { path = "../../../../riscv-runtime" } +serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } +serde_cbor = { version = "0.11.2", default-features = false, features = ["alloc"] } + +[workspace] diff --git a/riscv/tests/riscv_data/read_slice/rust-toolchain.toml b/riscv/tests/riscv_data/read_slice/rust-toolchain.toml new file mode 100644 index 000000000..ffe8ad460 --- /dev/null +++ b/riscv/tests/riscv_data/read_slice/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "nightly-2024-02-01" +targets = ["riscv32imac-unknown-none-elf"] +profile = "minimal" diff --git a/riscv/tests/riscv_data/read_slice/src/lib.rs b/riscv/tests/riscv_data/read_slice/src/lib.rs new file mode 100644 index 000000000..b39b268e6 --- /dev/null +++ b/riscv/tests/riscv_data/read_slice/src/lib.rs @@ -0,0 +1,20 @@ +#![no_std] + +extern crate alloc; +use alloc::vec; + +use powdr_riscv_runtime::io::read_slice; + +#[no_mangle] +pub fn main() { + let mut a = vec![0u32; 8]; + read_slice(42, &mut a); + assert_eq!(a[0], 0); + assert_eq!(a[1], 1); + assert_eq!(a[2], 2); + assert_eq!(a[3], 3); + assert_eq!(a[4], 4); + assert_eq!(a[5], 5); + assert_eq!(a[6], 6); + assert_eq!(a[7], 7); +} diff --git a/std/machines/memory.asm b/std/machines/memory.asm index c68fe8bff..f47c77560 100644 --- a/std/machines/memory.asm +++ b/std/machines/memory.asm @@ -70,3 +70,213 @@ machine Memory with col diff = (m_change * (m_addr' - m_addr) + (1 - m_change) * (m_step' - m_step)); (1 - LAST) * (diff - 1 - m_diff_upper * 2**16 - m_diff_lower) = 0; } + +// TODO Remove when https://github.com/powdr-labs/powdr/issues/1572 is done +machine Memory_20 with + degree: 2**20, + latch: LATCH, + operation_id: m_is_write, + call_selectors: selectors, +{ + Byte2 byte2; + + operation mload<0> m_addr, m_step -> m_value; + operation mstore<1> m_addr, m_step, m_value ->; + + let LATCH = 1; + + // =============== read-write memory ======================= + // Read-write memory. Columns are sorted by addr and + // then by step. change is 1 if and only if addr changes + // in the next row. + // Note that these column names are used by witgen to detect + // this machine... + col witness m_addr; + col witness m_step; + col witness m_change; + col witness m_value; + + // Memory operation flags + col witness m_is_write; + std::utils::force_bool(m_is_write); + + // is_write can only be 1 if a selector is active + let is_mem_op = array::sum(selectors); + std::utils::force_bool(is_mem_op); + (1 - is_mem_op) * m_is_write = 0; + + // If the next line is a not a write and we have an address change, + // then the value is zero. + (1 - m_is_write') * m_change * m_value' = 0; + + // change has to be 1 in the last row, so that a first read on row zero is constrained to return 0 + (1 - m_change) * LAST = 0; + + // If the next line is a read and we stay at the same address, then the + // value cannot change. + (1 - m_is_write') * (1 - m_change) * (m_value' - m_value) = 0; + + col witness m_diff_lower; + col witness m_diff_upper; + + col fixed FIRST = [1] + [0]*; + let LAST = FIRST'; + col fixed STEP(i) { i }; + col fixed BIT16(i) { i & 0xffff }; + + link => byte2.check(m_diff_lower); + link => byte2.check(m_diff_upper); + + std::utils::force_bool(m_change); + + // if change is zero, addr has to stay the same. + (m_addr' - m_addr) * (1 - m_change) = 0; + + // Except for the last row, if change is 1, then addr has to increase, + // if it is zero, step has to increase. + // `m_diff_upper * 2**16 + m_diff_lower` has to be equal to the difference **minus one**. + // Since we know that both addr and step can only be 32-Bit, this enforces that + // the values are strictly increasing. + col diff = (m_change * (m_addr' - m_addr) + (1 - m_change) * (m_step' - m_step)); + (1 - LAST) * (diff - 1 - m_diff_upper * 2**16 - m_diff_lower) = 0; +} + +// TODO Remove when https://github.com/powdr-labs/powdr/issues/1572 is done +machine Memory_21 with + degree: 2**21, + latch: LATCH, + operation_id: m_is_write, + call_selectors: selectors, +{ + Byte2 byte2; + + operation mload<0> m_addr, m_step -> m_value; + operation mstore<1> m_addr, m_step, m_value ->; + + let LATCH = 1; + + // =============== read-write memory ======================= + // Read-write memory. Columns are sorted by addr and + // then by step. change is 1 if and only if addr changes + // in the next row. + // Note that these column names are used by witgen to detect + // this machine... + col witness m_addr; + col witness m_step; + col witness m_change; + col witness m_value; + + // Memory operation flags + col witness m_is_write; + std::utils::force_bool(m_is_write); + + // is_write can only be 1 if a selector is active + let is_mem_op = array::sum(selectors); + std::utils::force_bool(is_mem_op); + (1 - is_mem_op) * m_is_write = 0; + + // If the next line is a not a write and we have an address change, + // then the value is zero. + (1 - m_is_write') * m_change * m_value' = 0; + + // change has to be 1 in the last row, so that a first read on row zero is constrained to return 0 + (1 - m_change) * LAST = 0; + + // If the next line is a read and we stay at the same address, then the + // value cannot change. + (1 - m_is_write') * (1 - m_change) * (m_value' - m_value) = 0; + + col witness m_diff_lower; + col witness m_diff_upper; + + col fixed FIRST = [1] + [0]*; + let LAST = FIRST'; + col fixed STEP(i) { i }; + col fixed BIT16(i) { i & 0xffff }; + + link => byte2.check(m_diff_lower); + link => byte2.check(m_diff_upper); + + std::utils::force_bool(m_change); + + // if change is zero, addr has to stay the same. + (m_addr' - m_addr) * (1 - m_change) = 0; + + // Except for the last row, if change is 1, then addr has to increase, + // if it is zero, step has to increase. + // `m_diff_upper * 2**16 + m_diff_lower` has to be equal to the difference **minus one**. + // Since we know that both addr and step can only be 32-Bit, this enforces that + // the values are strictly increasing. + col diff = (m_change * (m_addr' - m_addr) + (1 - m_change) * (m_step' - m_step)); + (1 - LAST) * (diff - 1 - m_diff_upper * 2**16 - m_diff_lower) = 0; +} + +// TODO Remove when https://github.com/powdr-labs/powdr/issues/1572 is done +machine Memory_22 with + degree: 2**22, + latch: LATCH, + operation_id: m_is_write, + call_selectors: selectors, +{ + Byte2 byte2; + + operation mload<0> m_addr, m_step -> m_value; + operation mstore<1> m_addr, m_step, m_value ->; + + let LATCH = 1; + + // =============== read-write memory ======================= + // Read-write memory. Columns are sorted by addr and + // then by step. change is 1 if and only if addr changes + // in the next row. + // Note that these column names are used by witgen to detect + // this machine... + col witness m_addr; + col witness m_step; + col witness m_change; + col witness m_value; + + // Memory operation flags + col witness m_is_write; + std::utils::force_bool(m_is_write); + + // is_write can only be 1 if a selector is active + let is_mem_op = array::sum(selectors); + std::utils::force_bool(is_mem_op); + (1 - is_mem_op) * m_is_write = 0; + + // If the next line is a not a write and we have an address change, + // then the value is zero. + (1 - m_is_write') * m_change * m_value' = 0; + + // change has to be 1 in the last row, so that a first read on row zero is constrained to return 0 + (1 - m_change) * LAST = 0; + + // If the next line is a read and we stay at the same address, then the + // value cannot change. + (1 - m_is_write') * (1 - m_change) * (m_value' - m_value) = 0; + + col witness m_diff_lower; + col witness m_diff_upper; + + col fixed FIRST = [1] + [0]*; + let LAST = FIRST'; + col fixed STEP(i) { i }; + col fixed BIT16(i) { i & 0xffff }; + + link => byte2.check(m_diff_lower); + link => byte2.check(m_diff_upper); + + std::utils::force_bool(m_change); + + // if change is zero, addr has to stay the same. + (m_addr' - m_addr) * (1 - m_change) = 0; + + // Except for the last row, if change is 1, then addr has to increase, + // if it is zero, step has to increase. + // `m_diff_upper * 2**16 + m_diff_lower` has to be equal to the difference **minus one**. + // Since we know that both addr and step can only be 32-Bit, this enforces that + // the values are strictly increasing. + col diff = (m_change * (m_addr' - m_addr) + (1 - m_change) * (m_step' - m_step)); + (1 - LAST) * (diff - 1 - m_diff_upper * 2**16 - m_diff_lower) = 0; +} From de2905de69390a28f786c2ae0e8c662592226ac2 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 22 Jul 2024 18:10:23 +0200 Subject: [PATCH 18/24] Add data structures for variably-sized fixed columns (#1542) This PR adds `number::VariablySizedColumns`, which can store several sizes of the same column. Currently, we always just have one size, but as part of #1496, we can relax that. --- backend/src/composite/mod.rs | 19 ++++++-- backend/src/estark/mod.rs | 10 +++- backend/src/estark/polygon_wrapper.rs | 10 +++- backend/src/estark/starky_wrapper.rs | 11 ++++- backend/src/halo2/mod.rs | 12 ++++- backend/src/lib.rs | 4 +- backend/src/plonky3/mod.rs | 11 ++++- executor/Cargo.toml | 1 + .../src/constant_evaluator/data_structures.rs | 48 +++++++++++++++++++ executor/src/constant_evaluator/mod.rs | 22 +++++++-- executor/src/witgen/block_processor.rs | 8 ++-- executor/src/witgen/global_constraints.rs | 3 ++ executor/src/witgen/mod.rs | 16 ++++--- number/src/lib.rs | 4 +- number/src/serialize.rs | 32 ++++++------- pipeline/src/pipeline.rs | 33 ++++++++----- pipeline/src/util.rs | 45 ++++++++--------- pipeline/tests/asm.rs | 8 ++-- plonky3/src/stark.rs | 4 ++ riscv/src/continuations.rs | 4 +- 20 files changed, 215 insertions(+), 90 deletions(-) create mode 100644 executor/src/constant_evaluator/data_structures.rs diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index 535bee188..502558e6d 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -8,7 +8,10 @@ use std::{ use itertools::Itertools; use powdr_ast::analyzed::Analyzed; -use powdr_executor::witgen::WitgenCallback; +use powdr_executor::{ + constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, + witgen::WitgenCallback, +}; use powdr_number::{DegreeType, FieldElement}; use serde::{Deserialize, Serialize}; use split::{machine_fixed_columns, machine_witness_columns}; @@ -49,7 +52,7 @@ impl> BackendFactory for CompositeBacke fn create<'a>( &self, pil: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, @@ -60,6 +63,11 @@ impl> BackendFactory for CompositeBacke unimplemented!(); } + // TODO: Handle multiple sizes. + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); + let pils = split::split_pil((*pil).clone()); // Read the setup once (if any) to pass to all backends. @@ -105,7 +113,12 @@ impl> BackendFactory for CompositeBacke if let Some(ref output_dir) = output_dir { std::fs::create_dir_all(output_dir)?; } - let fixed = Arc::new(machine_fixed_columns(&fixed, &pil)); + let fixed = Arc::new( + machine_fixed_columns(&fixed, &pil) + .into_iter() + .map(|(column_name, values)| (column_name, values.into())) + .collect(), + ); let backend = self.factory.create( pil.clone(), fixed, diff --git a/backend/src/estark/mod.rs b/backend/src/estark/mod.rs index 137ce144a..ac3223873 100644 --- a/backend/src/estark/mod.rs +++ b/backend/src/estark/mod.rs @@ -15,7 +15,10 @@ use std::{ use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; use powdr_ast::analyzed::Analyzed; -use powdr_executor::witgen::WitgenCallback; +use powdr_executor::{ + constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, + witgen::WitgenCallback, +}; use powdr_number::{DegreeType, FieldElement}; use serde::Serialize; use starky::types::{StarkStruct, Step, PIL}; @@ -222,13 +225,16 @@ impl BackendFactory for DumpFactory { fn create<'a>( &self, analyzed: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, verification_app_key: Option<&mut dyn std::io::Read>, options: BackendOptions, ) -> Result + 'a>, Error> { + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); Ok(Box::new(DumpBackend(EStarkFilesCommon::create( &analyzed, fixed, diff --git a/backend/src/estark/polygon_wrapper.rs b/backend/src/estark/polygon_wrapper.rs index 4758d5d6f..c8d7c594b 100644 --- a/backend/src/estark/polygon_wrapper.rs +++ b/backend/src/estark/polygon_wrapper.rs @@ -1,7 +1,10 @@ use std::{fs, path::PathBuf, sync::Arc}; use powdr_ast::analyzed::Analyzed; -use powdr_executor::witgen::WitgenCallback; +use powdr_executor::{ + constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, + witgen::WitgenCallback, +}; use powdr_number::FieldElement; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; @@ -14,13 +17,16 @@ impl BackendFactory for Factory { fn create<'a>( &self, analyzed: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, verification_app_key: Option<&mut dyn std::io::Read>, options: BackendOptions, ) -> Result + 'a>, Error> { + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); Ok(Box::new(PolygonBackend(EStarkFilesCommon::create( &analyzed, fixed, diff --git a/backend/src/estark/starky_wrapper.rs b/backend/src/estark/starky_wrapper.rs index 9d57aa3c7..1c368bc00 100644 --- a/backend/src/estark/starky_wrapper.rs +++ b/backend/src/estark/starky_wrapper.rs @@ -4,7 +4,10 @@ use std::time::Instant; use crate::{Backend, BackendFactory, BackendOptions, Error}; use powdr_ast::analyzed::Analyzed; -use powdr_executor::witgen::WitgenCallback; +use powdr_executor::{ + constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, + witgen::WitgenCallback, +}; use powdr_number::{FieldElement, GoldilocksField, LargeInt}; use starky::{ @@ -26,7 +29,7 @@ impl BackendFactory for Factory { fn create<'a>( &self, pil: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn std::io::Read>, verification_key: Option<&mut dyn std::io::Read>, @@ -49,6 +52,10 @@ impl BackendFactory for Factory { return Err(Error::NoVariableDegreeAvailable); } + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); + let proof_type: ProofType = ProofType::from(options); let params = create_stark_struct(pil.degree(), proof_type.hash_type()); diff --git a/backend/src/halo2/mod.rs b/backend/src/halo2/mod.rs index 4305468a8..24c9457e3 100644 --- a/backend/src/halo2/mod.rs +++ b/backend/src/halo2/mod.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; use powdr_ast::analyzed::Analyzed; +use powdr_executor::constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}; use powdr_executor::witgen::WitgenCallback; use powdr_number::{DegreeType, FieldElement}; use prover::{generate_setup, Halo2Prover}; @@ -76,7 +77,7 @@ impl BackendFactory for Halo2ProverFactory { fn create<'a>( &self, pil: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, @@ -87,6 +88,9 @@ impl BackendFactory for Halo2ProverFactory { return Err(Error::NoVariableDegreeAvailable); } let proof_type = ProofType::from(options); + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); let mut halo2 = Box::new(Halo2Prover::new(pil, fixed, setup, proof_type)?); if let Some(vk) = verification_key { halo2.add_verification_key(vk); @@ -183,7 +187,7 @@ impl BackendFactory for Halo2MockFactory { fn create<'a>( &self, pil: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, @@ -200,6 +204,10 @@ impl BackendFactory for Halo2MockFactory { return Err(Error::NoAggregationAvailable); } + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); + Ok(Box::new(Halo2Mock { pil, fixed })) } } diff --git a/backend/src/lib.rs b/backend/src/lib.rs index e490fa9bb..9bc782545 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -9,7 +9,7 @@ mod plonky3; mod composite; use powdr_ast::analyzed::Analyzed; -use powdr_executor::witgen::WitgenCallback; +use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback}; use powdr_number::{DegreeType, FieldElement}; use std::{io, path::PathBuf, sync::Arc}; use strum::{Display, EnumString, EnumVariantNames}; @@ -134,7 +134,7 @@ pub trait BackendFactory { fn create<'a>( &self, pil: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, diff --git a/backend/src/plonky3/mod.rs b/backend/src/plonky3/mod.rs index 1b22ad1cf..76d99466c 100644 --- a/backend/src/plonky3/mod.rs +++ b/backend/src/plonky3/mod.rs @@ -1,7 +1,10 @@ use std::{io, path::PathBuf, sync::Arc}; use powdr_ast::analyzed::Analyzed; -use powdr_executor::witgen::WitgenCallback; +use powdr_executor::{ + constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, + witgen::WitgenCallback, +}; use powdr_number::{FieldElement, GoldilocksField, LargeInt}; use powdr_plonky3::Plonky3Prover; @@ -13,7 +16,7 @@ impl BackendFactory for Factory { fn create<'a>( &self, pil: Arc>, - fixed: Arc)>>, + fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, verification_key: Option<&mut dyn io::Read>, @@ -34,6 +37,10 @@ impl BackendFactory for Factory { return Err(Error::NoVariableDegreeAvailable); } + let fixed = Arc::new( + get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, + ); + let mut p3 = Box::new(Plonky3Prover::new(pil, fixed)); if let Some(verification_key) = verification_key { diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 83f2f9219..0079cce9a 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -20,6 +20,7 @@ bit-vec = "0.6.3" num-traits = "0.2.15" lazy_static = "1.4.0" indicatif = "0.17.7" +serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } [dev-dependencies] test-log = "0.2.12" diff --git a/executor/src/constant_evaluator/data_structures.rs b/executor/src/constant_evaluator/data_structures.rs new file mode 100644 index 000000000..cd5c288ff --- /dev/null +++ b/executor/src/constant_evaluator/data_structures.rs @@ -0,0 +1,48 @@ +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; + +#[derive(Serialize, Deserialize)] +pub struct VariablySizedColumn { + column_by_size: BTreeMap>, +} + +#[derive(Debug)] +pub struct HasMultipleSizesError; + +impl VariablySizedColumn { + /// Create a view where each column has a single size. Fails if any column has multiple sizes. + pub fn get_uniquely_sized(&self) -> Result<&Vec, HasMultipleSizesError> { + if self.column_by_size.len() != 1 { + return Err(HasMultipleSizesError); + } + Ok(self.column_by_size.values().next().unwrap()) + } +} + +pub fn get_uniquely_sized( + column: &[(String, VariablySizedColumn)], +) -> Result)>, HasMultipleSizesError> { + column + .iter() + .map(|(name, column)| Ok((name.clone(), column.get_uniquely_sized()?))) + .collect() +} + +pub fn get_uniquely_sized_cloned( + column: &[(String, VariablySizedColumn)], +) -> Result)>, HasMultipleSizesError> { + get_uniquely_sized(column).map(|column| { + column + .into_iter() + .map(|(name, column)| (name, column.clone())) + .collect() + }) +} + +impl From> for VariablySizedColumn { + fn from(column: Vec) -> Self { + VariablySizedColumn { + column_by_size: [(column.len(), column)].into_iter().collect(), + } + } +} diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 9fa007bc9..5168c1d84 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -3,6 +3,7 @@ use std::{ sync::{Arc, RwLock}, }; +pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; use itertools::Itertools; use powdr_ast::{ analyzed::{Analyzed, FunctionValueDefinition, Symbol, TypedExpression}, @@ -15,12 +16,14 @@ use powdr_number::{BigInt, BigUint, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +mod data_structures; + /// Generates the fixed column values for all fixed columns that are defined /// (and not just declared). /// @returns the names (in source order) and the values for the columns. /// Arrays of columns are flattened, the name of the `i`th array element /// is `name[i]`. -pub fn generate(analyzed: &Analyzed) -> Vec<(String, Vec)> { +pub fn generate(analyzed: &Analyzed) -> Vec<(String, VariablySizedColumn)> { let mut fixed_cols = HashMap::new(); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { @@ -37,8 +40,8 @@ pub fn generate(analyzed: &Analyzed) -> Vec<(String, Vec) fixed_cols .into_iter() .sorted_by_key(|(_, (id, _))| *id) - .map(|(name, (_, values))| (name, values)) - .collect::>() + .map(|(name, (_, values))| (name, values.into())) + .collect() } fn generate_values( @@ -169,17 +172,28 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for CachedSymbols<'a, T> { #[cfg(test)] mod test { + use powdr_ast::analyzed::Analyzed; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; use pretty_assertions::assert_eq; use test_log::test; - use super::*; + use crate::constant_evaluator::{ + data_structures::get_uniquely_sized, generate as generate_variably_sized, + }; fn convert(input: Vec) -> Vec { input.into_iter().map(|x| x.into()).collect() } + fn generate(analyzed: &Analyzed) -> Vec<(String, Vec)> { + get_uniquely_sized(&generate_variably_sized(analyzed)) + .unwrap() + .into_iter() + .map(|(name, values)| (name, values.clone())) + .collect() + } + #[test] fn last() { let src = r#" diff --git a/executor/src/witgen/block_processor.rs b/executor/src/witgen/block_processor.rs index ca78f96df..06fd07fec 100644 --- a/executor/src/witgen/block_processor.rs +++ b/executor/src/witgen/block_processor.rs @@ -121,7 +121,7 @@ mod tests { use powdr_pil_analyzer::analyze_string; use crate::{ - constant_evaluator::generate, + constant_evaluator::{generate, get_uniquely_sized}, witgen::{ data_structures::finalizable_data::FinalizableData, identity_processor::Machines, @@ -152,10 +152,8 @@ mod tests { f: impl Fn(BlockProcessor, BTreeMap, u64, usize) -> R, ) -> R { let analyzed = analyze_string(src); - let constants = generate(&analyzed) - .into_iter() - .map(|(n, c)| (n.to_string(), c)) - .collect::>(); + let constants = generate(&analyzed); + let constants = get_uniquely_sized(&constants).unwrap(); let fixed_data = FixedData::new(&analyzed, &constants, &[], Default::default(), 0); // No submachines diff --git a/executor/src/witgen/global_constraints.rs b/executor/src/witgen/global_constraints.rs index 37a48d393..7a61ac5f9 100644 --- a/executor/src/witgen/global_constraints.rs +++ b/executor/src/witgen/global_constraints.rs @@ -365,6 +365,8 @@ mod test { use pretty_assertions::assert_eq; use test_log::test; + use crate::constant_evaluator::get_uniquely_sized; + use super::*; #[test] @@ -437,6 +439,7 @@ namespace Global(2**20); "; let analyzed = powdr_pil_analyzer::analyze_string::(pil_source); let constants = crate::constant_evaluator::generate(&analyzed); + let constants = get_uniquely_sized(&constants).unwrap(); let fixed_polys = (0..constants.len()) .map(|i| constant_poly_id(i as u64)) .collect::>(); diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index 099840342..d3ddbd17b 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -10,6 +10,8 @@ use powdr_ast::parsed::visitor::ExpressionVisitable; use powdr_ast::parsed::{FunctionKind, LambdaExpression}; use powdr_number::{DegreeType, FieldElement}; +use crate::constant_evaluator::{get_uniquely_sized, VariablySizedColumn}; + use self::data_structures::column_map::{FixedColumnMap, WitnessColumnMap}; pub use self::eval_result::{ Constraint, Constraints, EvalError, EvalResult, EvalStatus, EvalValue, IncompleteCause, @@ -50,14 +52,14 @@ impl QueryCallback for F where F: Fn(&str) -> Result, String> #[derive(Clone)] pub struct WitgenCallback { analyzed: Arc>, - fixed_col_values: Arc)>>, + fixed_col_values: Arc)>>, query_callback: Arc>, } impl WitgenCallback { pub fn new( analyzed: Arc>, - fixed_col_values: Arc)>>, + fixed_col_values: Arc)>>, query_callback: Option>>, ) -> Self { let query_callback = query_callback.unwrap_or_else(|| Arc::new(unused_query_callback())); @@ -111,7 +113,7 @@ pub struct MutableState<'a, 'b, T: FieldElement, Q: QueryCallback> { pub struct WitnessGenerator<'a, 'b, T: FieldElement> { analyzed: &'a Analyzed, - fixed_col_values: &'b [(String, Vec)], + fixed_col_values: &'b Vec<(String, VariablySizedColumn)>, query_callback: &'b dyn QueryCallback, external_witness_values: &'b [(String, Vec)], stage: u8, @@ -121,7 +123,7 @@ pub struct WitnessGenerator<'a, 'b, T: FieldElement> { impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { pub fn new( analyzed: &'a Analyzed, - fixed_col_values: &'b [(String, Vec)], + fixed_col_values: &'b Vec<(String, VariablySizedColumn)>, query_callback: &'b dyn QueryCallback, ) -> Self { WitnessGenerator { @@ -156,9 +158,11 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { /// @returns the values (in source order) and the degree of the polynomials. pub fn generate(self) -> Vec<(String, Vec)> { record_start(OUTER_CODE_NAME); + // TODO: Handle multiple sizes + let fixed_col_values = get_uniquely_sized(self.fixed_col_values).unwrap(); let fixed = FixedData::new( self.analyzed, - self.fixed_col_values, + &fixed_col_values, self.external_witness_values, self.challenges, self.stage, @@ -328,7 +332,7 @@ impl<'a, T: FieldElement> FixedData<'a, T> { pub fn new( analyzed: &'a Analyzed, - fixed_col_values: &'a [(String, Vec)], + fixed_col_values: &'a [(String, &'a Vec)], external_witness_values: &'a [(String, Vec)], challenges: BTreeMap, stage: u8, diff --git a/number/src/lib.rs b/number/src/lib.rs index 2c1cb2dc4..8842281c6 100644 --- a/number/src/lib.rs +++ b/number/src/lib.rs @@ -8,10 +8,8 @@ mod bn254; mod goldilocks; mod serialize; mod traits; - pub use serialize::{ - buffered_write_file, read_polys_csv_file, read_polys_file, write_polys_csv_file, - write_polys_file, CsvRenderMode, + buffered_write_file, read_polys_csv_file, write_polys_csv_file, CsvRenderMode, ReadWrite, }; pub use bn254::Bn254Field; diff --git a/number/src/serialize.rs b/number/src/serialize.rs index 631298c91..3bd80d963 100644 --- a/number/src/serialize.rs +++ b/number/src/serialize.rs @@ -6,6 +6,7 @@ use std::{ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use csv::{Reader, Writer}; +use serde::{de::DeserializeOwned, Serialize}; use serde_with::{DeserializeAs, SerializeAs}; use crate::FieldElement; @@ -102,24 +103,19 @@ pub fn buffered_write_file( Ok(result) } -pub fn write_polys_file( - path: &Path, - polys: &[(String, Vec)], -) -> Result<(), serde_cbor::Error> { - buffered_write_file(path, |writer| write_polys_stream(writer, polys))??; - - Ok(()) -} - -fn write_polys_stream( - file: &mut impl Write, - polys: &[(String, Vec)], -) -> Result<(), serde_cbor::Error> { - serde_cbor::to_writer(file, &polys) +pub trait ReadWrite { + fn read(file: &mut impl Read) -> Self; + fn write(&self, path: &Path) -> Result<(), serde_cbor::Error>; } -pub fn read_polys_file(file: &mut impl Read) -> Vec<(String, Vec)> { - serde_cbor::from_reader(file).unwrap() +impl ReadWrite for T { + fn read(file: &mut impl Read) -> Self { + serde_cbor::from_reader(file).unwrap() + } + fn write(&self, path: &Path) -> Result<(), serde_cbor::Error> { + buffered_write_file(path, |writer| serde_cbor::to_writer(writer, &self))??; + Ok(()) + } } // Serde wrappers for serialize/deserialize @@ -164,8 +160,8 @@ mod tests { let polys = test_polys(); - write_polys_stream(&mut buf, &polys).unwrap(); - let read_polys = read_polys_file::(&mut Cursor::new(buf)); + serde_cbor::to_writer(&mut buf, &polys).unwrap(); + let read_polys: Vec<(String, Vec)> = ReadWrite::read(&mut Cursor::new(buf)); assert_eq!(read_polys, polys); } diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index 9836a10ad..d7d9dfd70 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -10,6 +10,7 @@ use std::{ time::Instant, }; +use crate::util::PolySet; use log::Level; use mktemp::Temp; use powdr_ast::{ @@ -20,23 +21,24 @@ use powdr_ast::{ }; use powdr_backend::{BackendOptions, BackendType, Proof}; use powdr_executor::{ - constant_evaluator, + constant_evaluator::{self, get_uniquely_sized_cloned, VariablySizedColumn}, witgen::{ chain_callbacks, extract_publics, unused_query_callback, QueryCallback, WitgenCallback, WitnessGenerator, }, }; -use powdr_number::{write_polys_csv_file, write_polys_file, CsvRenderMode, FieldElement}; +use powdr_number::{write_polys_csv_file, CsvRenderMode, FieldElement, ReadWrite}; use powdr_schemas::SerializedAnalyzed; use crate::{ dict_data_to_query_callback, handle_simple_queries_callback, inputs_to_query_callback, serde_data_to_query_callback, - util::{read_poly_set, FixedPolySet, WitnessPolySet}, + util::{FixedPolySet, WitnessPolySet}, }; use std::collections::BTreeMap; type Columns = Vec<(String, Vec)>; +type VariablySizedColumns = Vec<(String, VariablySizedColumn)>; #[derive(Default, Clone)] pub struct Artifacts { @@ -68,7 +70,7 @@ pub struct Artifacts { /// An optimized .pil file. optimized_pil: Option>>, /// Fully evaluated fixed columns. - fixed_cols: Option>>, + fixed_cols: Option>>, /// Generated witnesses. witness: Option>>, /// The proof (if successful). @@ -397,7 +399,7 @@ impl Pipeline { /// Reads previously generated fixed columns from the provided directory. pub fn read_constants(self, directory: &Path) -> Self { - let fixed = read_poly_set::(directory); + let fixed = FixedPolySet::::read(directory); Pipeline { artifact: Artifacts { @@ -410,7 +412,7 @@ impl Pipeline { /// Reads a previously generated witness from the provided directory. pub fn read_witness(self, directory: &Path) -> Self { - let witness = read_poly_set::(directory); + let witness = WitnessPolySet::::read(directory); Pipeline { artifact: Artifacts { @@ -499,24 +501,29 @@ impl Pipeline { Ok(()) } - fn maybe_write_constants(&self, constants: &[(String, Vec)]) -> Result<(), Vec> { + fn maybe_write_constants( + &self, + constants: &VariablySizedColumns, + ) -> Result<(), Vec> { if let Some(path) = self.path_if_should_write(|_| "constants.bin".to_string())? { - write_polys_file(&path, constants).map_err(|e| vec![format!("{}", e)])?; + constants.write(&path).map_err(|e| vec![format!("{}", e)])?; } Ok(()) } fn maybe_write_witness( &self, - fixed: &[(String, Vec)], - witness: &[(String, Vec)], + fixed: &VariablySizedColumns, + witness: &Columns, ) -> Result<(), Vec> { if let Some(path) = self.path_if_should_write(|_| "commits.bin".to_string())? { - write_polys_file(&path, witness).map_err(|e| vec![format!("{}", e)])?; + witness.write(&path).map_err(|e| vec![format!("{}", e)])?; } if self.arguments.export_witness_csv { if let Some(path) = self.path_if_should_write(|name| format!("{name}_columns.csv"))? { + // TODO: Handle multiple sizes + let fixed = get_uniquely_sized_cloned(fixed).unwrap(); let columns = fixed.iter().chain(witness.iter()).collect::>(); let csv_file = fs::File::create(path).map_err(|e| vec![format!("{}", e)])?; @@ -799,7 +806,7 @@ impl Pipeline { Ok(self.artifact.optimized_pil.as_ref().unwrap().clone()) } - pub fn compute_fixed_cols(&mut self) -> Result>, Vec> { + pub fn compute_fixed_cols(&mut self) -> Result>, Vec> { if let Some(ref fixed_cols) = self.artifact.fixed_cols { return Ok(fixed_cols.clone()); } @@ -821,7 +828,7 @@ impl Pipeline { Ok(self.artifact.fixed_cols.as_ref().unwrap().clone()) } - pub fn fixed_cols(&self) -> Result>, Vec> { + pub fn fixed_cols(&self) -> Result>, Vec> { Ok(self.artifact.fixed_cols.as_ref().unwrap().clone()) } diff --git a/pipeline/src/util.rs b/pipeline/src/util.rs index 8c1ddf620..c6415f387 100644 --- a/pipeline/src/util.rs +++ b/pipeline/src/util.rs @@ -1,38 +1,39 @@ use powdr_ast::analyzed::{Analyzed, FunctionValueDefinition, Symbol}; -use powdr_number::{read_polys_file, FieldElement}; -use std::{fs::File, io::BufReader, path::Path}; +use powdr_executor::constant_evaluator::VariablySizedColumn; +use powdr_number::ReadWrite; +use serde::{de::DeserializeOwned, Serialize}; +use std::{fs::File, io::BufReader, marker::PhantomData, path::Path}; -pub trait PolySet { +pub trait PolySet { const FILE_NAME: &'static str; - fn get_polys( - pil: &Analyzed, - ) -> Vec<&(Symbol, Option)>; + fn get_polys(pil: &Analyzed) -> Vec<&(Symbol, Option)>; + + fn read(dir: &Path) -> C { + let path = dir.join(Self::FILE_NAME); + C::read(&mut BufReader::new(File::open(path).unwrap())) + } } -pub struct FixedPolySet; -impl PolySet for FixedPolySet { +pub struct FixedPolySet { + _phantom: PhantomData, +} +impl PolySet)>, T> + for FixedPolySet +{ const FILE_NAME: &'static str = "constants.bin"; - fn get_polys( - pil: &Analyzed, - ) -> Vec<&(Symbol, Option)> { + fn get_polys(pil: &Analyzed) -> Vec<&(Symbol, Option)> { pil.constant_polys_in_source_order() } } -pub struct WitnessPolySet; -impl PolySet for WitnessPolySet { +pub struct WitnessPolySet { + _phantom: PhantomData, +} +impl PolySet)>, T> for WitnessPolySet { const FILE_NAME: &'static str = "commits.bin"; - fn get_polys( - pil: &Analyzed, - ) -> Vec<&(Symbol, Option)> { + fn get_polys(pil: &Analyzed) -> Vec<&(Symbol, Option)> { pil.committed_polys_in_source_order() } } - -#[allow(clippy::type_complexity)] -pub fn read_poly_set(dir: &Path) -> Vec<(String, Vec)> { - let path = dir.join(P::FILE_NAME); - read_polys_file(&mut BufReader::new(File::open(path).unwrap())) -} diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index 44027edec..c1540b725 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -1,4 +1,5 @@ use powdr_backend::BackendType; +use powdr_executor::constant_evaluator::get_uniquely_sized; use powdr_number::{Bn254Field, FieldElement, GoldilocksField}; use powdr_pipeline::{ test_util::{ @@ -6,7 +7,7 @@ use powdr_pipeline::{ resolve_test_file, run_pilcom_test_file, run_pilcom_with_backend_variant, test_halo2, test_halo2_with_backend_variant, BackendVariant, }, - util::{read_poly_set, FixedPolySet, WitnessPolySet}, + util::{FixedPolySet, PolySet, WitnessPolySet}, Pipeline, }; use test_log::test; @@ -401,13 +402,14 @@ fn read_poly_files() { pipeline.compute_proof().unwrap(); // check fixed cols (may have no fixed cols) - let fixed = read_poly_set::(tmp_dir.as_path()); + let fixed = FixedPolySet::::read(tmp_dir.as_path()); + let fixed = get_uniquely_sized(&fixed).unwrap(); if !fixed.is_empty() { assert_eq!(pil.degree(), fixed[0].1.len() as u64); } // check witness cols (examples assumed to have at least one witness col) - let witness = read_poly_set::(tmp_dir.as_path()); + let witness = WitnessPolySet::::read(tmp_dir.as_path()); assert_eq!(pil.degree(), witness[0].1.len() as u64); } } diff --git a/plonky3/src/stark.rs b/plonky3/src/stark.rs index eeb199fc1..5017be074 100644 --- a/plonky3/src/stark.rs +++ b/plonky3/src/stark.rs @@ -203,6 +203,9 @@ impl Plonky3Prover { #[cfg(test)] mod tests { + use std::sync::Arc; + + use powdr_executor::constant_evaluator::get_uniquely_sized_cloned; use powdr_number::GoldilocksField; use powdr_pipeline::Pipeline; use test_log::test; @@ -221,6 +224,7 @@ mod tests { let witness_callback = pipeline.witgen_callback().unwrap(); let witness = pipeline.compute_witness().unwrap(); let fixed = pipeline.compute_fixed_cols().unwrap(); + let fixed = Arc::new(get_uniquely_sized_cloned(&fixed).unwrap()); let mut prover = Plonky3Prover::new(pil, fixed); prover.setup(); diff --git a/riscv/src/continuations.rs b/riscv/src/continuations.rs index 5ed348b51..0e882692a 100644 --- a/riscv/src/continuations.rs +++ b/riscv/src/continuations.rs @@ -7,6 +7,7 @@ use powdr_ast::{ asm_analysis::AnalysisASMFile, parsed::{asm::parse_absolute_path, Expression, Number, PilStatement}, }; +use powdr_executor::constant_evaluator::get_uniquely_sized; use powdr_number::FieldElement; use powdr_pipeline::Pipeline; use powdr_riscv_executor::{get_main_machine, Elem, ExecutionTrace, MemoryState, ProfilerOptions}; @@ -77,7 +78,8 @@ where pipeline.compute_optimized_pil().unwrap(); // TODO hacky way to find the degree of the main machine, fix. - let length = fixed_cols + let length = get_uniquely_sized(&fixed_cols) + .unwrap() .iter() .find(|(col, _)| col == "main.STEP") .unwrap() From 5bfeea8b584758345710d34c58e1c9727643b1a6 Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Tue, 23 Jul 2024 11:22:42 +0200 Subject: [PATCH 19/24] Extract range checks from riscv into std machines (#1594) Also, allow vms to have links outside instructions. --------- Co-authored-by: Leo Alt --- analysis/src/machine_check.rs | 14 +---- riscv/src/code_gen.rs | 39 ++++++-------- riscv/src/runtime.rs | 8 +++ std/machines/arith.asm | 2 +- std/machines/byte2.asm | 11 ---- std/machines/memory.asm | 2 +- std/machines/memory_with_bootloader_write.asm | 2 +- std/machines/mod.asm | 2 +- std/machines/range.asm | 54 +++++++++++++++++++ 9 files changed, 85 insertions(+), 49 deletions(-) delete mode 100644 std/machines/byte2.asm create mode 100644 std/machines/range.asm diff --git a/analysis/src/machine_check.rs b/analysis/src/machine_check.rs index 70259f053..4286d7e84 100644 --- a/analysis/src/machine_check.rs +++ b/analysis/src/machine_check.rs @@ -247,11 +247,6 @@ impl TypeChecker { "Machine {ctx} should not have call_selectors as it has a pc" )); } - for _ in &links { - errors.push(format!( - "Machine {ctx} has a pc, links cannot be used outside of instructions." - )); - } for o in callable.operation_definitions() { errors.push(format!( "Machine {ctx} should not have operations as it has a pc, found `{}`", @@ -446,7 +441,7 @@ machine Main with latch: latch, operation_id: id { } #[test] - fn virtual_machine_has_no_links() { + fn virtual_machine_with_links() { let src = r#" machine Main { reg pc[@pc]; @@ -456,12 +451,7 @@ machine Main { link => B = submachine.foo(A); } "#; - expect_check_str( - src, - Err(vec![ - "Machine ::Main has a pc, links cannot be used outside of instructions.", - ]), - ); + expect_check_str(src, Ok(())); } #[test] diff --git a/riscv/src/code_gen.rs b/riscv/src/code_gen.rs index 3afdba318..a12a879a8 100644 --- a/riscv/src/code_gen.rs +++ b/riscv/src/code_gen.rs @@ -350,7 +350,7 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo "".to_string() }; - for machine in ["binary", "shift"] { + for machine in ["binary", "shift", "bit2", "bit6", "bit7", "byte"] { assert!( runtime.has_submachine(machine), "RISC-V machine requires the `{machine}` submachine" @@ -549,15 +549,14 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo .map(|s| format!(" {s}")) .join("\n") + r#" - col fixed bytes(i) { i & 0xff }; col witness X_b1; col witness X_b2; col witness X_b3; col witness X_b4; - [ X_b1 ] in [ bytes ]; - [ X_b2 ] in [ bytes ]; - [ X_b3 ] in [ bytes ]; - [ X_b4 ] in [ bytes ]; + link => byte.check(X_b1); + link => byte.check(X_b2); + link => byte.check(X_b3); + link => byte.check(X_b4); col witness wrap_bit; wrap_bit * (1 - wrap_bit) = 0; @@ -571,10 +570,8 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo val1_col = Y_7bit + wrap_bit * 0x80 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, val3_col = Y_7bit + wrap_bit * 0xffffff80 } - - col fixed seven_bit(i) { i & 0x7f }; col witness Y_7bit; - [ Y_7bit ] in [ seven_bit ]; + link => bit7.check(Y_7bit); // Sign extends the value in register X and stores it in register Y. // Input is a 32 bit unsigned number. We check bit 15 and set all higher bits to that value. @@ -622,19 +619,19 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo col witness Y_b6; col witness Y_b7; col witness Y_b8; - [ Y_b5 ] in [ bytes ]; - [ Y_b6 ] in [ bytes ]; - [ Y_b7 ] in [ bytes ]; - [ Y_b8 ] in [ bytes ]; + link => byte.check(Y_b5); + link => byte.check(Y_b6); + link => byte.check(Y_b7); + link => byte.check(Y_b8); col witness REM_b1; col witness REM_b2; col witness REM_b3; col witness REM_b4; - [ REM_b1 ] in [ bytes ]; - [ REM_b2 ] in [ bytes ]; - [ REM_b3 ] in [ bytes ]; - [ REM_b4 ] in [ bytes ]; + link => byte.check(REM_b1); + link => byte.check(REM_b2); + link => byte.check(REM_b3); + link => byte.check(REM_b4); // Computes Q = val(Y) / val(X) and R = val(Y) % val(X) and stores them in registers Z and W. instr divremu Y, X, Z, W @@ -735,8 +732,6 @@ fn memory(with_bootloader: bool) -> String { // ============== memory instructions ============== - let up_to_three: col = |i| i % 4; - let six_bits: col = |i| i % 2**6; /// Loads one word from an address V = val(X) + Y, where V can be between 0 and 2**33 (sic!), /// wraps the address to 32 bits and rounds it down to the next multiple of 4. /// Writes the loaded word and the remainder of the division by 4 to registers Z and W, @@ -746,10 +741,10 @@ fn memory(with_bootloader: bool) -> String { link ~> val3_col = memory.mload(X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4, STEP + 1) link ~> regs.mstore(Z, STEP + 2, val3_col) link ~> regs.mstore(W, STEP + 3, val4_col) + link => bit2.check(val4_col) + link => bit6.check(X_b1) { - [ val4_col ] in [ up_to_three ], - val1_col + Y = wrap_bit * 2**32 + X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4 + val4_col, - [ X_b1 ] in [ six_bits ] + val1_col + Y = wrap_bit * 2**32 + X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4 + val4_col } // Stores val(W) at address (V = val(X) - val(Y) + Z) % 2**32. diff --git a/riscv/src/runtime.rs b/riscv/src/runtime.rs index 64cca3f04..a26e9909e 100644 --- a/riscv/src/runtime.rs +++ b/riscv/src/runtime.rs @@ -142,6 +142,14 @@ impl Runtime { ["split_gl 0, 0, 0;"], ); + r.add_submachine::<&str, _, _>("std::machines::range::Bit2", None, "bit2", [], 0, []); + + r.add_submachine::<&str, _, _>("std::machines::range::Bit6", None, "bit6", [], 0, []); + + r.add_submachine::<&str, _, _>("std::machines::range::Bit7", None, "bit7", [], 0, []); + + r.add_submachine::<&str, _, _>("std::machines::range::Byte", None, "byte", [], 0, []); + // Base syscalls r.add_syscall( Syscall::Input, diff --git a/std/machines/arith.asm b/std/machines/arith.asm index d32c36ecc..fdb21de32 100644 --- a/std/machines/arith.asm +++ b/std/machines/arith.asm @@ -9,7 +9,7 @@ use std::convert::fe; use std::convert::expr; use std::prover::eval; use std::prover::Query; -use std::machines::byte2::Byte2; +use std::machines::range::Byte2; // Arithmetic machine, ported mainly from Polygon: https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/arith.pil // Currently only supports "Equation 0", i.e., 256-Bit addition and multiplication. diff --git a/std/machines/byte2.asm b/std/machines/byte2.asm deleted file mode 100644 index 7e19e90eb..000000000 --- a/std/machines/byte2.asm +++ /dev/null @@ -1,11 +0,0 @@ -/// A machine to check that a field element represents two bytes. It uses an exhaustive lookup table. -machine Byte2 with - latch: latch, - operation_id: operation_id -{ - operation check<0> BYTE2 -> ; - - let BYTE2: col = |i| i & 0xffff; - col fixed latch = [1]*; - col fixed operation_id = [0]*; -} \ No newline at end of file diff --git a/std/machines/memory.asm b/std/machines/memory.asm index f47c77560..5793d608f 100644 --- a/std/machines/memory.asm +++ b/std/machines/memory.asm @@ -1,5 +1,5 @@ use std::array; -use std::machines::byte2::Byte2; +use std::machines::range::Byte2; // A read/write memory, similar to that of Polygon: // https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/mem.pil diff --git a/std/machines/memory_with_bootloader_write.asm b/std/machines/memory_with_bootloader_write.asm index 3782e2471..d3ecceddc 100644 --- a/std/machines/memory_with_bootloader_write.asm +++ b/std/machines/memory_with_bootloader_write.asm @@ -1,5 +1,5 @@ use std::array; -use std::machines::byte2::Byte2; +use std::machines::range::Byte2; /// This machine is a slightly extended version of std::machines::memory::Memory, /// where in addition to mstore, there is an mstore_bootloader operation. It behaves diff --git a/std/machines/mod.asm b/std/machines/mod.asm index a5b59bf2f..5e0306287 100644 --- a/std/machines/mod.asm +++ b/std/machines/mod.asm @@ -1,6 +1,6 @@ mod arith; mod binary; -mod byte2; +mod range; mod hash; mod memory; mod memory_with_bootloader_write; diff --git a/std/machines/range.asm b/std/machines/range.asm new file mode 100644 index 000000000..033889c92 --- /dev/null +++ b/std/machines/range.asm @@ -0,0 +1,54 @@ +machine Byte with + latch: latch, + operation_id: operation_id +{ + operation check<0> BYTE -> ; + + let BYTE: col = |i| i & 0xff; + col fixed latch = [1]*; + col fixed operation_id = [0]*; +} + +machine Byte2 with + latch: latch, + operation_id: operation_id +{ + operation check<0> BYTE2 -> ; + + let BYTE2: col = |i| i & 0xffff; + col fixed latch = [1]*; + col fixed operation_id = [0]*; +} + +machine Bit2 with + latch: latch, + operation_id: operation_id +{ + operation check<0> BIT2 -> ; + + let BIT2: col = |i| i % 4; + col fixed latch = [1]*; + col fixed operation_id = [0]*; +} + +machine Bit6 with + latch: latch, + operation_id: operation_id +{ + operation check<0> BIT6 -> ; + + let BIT6: col = |i| i % 64; + col fixed latch = [1]*; + col fixed operation_id = [0]*; +} + +machine Bit7 with + latch: latch, + operation_id: operation_id +{ + operation check<0> BIT7 -> ; + + let BIT7: col = |i| i % 128; + col fixed latch = [1]*; + col fixed operation_id = [0]*; +} From 7a1ccfceada7b1e25b52ed366e4976748e708641 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Tue, 23 Jul 2024 11:41:44 +0200 Subject: [PATCH 20/24] Add dynamic VADCOP proving (no witgen yet) (#1574) Fixes #1496 Also, a step towards #1572 This PR implements the steps needed in `CompositeBackend` to implement dynamic VADCOP. In summary: - If a machines size (a.k.a. "degree") is set to `None`, fixed columns are computed in all powers of too in some hard-coded range. This fixes #1572. As a result, machines with a size set to `None` are available in multiple sizes. If the size is explicitly set by the user, the machine is only available in that one size. - Note that the ASM linker still sets the size of machines without a size. So, currently, this can only happen when coming from PIL directly. - `CompositeBackend` instantiates a new backend for each machine *and size*: - The verification key contains a key for each machine and size. - When proving, it it uses the backend of whatever size the witness has. The size chosen is also stored in the proof. - When verifying, the verification key of the reported size is used. - Witness generation currently chooses the largest available size. This will change in a future PR. This is an example: ``` $ cargo run pil test_data/pil/vm_to_block_dynamic_length.pil -o output -f --field bn254 --prove-with halo2-mock-composite ... == Proving machine: main (size 256) ==> Machine proof of 256 rows (0 bytes) computed in 209.101166ms == Proving machine: main__rom (size 256) ==> Machine proof of 256 rows (0 bytes) computed in 226.87175ms == Proving machine: main_arith (size 1024) ==> Machine proof of 1024 rows (0 bytes) computed in 432.807583ms ``` --- backend/src/composite/mod.rs | 246 ++++++++++++------ backend/src/composite/split.rs | 85 +++++- .../src/constant_evaluator/data_structures.rs | 38 ++- executor/src/constant_evaluator/mod.rs | 21 +- executor/src/witgen/mod.rs | 13 +- pipeline/tests/pil.rs | 20 ++ test_data/pil/vm_to_block_dynamic_length.pil | 67 +++++ 7 files changed, 381 insertions(+), 109 deletions(-) create mode 100644 test_data/pil/vm_to_block_dynamic_length.pil diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index 502558e6d..44265c15f 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -8,10 +8,7 @@ use std::{ use itertools::Itertools; use powdr_ast::analyzed::Analyzed; -use powdr_executor::{ - constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, - witgen::WitgenCallback, -}; +use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback}; use powdr_number::{DegreeType, FieldElement}; use serde::{Deserialize, Serialize}; use split::{machine_fixed_columns, machine_witness_columns}; @@ -20,18 +17,30 @@ use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; mod split; +/// Maps each size to the corresponding verification key. +type VerificationKeyBySize = BTreeMap>; + /// A composite verification key that contains a verification key for each machine separately. #[derive(Serialize, Deserialize)] struct CompositeVerificationKey { - /// Verification key for each machine (if available, otherwise None), sorted by machine name. - verification_keys: Vec>>, + /// Verification keys for each machine (if available, otherwise None), sorted by machine name. + verification_keys: Vec>, +} + +/// A proof for a single machine. +#[derive(Serialize, Deserialize)] +struct MachineProof { + /// The (dynamic) size of the machine. + size: usize, + /// The proof for the machine. + proof: Vec, } /// A composite proof that contains a proof for each machine separately, sorted by machine name. #[derive(Serialize, Deserialize)] struct CompositeProof { - /// Map from machine name to proof - proofs: Vec>, + /// Machine proofs, sorted by machine name. + proofs: Vec, } pub(crate) struct CompositeBackendFactory> { @@ -63,11 +72,6 @@ impl> BackendFactory for CompositeBacke unimplemented!(); } - // TODO: Handle multiple sizes. - let fixed = Arc::new( - get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, - ); - let pils = split::split_pil((*pil).clone()); // Read the setup once (if any) to pass to all backends. @@ -97,41 +101,46 @@ impl> BackendFactory for CompositeBacke .into_iter() .zip_eq(verification_keys.into_iter()) .map(|((machine_name, pil), verification_key)| { - // Set up readers for the setup and verification key - let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new); - let setup = setup_cursor.as_mut().map(|cursor| cursor as &mut dyn Read); - - let mut verification_key_cursor = verification_key.as_ref().map(Cursor::new); - let verification_key = verification_key_cursor - .as_mut() - .map(|cursor| cursor as &mut dyn Read); - let pil = Arc::new(pil); - let output_dir = output_dir - .clone() - .map(|output_dir| output_dir.join(&machine_name)); - if let Some(ref output_dir) = output_dir { - std::fs::create_dir_all(output_dir)?; - } - let fixed = Arc::new( - machine_fixed_columns(&fixed, &pil) - .into_iter() - .map(|(column_name, values)| (column_name, values.into())) - .collect(), - ); - let backend = self.factory.create( - pil.clone(), - fixed, - output_dir, - setup, - verification_key, - // TODO: Handle verification_app_key - None, - backend_options.clone(), - ); - backend.map(|backend| (machine_name.to_string(), MachineData { pil, backend })) + machine_fixed_columns(&fixed, &pil) + .into_iter() + .map(|(size, fixed)| { + let pil = set_size(pil.clone(), size as DegreeType); + // Set up readers for the setup and verification key + let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new); + let setup = setup_cursor.as_mut().map(|cursor| cursor as &mut dyn Read); + + let mut verification_key_cursor = verification_key + .as_ref() + .map(|keys| Cursor::new(keys.get(&size).unwrap())); + let verification_key = verification_key_cursor + .as_mut() + .map(|cursor| cursor as &mut dyn Read); + + let output_dir = output_dir + .clone() + .map(|output_dir| output_dir.join(&machine_name)); + if let Some(ref output_dir) = output_dir { + std::fs::create_dir_all(output_dir)?; + } + let fixed = Arc::new(fixed); + let backend = self.factory.create( + pil.clone(), + fixed, + output_dir, + setup, + verification_key, + // TODO: Handle verification_app_key + None, + backend_options.clone(), + ); + backend.map(|backend| (size, MachineData { pil, backend })) + }) + .collect::, _>>() + .map(|backends| (machine_name, backends)) }) - .collect::>()?; + .collect::, _>>()?; + Ok(Box::new(CompositeBackend { machine_data })) } @@ -180,7 +189,46 @@ pub(crate) struct CompositeBackend<'a, F> { /// Maps each machine name to the corresponding machine data /// Note that it is essential that we use BTreeMap here to ensure that the machines are /// deterministically ordered. - machine_data: BTreeMap>, + machine_data: BTreeMap>>, +} + +/// Makes sure that all columns in the machine PIL have the provided degree, cloning +/// the machine PIL if necessary. This is needed because backends other than `CompositeBackend` +/// typically expect that the degree is static. +/// +/// # Panics +/// Panics if the machine PIL contains definitions with different degrees, or if the machine +/// already has a degree set that is different from the provided degree. +fn set_size(pil: Arc>, degree: DegreeType) -> Arc> { + let current_degrees = pil.degrees(); + assert!( + current_degrees.len() <= 1, + "Expected at most one degree within a machine" + ); + + match current_degrees.iter().next() { + None => { + // Clone the PIL and set the degree for all definitions + let pil = (*pil).clone(); + let definitions = pil + .definitions + .into_iter() + .map(|(name, (mut symbol, def))| { + symbol.degree = Some(degree); + (name, (symbol, def)) + }) + .collect(); + Arc::new(Analyzed { definitions, ..pil }) + } + Some(existing_degree) => { + // Keep the the PIL as is + assert_eq!( + existing_degree, °ree, + "Expected all definitions within a machine to have the same degree" + ); + pil + } + } } // TODO: This just forwards to the backend for now. In the future this should: @@ -199,45 +247,64 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> { unimplemented!(); } - let proof = CompositeProof { - proofs: self - .machine_data - .iter() - .map(|(machine, MachineData { pil, backend })| { - let witgen_callback = witgen_callback.clone().with_pil(pil.clone()); - - log::info!("== Proving machine: {} (size {})", machine, pil.degree()); - log::debug!("PIL:\n{}", pil); - - let start = std::time::Instant::now(); - - let witness = machine_witness_columns(witness, pil, machine); - - let proof = backend.prove(&witness, None, witgen_callback); + let proofs = self + .machine_data + .iter() + .map(|(machine, machine_data)| { + let start = std::time::Instant::now(); + // Pick any available PIL; they all contain the same witness columns + let any_pil = &machine_data.values().next().unwrap().pil; + let witness = machine_witness_columns(witness, any_pil, machine); + let size = witness + .iter() + .map(|(_, witness)| witness.len()) + .unique() + .exactly_one() + .expect("All witness columns of a machine must have the same size"); + let machine_data = machine_data + .get(&size) + .expect("Machine does not support the given size"); + let witgen_callback = witgen_callback.clone().with_pil(machine_data.pil.clone()); + + log::info!("== Proving machine: {} (size {})", machine, size); + log::debug!("PIL:\n{}", machine_data.pil); + + let proof = machine_data.backend.prove(&witness, None, witgen_callback); + + match proof { + Ok(inner_proof) => { + log::info!( + "==> Machine proof of {size} rows ({} bytes) computed in {:?}", + inner_proof.len(), + start.elapsed() + ); + Ok(MachineProof { + size, + proof: inner_proof, + }) + } + Err(e) => { + log::error!("==> Machine proof failed: {:?}", e); + Err(e) + } + } + }) + .collect::>()?; - match &proof { - Ok(proof) => { - log::info!( - "==> Machine proof of {} bytes computed in {:?}", - proof.len(), - start.elapsed() - ); - } - Err(e) => { - log::error!("==> Machine proof failed: {:?}", e); - } - }; - proof - }) - .collect::>()?, - }; + let proof = CompositeProof { proofs }; Ok(bincode::serialize(&proof).unwrap()) } fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { let proof: CompositeProof = bincode::deserialize(proof).unwrap(); - for (machine_data, machine_proof) in self.machine_data.values().zip_eq(proof.proofs) { - machine_data.backend.verify(&machine_proof, instances)?; + for (machine_data, machine_proof) in + self.machine_data.values().zip_eq(proof.proofs.into_iter()) + { + machine_data + .get(&machine_proof.size) + .unwrap() + .backend + .verify(&machine_proof.proof, instances)?; } Ok(()) } @@ -245,6 +312,9 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> { fn export_setup(&self, output: &mut dyn io::Write) -> Result<(), Error> { // All backend are the same, just pick the first self.machine_data + .values() + .next() + .unwrap() .values() .next() .unwrap() @@ -258,10 +328,18 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> { .machine_data .values() .map(|machine_data| { - let backend = machine_data.backend.as_ref(); - let vk_bytes = backend.verification_key_bytes(); - match vk_bytes { - Ok(vk_bytes) => Ok(Some(vk_bytes)), + let verification_keys = machine_data + .iter() + .map(|(size, machine_data)| { + machine_data + .backend + .verification_key_bytes() + .map(|vk_bytes| (*size, vk_bytes)) + }) + .collect::>(); + + match verification_keys { + Ok(verification_keys) => Ok(Some(verification_keys)), Err(Error::NoVerificationAvailable) => Ok(None), Err(e) => Err(e), } diff --git a/backend/src/composite/split.rs b/backend/src/composite/split.rs index 59ffa1666..07e5de483 100644 --- a/backend/src/composite/split.rs +++ b/backend/src/composite/split.rs @@ -15,6 +15,7 @@ use powdr_ast::{ visitor::{ExpressionVisitable, VisitOrder}, }, }; +use powdr_executor::constant_evaluator::{VariablySizedColumn, MAX_DEGREE_LOG, MIN_DEGREE_LOG}; use powdr_number::FieldElement; const DUMMY_COLUMN_NAME: &str = "__dummy"; @@ -43,32 +44,91 @@ pub(crate) fn machine_witness_columns( machine_pil: &Analyzed, machine_name: &str, ) -> Vec<(String, Vec)> { + let machine_columns = select_machine_columns( + all_witness_columns, + machine_pil.committed_polys_in_source_order(), + ); + let size = machine_columns + .iter() + .map(|(_, column)| column.len()) + .unique() + .exactly_one() + .unwrap_or_else(|err| { + if err.try_len().unwrap() == 0 { + // No witness column, use degree of provided PIL + // In practice, we'd at least expect a bus accumulator here, so this should not happen + // in any sound setup (after #1498) + machine_pil.degree() as usize + } else { + panic!("Machine {machine_name} has witness columns of different sizes") + } + }); let dummy_column_name = format!("{machine_name}.{DUMMY_COLUMN_NAME}"); - let dummy_column = vec![F::zero(); machine_pil.degree() as usize]; + let dummy_column = vec![F::zero(); size]; iter::once((dummy_column_name, dummy_column)) - .chain(select_machine_columns( - all_witness_columns, - machine_pil.committed_polys_in_source_order(), - )) + .chain(machine_columns.into_iter().cloned()) .collect::>() } /// Given a set of columns and a PIL describing the machine, returns the fixed column that belong to the machine. pub(crate) fn machine_fixed_columns( - all_fixed_columns: &[(String, Vec)], + all_fixed_columns: &[(String, VariablySizedColumn)], machine_pil: &Analyzed, -) -> Vec<(String, Vec)> { - select_machine_columns( +) -> BTreeMap)>> { + let machine_columns = select_machine_columns( all_fixed_columns, machine_pil.constant_polys_in_source_order(), - ) + ); + let sizes = machine_columns + .iter() + .map(|(_, column)| column.available_sizes()) + .collect::>(); + + assert!( + sizes.len() <= 1, + "All fixed columns of a machine must have the same sizes" + ); + + let sizes = sizes.into_iter().next().unwrap_or_else(|| { + // There is no fixed column with a set size. So either the PIL has a degree, or we + // assume all possible degrees. + let machine_degrees = machine_pil.degrees(); + assert!( + machine_degrees.len() <= 1, + "All fixed columns of a machine must have the same sizes" + ); + match machine_degrees.iter().next() { + Some(°ree) => iter::once(degree as usize).collect(), + None => (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) + .map(|log_size| 1 << log_size) + .collect(), + } + }); + + sizes + .into_iter() + .map(|size| { + ( + size, + machine_columns + .iter() + .map(|(name, column)| { + ( + name.clone(), + column.get_by_size_cloned(size).unwrap().into(), + ) + }) + .collect::>(), + ) + }) + .collect() } /// Filter the given columns to only include those that are referenced by the given symbols. -fn select_machine_columns( - columns: &[(String, Vec)], +fn select_machine_columns<'a, T, C>( + columns: &'a [(String, C)], symbols: Vec<&(Symbol, T)>, -) -> Vec<(String, Vec)> { +) -> Vec<&'a (String, C)> { let names = symbols .into_iter() .flat_map(|(symbol, _)| symbol.array_elements().map(|(name, _)| name)) @@ -76,7 +136,6 @@ fn select_machine_columns( columns .iter() .filter(|(name, _)| names.contains(name)) - .cloned() .collect::>() } diff --git a/executor/src/constant_evaluator/data_structures.rs b/executor/src/constant_evaluator/data_structures.rs index cd5c288ff..931409840 100644 --- a/executor/src/constant_evaluator/data_structures.rs +++ b/executor/src/constant_evaluator/data_structures.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; #[derive(Serialize, Deserialize)] pub struct VariablySizedColumn { @@ -17,8 +17,22 @@ impl VariablySizedColumn { } Ok(self.column_by_size.values().next().unwrap()) } + + /// Returns the set of available sizes. + pub fn available_sizes(&self) -> BTreeSet { + self.column_by_size.keys().cloned().collect() + } + + /// Clones and returns the column with the given size. + pub fn get_by_size_cloned(&self, size: usize) -> Option> + where + F: Clone, + { + self.column_by_size.get(&size).cloned() + } } +/// Returns all columns with their unique sizes. Fails if any column has multiple sizes. pub fn get_uniquely_sized( column: &[(String, VariablySizedColumn)], ) -> Result)>, HasMultipleSizesError> { @@ -28,6 +42,17 @@ pub fn get_uniquely_sized( .collect() } +/// Returns all columns with their maximum sizes. +pub fn get_max_sized(column: &[(String, VariablySizedColumn)]) -> Vec<(String, &Vec)> { + column + .iter() + .map(|(name, column)| { + let max_size = column.column_by_size.keys().max().unwrap(); + (name.clone(), &column.column_by_size[max_size]) + }) + .collect() +} + pub fn get_uniquely_sized_cloned( column: &[(String, VariablySizedColumn)], ) -> Result)>, HasMultipleSizesError> { @@ -46,3 +71,14 @@ impl From> for VariablySizedColumn { } } } + +impl From>> for VariablySizedColumn { + fn from(columns: Vec>) -> Self { + VariablySizedColumn { + column_by_size: columns + .into_iter() + .map(|column| (column.len(), column)) + .collect(), + } + } +} diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 5168c1d84..cd5d11651 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -3,7 +3,9 @@ use std::{ sync::{Arc, RwLock}, }; -pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; +pub use data_structures::{ + get_max_sized, get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn, +}; use itertools::Itertools; use powdr_ast::{ analyzed::{Analyzed, FunctionValueDefinition, Symbol, TypedExpression}, @@ -18,6 +20,9 @@ use rayon::prelude::{IntoParallelIterator, ParallelIterator}; mod data_structures; +pub const MIN_DEGREE_LOG: usize = 5; +pub const MAX_DEGREE_LOG: usize = 10; + /// Generates the fixed column values for all fixed columns that are defined /// (and not just declared). /// @returns the names (in source order) and the values for the columns. @@ -31,7 +36,17 @@ pub fn generate(analyzed: &Analyzed) -> Vec<(String, Variabl // for non-arrays, set index to None. for (index, (name, id)) in poly.array_elements().enumerate() { let index = poly.is_array().then_some(index as u64); - let values = generate_values(analyzed, poly.degree.unwrap(), &name, value, index); + let values = if let Some(degree) = poly.degree { + generate_values(analyzed, degree, &name, value, index).into() + } else { + (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) + .map(|degree_log| { + let degree = 1 << degree_log; + generate_values(analyzed, degree, &name, value, index) + }) + .collect::>() + .into() + }; assert!(fixed_cols.insert(name, (id, values)).is_none()); } } @@ -40,7 +55,7 @@ pub fn generate(analyzed: &Analyzed) -> Vec<(String, Variabl fixed_cols .into_iter() .sorted_by_key(|(_, (id, _))| *id) - .map(|(name, (_, values))| (name, values.into())) + .map(|(name, (_, values))| (name, values)) .collect() } diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index d3ddbd17b..8b4eadbe9 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -10,7 +10,7 @@ use powdr_ast::parsed::visitor::ExpressionVisitable; use powdr_ast::parsed::{FunctionKind, LambdaExpression}; use powdr_number::{DegreeType, FieldElement}; -use crate::constant_evaluator::{get_uniquely_sized, VariablySizedColumn}; +use crate::constant_evaluator::{get_max_sized, VariablySizedColumn, MAX_DEGREE_LOG}; use self::data_structures::column_map::{FixedColumnMap, WitnessColumnMap}; pub use self::eval_result::{ @@ -159,7 +159,7 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { pub fn generate(self) -> Vec<(String, Vec)> { record_start(OUTER_CODE_NAME); // TODO: Handle multiple sizes - let fixed_col_values = get_uniquely_sized(self.fixed_col_values).unwrap(); + let fixed_col_values = get_max_sized(self.fixed_col_values); let fixed = FixedData::new( self.analyzed, &fixed_col_values, @@ -315,12 +315,9 @@ impl<'a, T: FieldElement> FixedData<'a, T> { }) // get all array elements and their degrees .flat_map(|symbol| { - symbol.array_elements().map(|(_, id)| { - ( - id, - symbol.degree.expect("all polynomials should have a degree"), - ) - }) + symbol + .array_elements() + .map(|(_, id)| (id, symbol.degree.unwrap_or(1 << MAX_DEGREE_LOG))) }) // only keep the ones matching our set .filter_map(|(id, degree)| ids.contains(&id).then_some(degree)) diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index b855563ab..4d73ee4ab 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -329,6 +329,26 @@ fn different_degrees() { ); } +#[test] +fn vm_to_block_dynamic_length() { + let f = "pil/vm_to_block_dynamic_length.pil"; + // Because machines have different lengths, this can only be proven + // with a composite proof. + run_pilcom_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ) + .unwrap(); + test_halo2_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ); + gen_estark_proof_with_backend_variant( + make_prepared_pipeline(f, vec![], vec![]), + BackendVariant::Composite, + ); +} + #[test] fn serialize_deserialize_optimized_pil() { let f = "pil/fibonacci.pil"; diff --git a/test_data/pil/vm_to_block_dynamic_length.pil b/test_data/pil/vm_to_block_dynamic_length.pil new file mode 100644 index 000000000..6bc2a2f2d --- /dev/null +++ b/test_data/pil/vm_to_block_dynamic_length.pil @@ -0,0 +1,67 @@ +// This an adjusted copy of the optimized PIL coming out of `test_data/asm/vm_to_block_different_length.asm`. +// We can remove this once the linker allows us to specify a machine with a variable degree. + +namespace std::prover; + enum Query { + Input(int), + Output(int, int), + Hint(fe), + DataIdentifier(int, int), + None, + } +namespace main(256); + col witness _operation_id(i) query std::prover::Query::Hint(6); + col fixed _block_enforcer_last_step = [0]* + [1]; + (1 - main._block_enforcer_last_step) * (1 - main.instr_return) * (main._operation_id' - main._operation_id) = 0; + col witness pc; + col witness X; + col witness Y; + col witness reg_write_Z_A; + col witness A; + col witness Z; + col witness instr_add; + col witness instr_mul; + col witness instr_assert_eq; + main.instr_assert_eq * (main.X - main.Y) = 0; + col witness instr__jump_to_operation; + col witness instr__reset; + col witness instr__loop; + col witness instr_return; + col witness X_const; + col witness read_X_A; + main.X = main.read_X_A * main.A + main.X_const; + col witness Y_const; + main.Y = main.Y_const; + col witness Z_read_free; + main.Z = main.Z_read_free * main.Z_free_value; + col fixed first_step = [1] + [0]*; + main.A' = main.reg_write_Z_A * main.Z + (1 - (main.reg_write_Z_A + main.instr__reset)) * main.A; + col pc_update = main.instr__jump_to_operation * main._operation_id + main.instr__loop * main.pc + (1 - (main.instr__jump_to_operation + main.instr__loop + main.instr_return)) * (main.pc + 1); + main.pc' = (1 - main.first_step') * main.pc_update; + col witness Z_free_value; + [main.pc, main.reg_write_Z_A, main.instr_add, main.instr_mul, main.instr_assert_eq, main.instr__jump_to_operation, main.instr__reset, main.instr__loop, main.instr_return, main.X_const, main.read_X_A, main.Y_const, main.Z_read_free] in [main__rom.p_line, main__rom.p_reg_write_Z_A, main__rom.p_instr_add, main__rom.p_instr_mul, main__rom.p_instr_assert_eq, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return, main__rom.p_X_const, main__rom.p_read_X_A, main__rom.p_Y_const, main__rom.p_Z_read_free]; + main.instr_add $ [0, main.X, main.Y, main.Z] in [main_arith.operation_id, main_arith.x[0], main_arith.x[1], main_arith.y]; + main.instr_mul $ [1, main.X, main.Y, main.Z] in [main_arith.operation_id, main_arith.x[0], main_arith.x[1], main_arith.y]; + col fixed _linker_first_step = [1] + [0]*; + main._linker_first_step * (main._operation_id - 2) = 0; +namespace main__rom(256); + col fixed p_line = [0, 1, 2, 3, 4, 5, 6] + [6]*; + col fixed p_X_const = [0, 0, 2, 0, 0, 0, 0] + [0]*; + col fixed p_Y_const = [0, 0, 1, 9, 27, 0, 0] + [0]*; + col fixed p_Z_read_free = [0, 0, 1, 1, 0, 0, 0] + [0]*; + col fixed p_instr__jump_to_operation = [0, 1, 0, 0, 0, 0, 0] + [0]*; + col fixed p_instr__loop = [0, 0, 0, 0, 0, 0, 1] + [1]*; + col fixed p_instr__reset = [1, 0, 0, 0, 0, 0, 0] + [0]*; + col fixed p_instr_add = [0, 0, 1, 0, 0, 0, 0] + [0]*; + col fixed p_instr_assert_eq = [0, 0, 0, 0, 1, 0, 0] + [0]*; + col fixed p_instr_mul = [0, 0, 0, 1, 0, 0, 0] + [0]*; + col fixed p_instr_return = [0, 0, 0, 0, 0, 1, 0] + [0]*; + col fixed p_read_X_A = [0, 0, 0, 1, 1, 0, 0] + [0]*; + col fixed p_reg_write_Z_A = [0, 0, 1, 1, 0, 0, 0] + [0]*; + +// CHANGED HERE: The degree of this namespace is None, meaning that this machine has a variable size. +namespace main_arith; + col witness operation_id; + col witness x[2]; + col witness y; + main_arith.y = main_arith.operation_id * (main_arith.x[0] * main_arith.x[1]) + (1 - main_arith.operation_id) * (main_arith.x[0] + main_arith.x[1]); From 8f72dd436a16fc221415c3c8914040de17d36f54 Mon Sep 17 00:00:00 2001 From: Leo Date: Tue, 23 Jul 2024 14:11:04 +0200 Subject: [PATCH 21/24] reduce the degree of compiled asm to pil so it works with P3 (#1596) Depends on https://github.com/powdr-labs/powdr/pull/1594 --------- Co-authored-by: schaeff --- asm-to-pil/src/vm_to_constrained.rs | 14 ++++++++++---- linker/src/lib.rs | 24 ++++++++++++++++-------- pipeline/src/test_util.rs | 14 +++++++++++--- pipeline/tests/asm.rs | 3 ++- pipeline/tests/pil.rs | 8 ++++---- 5 files changed, 43 insertions(+), 20 deletions(-) diff --git a/asm-to-pil/src/vm_to_constrained.rs b/asm-to-pil/src/vm_to_constrained.rs index b7d5e2e09..191edb442 100644 --- a/asm-to-pil/src/vm_to_constrained.rs +++ b/asm-to-pil/src/vm_to_constrained.rs @@ -201,12 +201,18 @@ impl VMConverter { // introduce an intermediate witness polynomial to keep the degree of polynomial identities at 2 // this may not be optimal for backends which support higher degree constraints let pc_update_name = format!("{name}_update"); - vec![ - PilStatement::PolynomialDefinition( + witness_column( + SourceRef::unknown(), + pc_update_name.clone(), + None, + ), + PilStatement::Expression( SourceRef::unknown(), - pc_update_name.to_string(), - rhs, + build::identity( + direct_reference(pc_update_name.clone()), + rhs, + ), ), PilStatement::Expression( SourceRef::unknown(), diff --git a/linker/src/lib.rs b/linker/src/lib.rs index c2eecd2bd..66105396d 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -269,7 +269,8 @@ mod test { pol commit instr__loop; pol commit instr_return; pol constant first_step = [1] + [0]*; - pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; 1 $ [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return]; namespace main__rom(4 + 4); @@ -342,7 +343,8 @@ namespace main__rom(4 + 4); Y = read_Y_A * A + read_Y_pc * pc + Y_const + Y_read_free * Y_free_value; pol constant first_step = [1] + [0]*; A' = reg_write_X_A * X + reg_write_Y_A * Y + instr__reset * 0 + (1 - (reg_write_X_A + reg_write_Y_A + instr__reset)) * A; - pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; pol commit X_free_value; pol commit Y_free_value; @@ -392,7 +394,8 @@ namespace main_sub(16); _output_0 = read__output_0_pc * pc + read__output_0__input_0 * _input_0 + _output_0_const + _output_0_read_free * _output_0_free_value; pol constant first_step = [1] + [0]*; (1 - instr__reset) * (_input_0' - _input_0) = 0; - pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; pol commit _output_0_free_value; 1 $ [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return, _output_0_const, _output_0_read_free, read__output_0_pc, read__output_0__input_0] in main_sub__rom.latch $ [main_sub__rom.operation_id, main_sub__rom.p_line, main_sub__rom.p_instr__jump_to_operation, main_sub__rom.p_instr__reset, main_sub__rom.p_instr__loop, main_sub__rom.p_instr_return, main_sub__rom.p__output_0_const, main_sub__rom.p__output_0_read_free, main_sub__rom.p_read__output_0_pc, main_sub__rom.p_read__output_0__input_0]; @@ -459,7 +462,8 @@ namespace main_sub__rom(16); pol constant first_step = [1] + [0]*; A' = reg_write_X_A * X + instr__reset * 0 + (1 - (reg_write_X_A + instr__reset)) * A; CNT' = reg_write_X_CNT * X + instr_dec_CNT * (CNT - 1) + instr__reset * 0 + (1 - (reg_write_X_CNT + instr_dec_CNT + instr__reset)) * CNT; - pol pc_update = instr_jmpz * (instr_jmpz_pc_update + instr_jmpz_pc_update_1) + instr_jmp * instr_jmp_param_l + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_jmpz + instr_jmp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr_jmpz * (instr_jmpz_pc_update + instr_jmpz_pc_update_1) + instr_jmp * instr_jmp_param_l + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_jmpz + instr_jmp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; pol commit X_free_value(__i) query match std::prover::eval(pc) { 2 => std::prover::Query::Input(1), @@ -537,7 +541,8 @@ machine Machine { pol commit instr_return; pol constant first_step = [1] + [0]*; fp' = instr_inc_fp * (fp + instr_inc_fp_param_amount) + instr_adjust_fp * (fp + instr_adjust_fp_param_amount) + instr__reset * 0 + (1 - (instr_inc_fp + instr_adjust_fp + instr__reset)) * fp; - pol pc_update = instr_adjust_fp * instr_adjust_fp_param_t + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_adjust_fp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr_adjust_fp * instr_adjust_fp_param_t + instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr_adjust_fp + instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; 1 $ [0, pc, instr_inc_fp, instr_inc_fp_param_amount, instr_adjust_fp, instr_adjust_fp_param_amount, instr_adjust_fp_param_t, instr__jump_to_operation, instr__reset, instr__loop, instr_return] in main__rom.latch $ [main__rom.operation_id, main__rom.p_line, main__rom.p_instr_inc_fp, main__rom.p_instr_inc_fp_param_amount, main__rom.p_instr_adjust_fp, main__rom.p_instr_adjust_fp_param_amount, main__rom.p_instr_adjust_fp_param_t, main__rom.p_instr__jump_to_operation, main__rom.p_instr__reset, main__rom.p_instr__loop, main__rom.p_instr_return]; pol constant _linker_first_step = [1] + [0]*; @@ -630,7 +635,8 @@ machine Main { X = read_X_A * A + read_X_pc * pc + X_const + X_read_free * X_free_value; pol constant first_step = [1] + [0]*; A' = reg_write_X_A * X + instr_add5_into_A * A' + instr__reset * 0 + (1 - (reg_write_X_A + instr_add5_into_A + instr__reset)) * A; - pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; pol commit X_free_value; instr_add5_into_A $ [0, X, A'] in main_vm.latch $ [main_vm.operation_id, main_vm.x, main_vm.y]; @@ -711,7 +717,8 @@ namespace main_vm(1024); pol constant first_step = [1] + [0]*; A' = reg_write_X_A * X + reg_write_Y_A * Y + reg_write_Z_A * Z + instr__reset * 0 + (1 - (reg_write_X_A + reg_write_Y_A + reg_write_Z_A + instr__reset)) * A; B' = reg_write_X_B * X + reg_write_Y_B * Y + reg_write_Z_B * Z + instr_or_into_B * B' + instr__reset * 0 + (1 - (reg_write_X_B + reg_write_Y_B + reg_write_Z_B + instr_or_into_B + instr__reset)) * B; - pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; pol commit X_free_value; pol commit Y_free_value; @@ -859,7 +866,8 @@ namespace main_bin(65536); A' = reg_write_X_A * X + reg_write_Y_A * Y + reg_write_Z_A * Z + reg_write_W_A * W + instr_add_to_A * A' + instr_add_BC_to_A * A' + instr__reset * 0 + (1 - (reg_write_X_A + reg_write_Y_A + reg_write_Z_A + reg_write_W_A + instr_add_to_A + instr_add_BC_to_A + instr__reset)) * A; B' = reg_write_X_B * X + reg_write_Y_B * Y + reg_write_Z_B * Z + reg_write_W_B * W + instr__reset * 0 + (1 - (reg_write_X_B + reg_write_Y_B + reg_write_Z_B + reg_write_W_B + instr__reset)) * B; C' = reg_write_X_C * X + reg_write_Y_C * Y + reg_write_Z_C * Z + reg_write_W_C * W + instr__reset * 0 + (1 - (reg_write_X_C + reg_write_Y_C + reg_write_Z_C + reg_write_W_C + instr__reset)) * C; - pol pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); + pol commit pc_update; + pc_update = instr__jump_to_operation * _operation_id + instr__loop * pc + instr_return * 0 + (1 - (instr__jump_to_operation + instr__loop + instr_return)) * (pc + 1); pc' = (1 - first_step') * pc_update; pol commit X_free_value; pol commit Y_free_value; diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index 3a3bb5df2..d89b559d7 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -259,12 +259,20 @@ pub fn gen_halo2_proof(pipeline: Pipeline, backend: BackendVariant) pub fn gen_halo2_proof(_pipeline: Pipeline, _backend: BackendVariant) {} #[cfg(feature = "plonky3")] -pub fn test_plonky3(file_name: &str, inputs: Vec) { +pub fn test_plonky3_with_backend_variant( + file_name: &str, + inputs: Vec, + backend: BackendVariant, +) { + let backend = match backend { + BackendVariant::Monolithic => powdr_backend::BackendType::Plonky3, + BackendVariant::Composite => powdr_backend::BackendType::Plonky3Composite, + }; let mut pipeline = Pipeline::default() .with_tmp_output() .from_file(resolve_test_file(file_name)) .with_prover_inputs(inputs) - .with_backend(powdr_backend::BackendType::Plonky3, None); + .with_backend(backend, None); // Generate a proof let proof = pipeline.compute_proof().cloned().unwrap(); @@ -296,7 +304,7 @@ pub fn test_plonky3(file_name: &str, inputs: Vec) { } #[cfg(not(feature = "plonky3"))] -pub fn test_plonky3(_: &str, _: Vec) {} +pub fn test_plonky3_with_backend_variant(_: &str, _: Vec, _: BackendVariant) {} #[cfg(not(feature = "plonky3"))] pub fn gen_plonky3_proof(_: &str, _: Vec) {} diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index c1540b725..eaacb93f9 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -5,7 +5,7 @@ use powdr_pipeline::{ test_util::{ gen_estark_proof, gen_estark_proof_with_backend_variant, make_prepared_pipeline, resolve_test_file, run_pilcom_test_file, run_pilcom_with_backend_variant, test_halo2, - test_halo2_with_backend_variant, BackendVariant, + test_halo2_with_backend_variant, test_plonky3_with_backend_variant, BackendVariant, }, util::{FixedPolySet, PolySet, WitnessPolySet}, Pipeline, @@ -42,6 +42,7 @@ fn simple_sum_asm() { verify_asm(f, slice_to_vec(&i)); test_halo2(f, slice_to_vec(&i)); gen_estark_proof(f, slice_to_vec(&i)); + test_plonky3_with_backend_variant(f, slice_to_vec(&i), BackendVariant::Composite); } #[test] diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 4d73ee4ab..002db9941 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -6,8 +6,8 @@ use powdr_pipeline::test_util::{ assert_proofs_fail_for_invalid_witnesses_halo2, assert_proofs_fail_for_invalid_witnesses_pilcom, gen_estark_proof, gen_estark_proof_with_backend_variant, make_prepared_pipeline, run_pilcom_test_file, - run_pilcom_with_backend_variant, test_halo2, test_halo2_with_backend_variant, test_plonky3, - BackendVariant, + run_pilcom_with_backend_variant, test_halo2, test_halo2_with_backend_variant, + test_plonky3_with_backend_variant, BackendVariant, }; use test_log::test; @@ -93,7 +93,7 @@ fn fibonacci() { verify_pil(f, Default::default()); test_halo2(f, Default::default()); gen_estark_proof(f, Default::default()); - test_plonky3(f, Default::default()); + test_plonky3_with_backend_variant(f, Default::default(), BackendVariant::Monolithic); } #[test] @@ -245,7 +245,7 @@ fn halo_without_lookup() { #[test] fn add() { let f = "pil/add.pil"; - test_plonky3(f, Default::default()); + test_plonky3_with_backend_variant(f, Default::default(), BackendVariant::Monolithic); } #[test] From 5c31155a7ed38b69af8d5d1ab54521bae398bf86 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Tue, 23 Jul 2024 14:21:33 +0200 Subject: [PATCH 22/24] Add dynamic VADCOP witness generation (for block machines) (#1595) Another step towards #1572 Builds on #1574 I modified witness generation as follows: - Each machine keeps track of its current size; whenever a fixed column value is read, it has to pass the requested size as well. - If fixed columns are available in several sizes, witness generation starts out by using the largest size, as before - When finalizing a block machine, it "downsizes" the machine to the smallest possible value Doing this for other machine types (e.g. VM, memory, etc) should be done in another PR. In the `vm_to_block_dynamic_length.pil` example, witness generation now pics the minimum size instead of the maximum size for `main_arith` ``` $ cargo run pil test_data/pil/vm_to_block_dynamic_length.pil -o output -f --field bn254 --prove-with halo2-mock-composite ... == Proving machine: main (size 256) ==> Machine proof of 256 rows (0 bytes) computed in 60.174583ms size: 256 Machine: main__rom == Proving machine: main__rom (size 256) ==> Machine proof of 256 rows (0 bytes) computed in 33.310292ms size: 32 Machine: main_arith == Proving machine: main_arith (size 32) ==> Machine proof of 32 rows (0 bytes) computed in 2.766541ms ``` --- backend/src/composite/split.rs | 2 +- .../src/constant_evaluator/data_structures.rs | 20 ++-------- executor/src/constant_evaluator/mod.rs | 6 +-- executor/src/witgen/block_processor.rs | 16 ++++++-- executor/src/witgen/fixed_evaluator.rs | 13 ++++-- executor/src/witgen/generator.rs | 1 + executor/src/witgen/global_constraints.rs | 2 +- executor/src/witgen/machines/block_machine.rs | 18 ++++++++- .../witgen/machines/fixed_lookup_machine.rs | 6 +-- .../witgen/machines/sorted_witness_machine.rs | 6 +-- .../src/witgen/machines/write_once_memory.rs | 2 +- executor/src/witgen/mod.rs | 40 +++++++++++++------ executor/src/witgen/processor.rs | 17 +++++++- executor/src/witgen/query_processor.rs | 14 +++++-- executor/src/witgen/rows.rs | 6 +++ .../src/witgen/symbolic_witness_evaluator.rs | 11 ++++- executor/src/witgen/vm_processor.rs | 9 ++++- 17 files changed, 130 insertions(+), 59 deletions(-) diff --git a/backend/src/composite/split.rs b/backend/src/composite/split.rs index 07e5de483..062770a89 100644 --- a/backend/src/composite/split.rs +++ b/backend/src/composite/split.rs @@ -115,7 +115,7 @@ pub(crate) fn machine_fixed_columns( .map(|(name, column)| { ( name.clone(), - column.get_by_size_cloned(size).unwrap().into(), + column.get_by_size(size).unwrap().to_vec().into(), ) }) .collect::>(), diff --git a/executor/src/constant_evaluator/data_structures.rs b/executor/src/constant_evaluator/data_structures.rs index 931409840..f0af3d457 100644 --- a/executor/src/constant_evaluator/data_structures.rs +++ b/executor/src/constant_evaluator/data_structures.rs @@ -24,11 +24,10 @@ impl VariablySizedColumn { } /// Clones and returns the column with the given size. - pub fn get_by_size_cloned(&self, size: usize) -> Option> - where - F: Clone, - { - self.column_by_size.get(&size).cloned() + pub fn get_by_size(&self, size: usize) -> Option<&[F]> { + self.column_by_size + .get(&size) + .map(|column| column.as_slice()) } } @@ -42,17 +41,6 @@ pub fn get_uniquely_sized( .collect() } -/// Returns all columns with their maximum sizes. -pub fn get_max_sized(column: &[(String, VariablySizedColumn)]) -> Vec<(String, &Vec)> { - column - .iter() - .map(|(name, column)| { - let max_size = column.column_by_size.keys().max().unwrap(); - (name.clone(), &column.column_by_size[max_size]) - }) - .collect() -} - pub fn get_uniquely_sized_cloned( column: &[(String, VariablySizedColumn)], ) -> Result)>, HasMultipleSizesError> { diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index cd5d11651..2d4726dd0 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -3,9 +3,7 @@ use std::{ sync::{Arc, RwLock}, }; -pub use data_structures::{ - get_max_sized, get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn, -}; +pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; use itertools::Itertools; use powdr_ast::{ analyzed::{Analyzed, FunctionValueDefinition, Symbol, TypedExpression}, @@ -21,7 +19,7 @@ use rayon::prelude::{IntoParallelIterator, ParallelIterator}; mod data_structures; pub const MIN_DEGREE_LOG: usize = 5; -pub const MAX_DEGREE_LOG: usize = 10; +pub const MAX_DEGREE_LOG: usize = 22; /// Generates the fixed column values for all fixed columns that are defined /// (and not just declared). diff --git a/executor/src/witgen/block_processor.rs b/executor/src/witgen/block_processor.rs index 06fd07fec..f6be9fac2 100644 --- a/executor/src/witgen/block_processor.rs +++ b/executor/src/witgen/block_processor.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use powdr_ast::analyzed::{AlgebraicReference, PolyID}; -use powdr_number::FieldElement; +use powdr_number::{DegreeType, FieldElement}; use crate::Identity; @@ -33,8 +33,16 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> BlockProcessor<'a, 'b, 'c identities: &'c [&'a Identity], fixed_data: &'a FixedData<'a, T>, witness_cols: &'c HashSet, + size: DegreeType, ) -> Self { - let processor = Processor::new(row_offset, data, mutable_state, fixed_data, witness_cols); + let processor = Processor::new( + row_offset, + data, + mutable_state, + fixed_data, + witness_cols, + size, + ); Self { processor, identities, @@ -121,7 +129,7 @@ mod tests { use powdr_pil_analyzer::analyze_string; use crate::{ - constant_evaluator::{generate, get_uniquely_sized}, + constant_evaluator::generate, witgen::{ data_structures::finalizable_data::FinalizableData, identity_processor::Machines, @@ -153,7 +161,6 @@ mod tests { ) -> R { let analyzed = analyze_string(src); let constants = generate(&analyzed); - let constants = get_uniquely_sized(&constants).unwrap(); let fixed_data = FixedData::new(&analyzed, &constants, &[], Default::default(), 0); // No submachines @@ -189,6 +196,7 @@ mod tests { &identities, &fixed_data, &witness_cols, + degree, ); f( diff --git a/executor/src/witgen/fixed_evaluator.rs b/executor/src/witgen/fixed_evaluator.rs index 77a57f9e6..aa97108d3 100644 --- a/executor/src/witgen/fixed_evaluator.rs +++ b/executor/src/witgen/fixed_evaluator.rs @@ -2,17 +2,22 @@ use super::affine_expression::AffineResult; use super::expression_evaluator::SymbolicVariables; use super::FixedData; use powdr_ast::analyzed::AlgebraicReference; -use powdr_number::FieldElement; +use powdr_number::{DegreeType, FieldElement}; /// Evaluates only fixed columns on a specific row. pub struct FixedEvaluator<'a, T: FieldElement> { fixed_data: &'a FixedData<'a, T>, row: usize, + size: DegreeType, } impl<'a, T: FieldElement> FixedEvaluator<'a, T> { - pub fn new(fixed_data: &'a FixedData<'a, T>, row: usize) -> Self { - FixedEvaluator { fixed_data, row } + pub fn new(fixed_data: &'a FixedData<'a, T>, row: usize, size: DegreeType) -> Self { + FixedEvaluator { + fixed_data, + row, + size, + } } } @@ -23,7 +28,7 @@ impl<'a, T: FieldElement> SymbolicVariables for FixedEvaluator<'a, T> { poly.is_fixed(), "Can only access fixed columns in the fixed evaluator." ); - let col_data = self.fixed_data.fixed_cols[&poly.poly_id].values; + let col_data = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); let degree = col_data.len(); let row = if poly.next { (self.row + 1) % degree diff --git a/executor/src/witgen/generator.rs b/executor/src/witgen/generator.rs index 8b07dd148..74d7f13af 100644 --- a/executor/src/witgen/generator.rs +++ b/executor/src/witgen/generator.rs @@ -197,6 +197,7 @@ impl<'a, T: FieldElement> Generator<'a, T> { &identities_with_next_reference, self.fixed_data, &self.witnesses, + self.degree, ); let mut sequence_iterator = ProcessingSequenceIterator::Default( DefaultSequenceIterator::new(0, identities_with_next_reference.len(), None), diff --git a/executor/src/witgen/global_constraints.rs b/executor/src/witgen/global_constraints.rs index 7a61ac5f9..face80d4e 100644 --- a/executor/src/witgen/global_constraints.rs +++ b/executor/src/witgen/global_constraints.rs @@ -118,7 +118,7 @@ pub fn set_global_constraints<'a, T: FieldElement>( // It allows us to completely remove some lookups. let mut full_span = BTreeSet::new(); for (poly_id, col) in fixed_data.fixed_cols.iter() { - if let Some((cons, full)) = process_fixed_column(col.values) { + if let Some((cons, full)) = process_fixed_column(col.values_max_size()) { assert!(known_constraints.insert(poly_id, cons).is_none()); if full { full_span.insert(poly_id); diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index 8e8f2575d..5ac029c3f 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -4,6 +4,7 @@ use std::iter::{self, once}; use super::{EvalResult, FixedData, FixedLookup}; +use crate::constant_evaluator::MIN_DEGREE_LOG; use crate::witgen::block_processor::BlockProcessor; use crate::witgen::data_structures::finalizable_data::FinalizableData; use crate::witgen::processor::{OuterQuery, Processor}; @@ -249,7 +250,7 @@ fn try_to_period( let degree = fixed_data.common_degree(once(&poly.poly_id)); - let values = fixed_data.fixed_cols[&poly.poly_id].values; + let values = fixed_data.fixed_cols[&poly.poly_id].values(degree); let offset = values.iter().position(|v| v.is_one())?; let period = 1 + values.iter().skip(offset + 1).position(|v| v.is_one())?; @@ -317,6 +318,19 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> { ); } + if self.fixed_data.is_variable_size(&self.witness_cols) { + let new_degree = self.data.len().next_power_of_two() as DegreeType; + let new_degree = new_degree.max(1 << MIN_DEGREE_LOG); + log::info!( + "Resizing variable length machine '{}': {} -> {} (rounded up from {})", + self.name, + self.degree, + new_degree, + self.data.len() + ); + self.degree = new_degree; + } + if matches!(self.connection_type, ConnectionType::Permutation) { // We have to make sure that *all* selectors are 0 in the dummy block, // because otherwise this block won't have a matching block on the LHS. @@ -343,6 +357,7 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> { &mut mutable_state, self.fixed_data, &self.witness_cols, + self.degree, ); // Set all selectors to 0 @@ -558,6 +573,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { &self.identities, self.fixed_data, &self.witness_cols, + self.degree, ) .with_outer_query(outer_query); diff --git a/executor/src/witgen/machines/fixed_lookup_machine.rs b/executor/src/witgen/machines/fixed_lookup_machine.rs index b87d255b1..50f8b7c37 100644 --- a/executor/src/witgen/machines/fixed_lookup_machine.rs +++ b/executor/src/witgen/machines/fixed_lookup_machine.rs @@ -101,12 +101,12 @@ impl IndexedColumns { // get all values for the columns to be indexed let input_column_values = sorted_input_fixed_columns .iter() - .map(|id| fixed_data.fixed_cols[id].values) + .map(|id| fixed_data.fixed_cols[id].values_max_size()) .collect::>(); let output_column_values = sorted_output_fixed_columns .iter() - .map(|id| fixed_data.fixed_cols[id].values) + .map(|id| fixed_data.fixed_cols[id].values_max_size()) .collect::>(); let degree = input_column_values @@ -290,7 +290,7 @@ impl FixedLookup { let output = output_columns .iter() - .map(|column| fixed_data.fixed_cols[column].values[row]); + .map(|column| fixed_data.fixed_cols[column].values_max_size()[row]); let mut result = EvalValue::complete(vec![]); for (l, r) in output_expressions.into_iter().zip(output) { diff --git a/executor/src/witgen/machines/sorted_witness_machine.rs b/executor/src/witgen/machines/sorted_witness_machine.rs index 51b35b0a6..e9f83c8c8 100644 --- a/executor/src/witgen/machines/sorted_witness_machine.rs +++ b/executor/src/witgen/machines/sorted_witness_machine.rs @@ -117,9 +117,9 @@ fn check_identity( // TODO this could be rather slow. We should check the code for identity instead // of evaluating it. - let degree = degree as usize; - for row in 0..(degree) { - let ev = ExpressionEvaluator::new(FixedEvaluator::new(fixed_data, row)); + for row in 0..(degree as usize) { + let ev = ExpressionEvaluator::new(FixedEvaluator::new(fixed_data, row, degree)); + let degree = degree as usize; let nl = ev.evaluate(not_last).ok()?.constant_value()?; if (row == degree - 1 && !nl.is_zero()) || (row < degree - 1 && !nl.is_one()) { return None; diff --git a/executor/src/witgen/machines/write_once_memory.rs b/executor/src/witgen/machines/write_once_memory.rs index 85a992702..085c1eb1e 100644 --- a/executor/src/witgen/machines/write_once_memory.rs +++ b/executor/src/witgen/machines/write_once_memory.rs @@ -106,7 +106,7 @@ impl<'a, T: FieldElement> WriteOnceMemory<'a, T> { for row in 0..degree { let key = key_polys .iter() - .map(|k| fixed_data.fixed_cols[k].values[row as usize]) + .map(|k| fixed_data.fixed_cols[k].values(degree)[row as usize]) .collect::>(); if key_to_index.insert(key, row).is_some() { // Duplicate keys, can't be a write-once memory diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index 8b4eadbe9..71ef8325c 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -10,7 +10,7 @@ use powdr_ast::parsed::visitor::ExpressionVisitable; use powdr_ast::parsed::{FunctionKind, LambdaExpression}; use powdr_number::{DegreeType, FieldElement}; -use crate::constant_evaluator::{get_max_sized, VariablySizedColumn, MAX_DEGREE_LOG}; +use crate::constant_evaluator::{VariablySizedColumn, MAX_DEGREE_LOG}; use self::data_structures::column_map::{FixedColumnMap, WitnessColumnMap}; pub use self::eval_result::{ @@ -158,11 +158,9 @@ impl<'a, 'b, T: FieldElement> WitnessGenerator<'a, 'b, T> { /// @returns the values (in source order) and the degree of the polynomials. pub fn generate(self) -> Vec<(String, Vec)> { record_start(OUTER_CODE_NAME); - // TODO: Handle multiple sizes - let fixed_col_values = get_max_sized(self.fixed_col_values); let fixed = FixedData::new( self.analyzed, - &fixed_col_values, + self.fixed_col_values, self.external_witness_values, self.challenges, self.stage, @@ -302,7 +300,10 @@ impl<'a, T: FieldElement> FixedData<'a, T> { /// - the degree is not unique /// - the set of polynomials is empty /// - a declared polynomial does not have an explicit degree - pub fn common_degree<'b>(&self, ids: impl IntoIterator) -> DegreeType { + fn common_set_degree<'b>( + &self, + ids: impl IntoIterator, + ) -> Option { let ids: HashSet<_> = ids.into_iter().collect(); self.analyzed @@ -314,11 +315,7 @@ impl<'a, T: FieldElement> FixedData<'a, T> { matches!(symbol.kind, SymbolKind::Poly(_)).then_some(symbol) }) // get all array elements and their degrees - .flat_map(|symbol| { - symbol - .array_elements() - .map(|(_, id)| (id, symbol.degree.unwrap_or(1 << MAX_DEGREE_LOG))) - }) + .flat_map(|symbol| symbol.array_elements().map(|(_, id)| (id, symbol.degree))) // only keep the ones matching our set .filter_map(|(id, degree)| ids.contains(&id).then_some(degree)) // get the common degree @@ -327,9 +324,17 @@ impl<'a, T: FieldElement> FixedData<'a, T> { .unwrap_or_else(|_| panic!("expected all polynomials to have the same degree")) } + fn common_degree<'b>(&self, ids: impl IntoIterator) -> DegreeType { + self.common_set_degree(ids).unwrap_or(1 << MAX_DEGREE_LOG) + } + + fn is_variable_size<'b>(&self, ids: impl IntoIterator) -> bool { + self.common_set_degree(ids).is_none() + } + pub fn new( analyzed: &'a Analyzed, - fixed_col_values: &'a [(String, &'a Vec)], + fixed_col_values: &'a [(String, VariablySizedColumn)], external_witness_values: &'a [(String, Vec)], challenges: BTreeMap, stage: u8, @@ -453,14 +458,23 @@ impl<'a, T: FieldElement> FixedData<'a, T> { pub struct FixedColumn<'a, T> { name: String, - values: &'a Vec, + values: &'a VariablySizedColumn, } impl<'a, T> FixedColumn<'a, T> { - pub fn new(name: &'a str, values: &'a Vec) -> FixedColumn<'a, T> { + pub fn new(name: &'a str, values: &'a VariablySizedColumn) -> FixedColumn<'a, T> { let name = name.to_string(); FixedColumn { name, values } } + + pub fn values(&self, size: DegreeType) -> &[T] { + self.values.get_by_size(size as usize).unwrap() + } + + pub fn values_max_size(&self) -> &[T] { + let max_size = self.values.available_sizes().into_iter().max().unwrap() as DegreeType; + self.values(max_size) + } } #[derive(Debug)] diff --git a/executor/src/witgen/processor.rs b/executor/src/witgen/processor.rs index 446ad00f9..8c97cdede 100644 --- a/executor/src/witgen/processor.rs +++ b/executor/src/witgen/processor.rs @@ -85,6 +85,7 @@ pub struct Processor<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> { inputs: Vec<(PolyID, T)>, previously_set_inputs: BTreeMap, copy_constraints: CopyConstraints<(PolyID, RowIndex)>, + size: DegreeType, } impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, Q> { @@ -94,6 +95,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, mutable_state: &'c mut MutableState<'a, 'b, T, Q>, fixed_data: &'a FixedData<'a, T>, witness_cols: &'c HashSet, + size: DegreeType, ) -> Self { let is_relevant_witness = WitnessColumnMap::from( fixed_data @@ -121,6 +123,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, previously_set_inputs: BTreeMap::new(), // TODO(#1333): Get copy constraints from PIL. copy_constraints: Default::default(), + size, } } @@ -166,6 +169,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, self.row_offset + row_index as u64, self.fixed_data, UnknownStrategy::Unknown, + self.size, ); self.outer_query .as_ref() @@ -176,8 +180,11 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, } pub fn process_queries(&mut self, row_index: usize) -> Result> { - let mut query_processor = - QueryProcessor::new(self.fixed_data, self.mutable_state.query_callback); + let mut query_processor = QueryProcessor::new( + self.fixed_data, + self.mutable_state.query_callback, + self.size, + ); let global_row_index = self.row_offset + row_index as u64; let row_pair = RowPair::new( &self.data[row_index], @@ -185,6 +192,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, global_row_index, self.fixed_data, UnknownStrategy::Unknown, + self.size, ); let mut updates = EvalValue::complete(vec![]); for poly_id in &self.prover_query_witnesses { @@ -211,6 +219,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> Processor<'a, 'b, 'c, T, global_row_index, self.fixed_data, unknown_strategy, + self.size, ); // Compute updates @@ -286,6 +295,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): self.row_offset + row_index as u64, self.fixed_data, UnknownStrategy::Unknown, + self.size, ); let mut identity_processor = IdentityProcessor::new(self.fixed_data, self.mutable_state); @@ -365,6 +375,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): self.row_offset + row_index as u64, self.fixed_data, UnknownStrategy::Unknown, + self.size, ); let affine_expression = row_pair.evaluate(expression)?; let updates = (affine_expression - value.into()) @@ -503,6 +514,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): self.row_offset + (row_index - 1) as DegreeType, self.fixed_data, UnknownStrategy::Zero, + self.size, ) } // Check whether identities without a reference to the next row are satisfied @@ -513,6 +525,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): self.row_offset + row_index as DegreeType, self.fixed_data, UnknownStrategy::Zero, + self.size, ), }; diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index 9b53b8628..592f8bf14 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use powdr_ast::analyzed::Challenge; use powdr_ast::analyzed::{AlgebraicReference, Expression, PolyID, PolynomialType}; use powdr_ast::parsed::types::Type; -use powdr_number::{BigInt, FieldElement}; +use powdr_number::{BigInt, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, SymbolLookup, Value}; use super::{rows::RowPair, Constraint, EvalResult, EvalValue, FixedData, IncompleteCause}; @@ -12,15 +12,21 @@ use super::{rows::RowPair, Constraint, EvalResult, EvalValue, FixedData, Incompl pub struct QueryProcessor<'a, 'b, T: FieldElement, QueryCallback: Send + Sync> { fixed_data: &'a FixedData<'a, T>, query_callback: &'b mut QueryCallback, + size: DegreeType, } impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> QueryProcessor<'a, 'b, T, QueryCallback> { - pub fn new(fixed_data: &'a FixedData<'a, T>, query_callback: &'b mut QueryCallback) -> Self { + pub fn new( + fixed_data: &'a FixedData<'a, T>, + query_callback: &'b mut QueryCallback, + size: DegreeType, + ) -> Self { Self { fixed_data, query_callback, + size, } } @@ -87,6 +93,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> let mut symbols = Symbols { fixed_data: self.fixed_data, rows, + size: self.size, }; let fun = evaluator::evaluate(query, &mut symbols)?; evaluator::evaluate_function_call(fun, arguments, &mut symbols).map(|v| v.to_string()) @@ -97,6 +104,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback> struct Symbols<'a, T: FieldElement> { fixed_data: &'a FixedData<'a, T>, rows: &'a RowPair<'a, 'a, T>, + size: DegreeType, } impl<'a, T: FieldElement> SymbolLookup<'a, T> for Symbols<'a, T> { @@ -144,7 +152,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Symbols<'a, T> { .get_value(poly_ref) .ok_or(EvalError::DataNotAvailable)?, PolynomialType::Constant => { - let values = self.fixed_data.fixed_cols[&poly_ref.poly_id].values; + let values = self.fixed_data.fixed_cols[&poly_ref.poly_id].values(self.size); let row = self.rows.current_row_index + if poly_ref.next { 1 } else { 0 }; values[usize::from(row)] } diff --git a/executor/src/witgen/rows.rs b/executor/src/witgen/rows.rs index f30d5e385..8db21e5e6 100644 --- a/executor/src/witgen/rows.rs +++ b/executor/src/witgen/rows.rs @@ -398,6 +398,7 @@ pub struct RowPair<'row, 'a, T: FieldElement> { pub current_row_index: RowIndex, fixed_data: &'a FixedData<'a, T>, unknown_strategy: UnknownStrategy, + size: DegreeType, } impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { /// Creates a new row pair. @@ -407,6 +408,7 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { current_row_index: RowIndex, fixed_data: &'a FixedData<'a, T>, unknown_strategy: UnknownStrategy, + size: DegreeType, ) -> Self { Self { current, @@ -414,6 +416,7 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { current_row_index, fixed_data, unknown_strategy, + size, } } @@ -423,6 +426,7 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { current_row_index: RowIndex, fixed_data: &'a FixedData<'a, T>, unknown_strategy: UnknownStrategy, + size: DegreeType, ) -> Self { Self { current, @@ -430,6 +434,7 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { current_row_index, fixed_data, unknown_strategy, + size, } } @@ -472,6 +477,7 @@ impl<'row, 'a, T: FieldElement> RowPair<'row, 'a, T> { self.fixed_data, self.current_row_index.into(), self, + self.size, )) .evaluate(expr) } diff --git a/executor/src/witgen/symbolic_witness_evaluator.rs b/executor/src/witgen/symbolic_witness_evaluator.rs index 6d9831d8b..679341263 100644 --- a/executor/src/witgen/symbolic_witness_evaluator.rs +++ b/executor/src/witgen/symbolic_witness_evaluator.rs @@ -17,6 +17,7 @@ pub struct SymbolicWitnessEvaluator<'a, T: FieldElement, WA: WitnessColumnEvalua fixed_data: &'a FixedData<'a, T>, row: DegreeType, witness_access: &'a WA, + size: DegreeType, } impl<'a, T: FieldElement, WA> SymbolicWitnessEvaluator<'a, T, WA> @@ -26,11 +27,17 @@ where /// Constructs a new SymbolicWitnessEvaluator /// @param row the row on which to evaluate plain fixed /// columns ("next columns" - f' - are evaluated on row + 1). - pub fn new(fixed_data: &'a FixedData<'a, T>, row: DegreeType, witness_access: &'a WA) -> Self { + pub fn new( + fixed_data: &'a FixedData<'a, T>, + row: DegreeType, + witness_access: &'a WA, + size: DegreeType, + ) -> Self { Self { fixed_data, row, witness_access, + size, } } } @@ -45,7 +52,7 @@ where self.witness_access.value(poly) } else { // Constant polynomial (or something else) - let values = self.fixed_data.fixed_cols[&poly.poly_id].values; + let values = self.fixed_data.fixed_cols[&poly.poly_id].values(self.size); let row = if poly.next { self.row + 1 } else { self.row } % (values.len() as DegreeType); Ok(values[row as usize].into()) diff --git a/executor/src/witgen/vm_processor.rs b/executor/src/witgen/vm_processor.rs index e1412c2cc..81a5509d7 100644 --- a/executor/src/witgen/vm_processor.rs +++ b/executor/src/witgen/vm_processor.rs @@ -76,7 +76,14 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback> VmProcessor<'a, 'b, 'c, T let (identities_with_next, identities_without_next): (Vec<_>, Vec<_>) = identities .iter() .partition(|identity| identity.contains_next_ref()); - let processor = Processor::new(row_offset, data, mutable_state, fixed_data, witnesses); + let processor = Processor::new( + row_offset, + data, + mutable_state, + fixed_data, + witnesses, + degree, + ); let progress_bar = ProgressBar::new(degree); progress_bar.set_style( From 3932a6838f267b4a30369a5746e6b0eddad85e1b Mon Sep 17 00:00:00 2001 From: Leo Date: Wed, 24 Jul 2024 17:37:51 +0200 Subject: [PATCH 23/24] Some simplficiations for the regs in mem ricsv machine (#1598) --- riscv-executor/src/lib.rs | 16 +- riscv/src/code_gen.rs | 339 ++++++++++++++------------ riscv/src/continuations/bootloader.rs | 42 ++-- riscv/src/runtime.rs | 58 ++--- 4 files changed, 240 insertions(+), 215 deletions(-) diff --git a/riscv-executor/src/lib.rs b/riscv-executor/src/lib.rs index 7b8d8d83b..e7dd3bc0b 100644 --- a/riscv-executor/src/lib.rs +++ b/riscv-executor/src/lib.rs @@ -670,7 +670,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { vec![val] } - "move_reg" => { + "affine" => { let val = self.proc.get_reg_mem(args[0].u()); let write_reg = args[1].u(); let factor = args[2]; @@ -781,18 +781,18 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { Vec::new() } - "branch_if_zero" => { + "branch_if_diff_equal" => { let val1 = self.proc.get_reg_mem(args[0].u()); let val2 = self.proc.get_reg_mem(args[1].u()); let offset = args[2]; - let val: Elem = val1.sub(&val2.add(&offset)); + let val: Elem = val1.sub(&val2).sub(&offset); if val.is_zero() { self.proc.set_pc(args[3]); } Vec::new() } - "skip_if_zero" => { + "skip_if_equal" => { let val1 = self.proc.get_reg_mem(args[0].u()); let val2 = self.proc.get_reg_mem(args[1].u()); let offset = args[2]; @@ -805,24 +805,24 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { Vec::new() } - "branch_if_positive" => { + "branch_if_diff_greater_than" => { let val1 = self.proc.get_reg_mem(args[0].u()); let val2 = self.proc.get_reg_mem(args[1].u()); let offset = args[2]; - let val: Elem = val1.sub(&val2).add(&offset); + let val: Elem = val1.sub(&val2).sub(&offset); if val.bin() > 0 { self.proc.set_pc(args[3]); } Vec::new() } - "is_positive" => { + "is_diff_greater_than" => { let val1 = self.proc.get_reg_mem(args[0].u()); let val2 = self.proc.get_reg_mem(args[1].u()); let offset = args[2]; let write_reg = args[3].u(); - let val = val1.sub(&val2).add(&offset); + let val = val1.sub(&val2).sub(&offset); let r = if val.bin() > 0 { 1 } else { 0 }; self.proc.set_reg_mem(write_reg, r.into()); diff --git a/riscv/src/code_gen.rs b/riscv/src/code_gen.rs index a12a879a8..af3792720 100644 --- a/riscv/src/code_gen.rs +++ b/riscv/src/code_gen.rs @@ -391,10 +391,10 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo reg query_arg_2; // Witness columns used in instuctions for intermediate values inside instructions. - col witness val1_col; - col witness val2_col; - col witness val3_col; - col witness val4_col; + col witness tmp1_col; + col witness tmp2_col; + col witness tmp3_col; + col witness tmp4_col; // We need to add these inline instead of using std::utils::is_zero // because when XX is not constrained, witgen will try to set XX, @@ -409,114 +409,105 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Load the value of label `l` into register X. instr load_label X, l: label - link ~> regs.mstore(X, STEP, val1_col) + link ~> regs.mstore(X, STEP, tmp1_col) { - val1_col = l + tmp1_col = l } // Jump to `l` and store the return program counter in register W. instr jump l: label, W - link ~> regs.mstore(W, STEP, val3_col) + link ~> regs.mstore(W, STEP, pc + 1) { - pc' = l, - val3_col = pc + 1 + pc' = l } // Jump to the address in register X and store the return program counter in register W. instr jump_dyn X, W - link ~> val1_col = regs.mload(X, STEP) - link ~> regs.mstore(W, STEP, val3_col) - { - pc' = val1_col, - val3_col = pc + 1 - } + link ~> pc' = regs.mload(X, STEP) + link ~> regs.mstore(W, STEP, pc + 1); // Jump to `l` if val(X) - val(Y) is nonzero, where X and Y are register ids. instr branch_if_diff_nonzero X, Y, l: label - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) { XXIsZero = 1 - XX * XX_inv, - XX = val1_col - val2_col, + XX = tmp1_col - tmp2_col, pc' = (1 - XXIsZero) * l + XXIsZero * (pc + 1) } - // Jump to `l` if val(X) - (val(Y) + Z) is zero, where X and Y are register ids and Z is a - // constant offset. - instr branch_if_zero X, Y, Z, l: label - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) + // Jump to `l` if (val(X) - val(Y)) == Z, where X and Y are register ids and Z is a number. + instr branch_if_diff_equal X, Y, Z, l: label + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) { XXIsZero = 1 - XX * XX_inv, - XX = val1_col - (val2_col + Z), + XX = tmp1_col - tmp2_col - Z, pc' = XXIsZero * l + (1 - XXIsZero) * (pc + 1) } // Skips W instructions if val(X) - val(Y) + Z is zero, where X and Y are register ids and Z is a // constant offset. - instr skip_if_zero X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) + instr skip_if_equal X, Y, Z, W + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) { XXIsZero = 1 - XX * XX_inv, - XX = val1_col - val2_col + Z, + XX = tmp1_col - tmp2_col + Z, pc' = pc + 1 + (XXIsZero * W) } - // Branches to `l` if V = val(X) - val(Y) + Z is positive, where X and Y are register ids and Z is a - // constant offset. + // Branches to `l` if V = val(X) - val(Y) - Z is positive, i.e. val(X) - val(Y) > Z, + // where X and Y are register ids and Z is a constant. // V is required to be the difference of two 32-bit unsigned values. // i.e. -2**32 < V < 2**32. - instr branch_if_positive X, Y, Z, l: label - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) + instr branch_if_diff_greater_than X, Y, Z, l: label + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) { - (val1_col - val2_col + Z) + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32, + (tmp1_col - tmp2_col - Z) + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32, pc' = wrap_bit * l + (1 - wrap_bit) * (pc + 1) } - // Stores 1 in register W if V = val(X) - val(Y) + Z is positive, where X and Y are register ids and Z is a - // constant offset. + // Stores 1 in register W if V = val(X) - val(Y) - Z is positive, + // i.e. val(X) - val(Y) > Z, where X and Y are register ids and Z is a constant. // V is required to be the difference of two 32-bit unsigend values. // i.e. -2**32 < V < 2**32 - instr is_positive X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) + instr is_diff_greater_than X, Y, Z, W + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) link ~> regs.mstore(W, STEP + 2, wrap_bit) { - (val1_col - val2_col + Z) + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32 + (tmp1_col - tmp2_col - Z) + 2**32 - 1 = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + wrap_bit * 2**32 } // Stores val(X) * Z + W in register Y. - instr move_reg X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> regs.mstore(Y, STEP + 1, val3_col) - { - val3_col = val1_col * Z + W - } + instr affine X, Y, Z, W + link ~> tmp1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 1, tmp1_col * Z + W); // ================= wrapping instructions ================= // Computes V = val(X) + val(Y) + Z, wraps it in 32 bits, and stores the result in register W. // Requires 0 <= V < 2**33. instr add_wrap X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> regs.mstore(W, STEP + 2, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, tmp3_col) { - val1_col + val2_col + Z = val3_col + wrap_bit * 2**32, - val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + tmp1_col + tmp2_col + Z = tmp3_col + wrap_bit * 2**32, + tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } // Computes V = val(X) - val(Y) + Z, wraps it in 32 bits, and stores the result in register W. // Requires -2**32 <= V < 2**32. instr sub_wrap_with_offset X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> regs.mstore(W, STEP + 2, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, tmp3_col) { - (val1_col - val2_col + Z) + 2**32 = val3_col + wrap_bit * 2**32, - val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + (tmp1_col - tmp2_col + Z) + 2**32 = tmp3_col + wrap_bit * 2**32, + tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } // ================= logical instructions ================= @@ -524,22 +515,22 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Stores 1 in register W if the value in register X is zero, // otherwise stores 0. instr is_equal_zero X, W - link ~> val1_col = regs.mload(X, STEP) + link ~> tmp1_col = regs.mload(X, STEP) link ~> regs.mstore(W, STEP + 2, XXIsZero) { XXIsZero = 1 - XX * XX_inv, - XX = val1_col + XX = tmp1_col } // Stores 1 in register W if val(X) == val(Y), otherwise stores 0. instr is_not_equal X, Y, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> regs.mstore(W, STEP + 2, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(W, STEP + 2, tmp3_col) { XXIsZero = 1 - XX * XX_inv, - XX = val1_col - val2_col, - val3_col = 1 - XXIsZero + XX = tmp1_col - tmp2_col, + tmp3_col = 1 - XXIsZero } // ================= submachine instructions ================= @@ -563,12 +554,12 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Sign extends the value in register X and stores it in register Y. // Input is a 32 bit unsigned number. We check bit 7 and set all higher bits to that value. instr sign_extend_byte X, Y - link ~> val1_col = regs.mload(X, STEP) - link ~> regs.mstore(Y, STEP + 3, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 3, tmp3_col) { // wrap_bit is used as sign_bit here. - val1_col = Y_7bit + wrap_bit * 0x80 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, - val3_col = Y_7bit + wrap_bit * 0xffffff80 + tmp1_col = Y_7bit + wrap_bit * 0x80 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, + tmp3_col = Y_7bit + wrap_bit * 0xffffff80 } col witness Y_7bit; link => bit7.check(Y_7bit); @@ -576,14 +567,14 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Sign extends the value in register X and stores it in register Y. // Input is a 32 bit unsigned number. We check bit 15 and set all higher bits to that value. instr sign_extend_16_bits X, Y - link ~> val1_col = regs.mload(X, STEP) - link ~> regs.mstore(Y, STEP + 3, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 3, tmp3_col) { Y_15bit = X_b1 + Y_7bit * 0x100, // wrap_bit is used as sign_bit here. - val1_col = Y_15bit + wrap_bit * 0x8000 + X_b3 * 0x10000 + X_b4 * 0x1000000, - val3_col = Y_15bit + wrap_bit * 0xffff8000 + tmp1_col = Y_15bit + wrap_bit * 0x8000 + X_b3 * 0x10000 + X_b4 * 0x1000000, + tmp3_col = Y_15bit + wrap_bit * 0xffff8000 } col witness Y_15bit; @@ -591,12 +582,12 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Input is a 32 bit unsigned number (0 <= val(X) < 2**32) interpreted as a two's complement numbers. // Returns a signed number (-2**31 <= val(Y) < 2**31). instr to_signed X, Y - link ~> val1_col = regs.mload(X, STEP) - link ~> regs.mstore(Y, STEP + 1, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> regs.mstore(Y, STEP + 1, tmp3_col) { // wrap_bit is used as sign_bit here. - val1_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + Y_7bit * 0x1000000 + wrap_bit * 0x80000000, - val3_col = val1_col - wrap_bit * 0x100000000 + tmp1_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + Y_7bit * 0x1000000 + wrap_bit * 0x80000000, + tmp3_col = tmp1_col - wrap_bit * 0x100000000 } // ======================= assertions ========================= @@ -608,11 +599,11 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Removes up to 16 bits beyond 32 // TODO is this really safe? instr wrap16 X, Y, Z - link ~> val1_col = regs.mload(X, STEP) - link ~> regs.mstore(Z, STEP + 3, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> regs.mstore(Z, STEP + 3, tmp3_col) { - (val1_col * Y) = Y_b5 * 2**32 + Y_b6 * 2**40 + val3_col, - val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + (tmp1_col * Y) = Y_b5 * 2**32 + Y_b6 * 2**40 + tmp3_col, + tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } col witness Y_b5; @@ -635,28 +626,28 @@ fn preamble(runtime: &Runtime, degree: u64, with_bootloader: bo // Computes Q = val(Y) / val(X) and R = val(Y) % val(X) and stores them in registers Z and W. instr divremu Y, X, Z, W - link ~> val1_col = regs.mload(Y, STEP) - link ~> val2_col = regs.mload(X, STEP + 1) - link ~> regs.mstore(Z, STEP + 2, val3_col) - link ~> regs.mstore(W, STEP + 3, val4_col) + link ~> tmp1_col = regs.mload(Y, STEP) + link ~> tmp2_col = regs.mload(X, STEP + 1) + link ~> regs.mstore(Z, STEP + 2, tmp3_col) + link ~> regs.mstore(W, STEP + 3, tmp4_col) { XXIsZero = 1 - XX * XX_inv, - XX = val2_col, + XX = tmp2_col, // if X is zero, remainder is set to dividend, as per RISC-V specification: - val2_col * val3_col + val4_col = val1_col, + tmp2_col * tmp3_col + tmp4_col = tmp1_col, // remainder >= 0: - val4_col = REM_b1 + REM_b2 * 0x100 + REM_b3 * 0x10000 + REM_b4 * 0x1000000, + tmp4_col = REM_b1 + REM_b2 * 0x100 + REM_b3 * 0x10000 + REM_b4 * 0x1000000, // remainder < divisor, conditioned to val(X) not being 0: - (1 - XXIsZero) * (val2_col - val4_col - 1 - Y_b5 - Y_b6 * 0x100 - Y_b7 * 0x10000 - Y_b8 * 0x1000000) = 0, + (1 - XXIsZero) * (tmp2_col - tmp4_col - 1 - Y_b5 - Y_b6 * 0x100 - Y_b7 * 0x10000 - Y_b8 * 0x1000000) = 0, // in case X is zero, we set quotient according to RISC-V specification - XXIsZero * (val3_col - 0xffffffff) = 0, + XXIsZero * (tmp3_col - 0xffffffff) = 0, // quotient is 32 bits: - val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 + tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } "# + mul_instruction } @@ -670,14 +661,14 @@ fn mul_instruction(runtime: &Runtime) -> &'static str { // Computes V = val(X) * val(Y) and // stores the lower 32 bits in register Z and the upper 32 bits in register W. instr mul X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> regs.mstore(Z, STEP + 2, val3_col) - link ~> regs.mstore(W, STEP + 3, val4_col); + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> regs.mstore(Z, STEP + 2, tmp3_col) + link ~> regs.mstore(W, STEP + 3, tmp4_col); { - val1_col * val2_col = val3_col + val4_col * 2**32, - val3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, - val4_col = Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000 + tmp1_col * tmp2_col = tmp3_col + tmp4_col * 2**32, + tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, + tmp4_col = Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000 } "# } @@ -692,11 +683,11 @@ fn mul_instruction(runtime: &Runtime) -> &'static str { // Computes V = val(X) * val(Y) and // stores the lower 32 bits in register Z and the upper 32 bits in register W. instr mul X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> (val3_col, val4_col) = split_gl.split(val1_col * val2_col) - link ~> regs.mstore(Z, STEP + 2, val3_col) - link ~> regs.mstore(W, STEP + 3, val4_col); + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> (tmp3_col, tmp4_col) = split_gl.split(tmp1_col * tmp2_col) + link ~> regs.mstore(Z, STEP + 2, tmp3_col) + link ~> regs.mstore(W, STEP + 3, tmp4_col); "# } } @@ -710,12 +701,12 @@ fn memory(with_bootloader: bool) -> String { // Stores val(W) at address (V = val(X) - val(Z) + Y) % 2**32. // V can be between 0 and 2**33. instr mstore_bootloader X, Z, Y, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Z, STEP + 1) - link ~> val3_col = regs.mload(W, STEP + 2) - link ~> memory.mstore_bootloader(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP + 3, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Z, STEP + 1) + link ~> tmp3_col = regs.mload(W, STEP + 2) + link ~> memory.mstore_bootloader(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP + 3, tmp3_col) { - val1_col - val2_col + Y = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 + tmp1_col - tmp2_col + Y = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 } "# } else { @@ -737,26 +728,26 @@ fn memory(with_bootloader: bool) -> String { /// Writes the loaded word and the remainder of the division by 4 to registers Z and W, /// respectively. instr mload X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val3_col = memory.mload(X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4, STEP + 1) - link ~> regs.mstore(Z, STEP + 2, val3_col) - link ~> regs.mstore(W, STEP + 3, val4_col) - link => bit2.check(val4_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp3_col = memory.mload(X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4, STEP + 1) + link ~> regs.mstore(Z, STEP + 2, tmp3_col) + link ~> regs.mstore(W, STEP + 3, tmp4_col) + link => bit2.check(tmp4_col) link => bit6.check(X_b1) { - val1_col + Y = wrap_bit * 2**32 + X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4 + val4_col + tmp1_col + Y = wrap_bit * 2**32 + X_b4 * 0x1000000 + X_b3 * 0x10000 + X_b2 * 0x100 + X_b1 * 4 + tmp4_col } // Stores val(W) at address (V = val(X) - val(Y) + Z) % 2**32. // V can be between 0 and 2**33. // V should be a multiple of 4, but this instruction does not enforce it. instr mstore X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> val3_col = regs.mload(W, STEP + 2) - link ~> memory.mstore(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP + 3, val3_col) + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> tmp3_col = regs.mload(W, STEP + 2) + link ~> memory.mstore(X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, STEP + 3, tmp3_col) { - val1_col - val2_col + Z = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 + tmp1_col - tmp2_col + Z = (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000) + wrap_bit * 2**32 } "# } @@ -845,7 +836,7 @@ fn process_instruction(instr: &str, args: A) -> Result { let (rd, rs) = args.rr()?; - only_if_no_write_to_zero(rd, format!("move_reg {}, {}, 1, 0;", rs.addr(), rd.addr())) + only_if_no_write_to_zero(rd, format!("affine {}, {}, 1, 0;", rs.addr(), rd.addr())) } // Arithmetic @@ -922,15 +913,23 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result { @@ -1188,7 +1199,7 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result { let (r1, label) = args.rl()?; let label = escape_label(label.as_ref()); - vec![format!("branch_if_zero {}, 0, 0, {label};", r1.addr())] + vec![format!( + "branch_if_diff_equal {}, 0, 0, {label};", + r1.addr() + )] } "bgeu" => { let (r1, r2, label) = args.rrl()?; let label = escape_label(label.as_ref()); - // TODO does this fulfill the input requirements for branch_if_positive? + // TODO does this fulfill the input requirements for branch_if_diff_greater_than? vec![format!( - "branch_if_positive {}, {}, 1, {label};", + "branch_if_diff_greater_than {}, {}, -1, {label};", r1.addr(), r2.addr() )] @@ -1236,14 +1250,17 @@ fn process_instruction(instr: &str, args: A) -> Result { let (r1, r2, label) = args.rrl()?; let label = escape_label(label.as_ref()); vec![format!( - "branch_if_positive {}, {}, 0, {label};", + "branch_if_diff_greater_than {}, {}, 0, {label};", r2.addr(), r1.addr() )] @@ -1252,12 +1269,12 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result= r2 (signed). - // TODO does this fulfill the input requirements for branch_if_positive? + // TODO does this fulfill the input requirements for branch_if_diff_greater_than? vec![ format!("to_signed {}, {};", r1.addr(), tmp1.addr()), format!("to_signed {}, {};", r2.addr(), tmp2.addr()), format!( - "branch_if_positive {}, {}, 1, {label};", + "branch_if_diff_greater_than {}, {}, -1, {label};", tmp1.addr(), tmp2.addr() ), @@ -1283,7 +1300,7 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result { @@ -1302,7 +1322,10 @@ fn process_instruction(instr: &str, args: A) -> Result { @@ -1409,7 +1432,7 @@ fn process_instruction(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result(instr: &str, args: A) -> Result String { std::machines::write_once_memory::WriteOnceMemory bootloader_inputs; instr load_bootloader_input X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link => bootloader_inputs.access(val1_col * Z + W, val3_col) - link ~> regs.mstore(Y, STEP + 2, val3_col); + link ~> tmp1_col = regs.mload(X, STEP) + link => bootloader_inputs.access(tmp1_col * Z + W, tmp3_col) + link ~> regs.mstore(Y, STEP + 2, tmp3_col); instr assert_bootloader_input X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link => bootloader_inputs.access(val1_col * Z + W, val2_col); + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link => bootloader_inputs.access(tmp1_col * Z + W, tmp2_col); // Sets the PC to the bootloader input at the provided index instr jump_to_bootloader_input X link => bootloader_inputs.access(X, pc'); @@ -212,7 +212,7 @@ load_bootloader_input 0, 21, 1, {MEMORY_HASH_START_INDEX} + 3; // Current page index set_reg 2, 0; -branch_if_zero 1, 0, 0, bootloader_end_page_loop; +branch_if_diff_equal 1, 0, 0, bootloader_end_page_loop; bootloader_start_page_loop: @@ -240,7 +240,7 @@ and 3, 0, {PAGE_NUMBER_MASK}, 3; "#, )); - bootloader.push_str(&format!("move_reg 3, 90, {PAGE_SIZE_BYTES}, 0;\n")); + bootloader.push_str(&format!("affine 3, 90, {PAGE_SIZE_BYTES}, 0;\n")); for i in 0..WORDS_PER_PAGE { let reg = EXTRA_REGISTERS[(i % 4) + 4]; bootloader.push_str(&format!( @@ -248,7 +248,7 @@ and 3, 0, {PAGE_NUMBER_MASK}, 3; load_bootloader_input 2, 91, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {i}; {reg} <== get_reg(91); -move_reg 3, 90, {PAGE_SIZE_BYTES}, {i} * {BYTES_PER_WORD}; +affine 3, 90, {PAGE_SIZE_BYTES}, {i} * {BYTES_PER_WORD}; mstore_bootloader 90, 0, 0, 91;"# )); @@ -345,20 +345,20 @@ bootloader_level_{i}_end: branch_if_diff_nonzero 9, 0, bootloader_update_memory_hash; // Assert Correct Merkle Root -move_reg 18, 90, -1, 0; -move_reg 90, 90, 1, {P0}; +affine 18, 90, -1, 0; +affine 90, 90, 1, {P0}; branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; -move_reg 19, 90, -1, 0; -move_reg 90, 90, 1, {P1}; +affine 19, 90, -1, 0; +affine 90, 90, 1, {P1}; branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; -move_reg 20, 90, -1, 0; -move_reg 90, 90, 1, {P2}; +affine 20, 90, -1, 0; +affine 90, 90, 1, {P2}; branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; -move_reg 21, 90, -1, 0; -move_reg 90, 90, 1, {P3}; +affine 21, 90, -1, 0; +affine 90, 90, 1, {P3}; branch_if_diff_nonzero 90, 0, bootloader_memory_hash_mismatch; jump bootloader_memory_hash_ok, 90; @@ -393,7 +393,7 @@ set_reg 20, {P2}; set_reg 21, {P3}; // Increment page index -move_reg 2, 2, 1, 1; +affine 2, 2, 1, 1; branch_if_diff_nonzero 2, 1, bootloader_start_page_loop; @@ -480,7 +480,7 @@ add_wrap 1, 0, 0, 1; // Current page index set_reg 2, 0; -branch_if_zero 1, 0, 0, shutdown_end_page_loop; +branch_if_diff_equal 1, 0, 0, shutdown_end_page_loop; shutdown_start_page_loop: @@ -512,7 +512,7 @@ and 3, 0, {PAGE_NUMBER_MASK}, 3; "#, )); - bootloader.push_str(&format!("move_reg 90, 3, {PAGE_SIZE_BYTES}, 0;\n")); + bootloader.push_str(&format!("affine 90, 3, {PAGE_SIZE_BYTES}, 0;\n")); for i in 0..WORDS_PER_PAGE { let reg = EXTRA_REGISTERS[(i % 4) + 4]; bootloader.push_str(&format!("mload 90, {i} * {BYTES_PER_WORD}, 90, 91;\n")); @@ -543,7 +543,7 @@ set_reg 90, {P3}; assert_bootloader_input 2, 90, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + {WORDS_PER_PAGE} + 1 + 3; // Increment page index -move_reg 2, 2, 1, 1; +affine 2, 2, 1, 1; branch_if_diff_nonzero 2, 1, shutdown_start_page_loop; diff --git a/riscv/src/runtime.rs b/riscv/src/runtime.rs index a26e9909e..078dea5b0 100644 --- a/riscv/src/runtime.rs +++ b/riscv/src/runtime.rs @@ -90,20 +90,20 @@ impl Runtime { "binary", [ r#"instr and X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> val3_col = binary.and(val1_col, val2_col + Z) - link ~> regs.mstore(W, STEP + 3, val3_col);"#, + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> tmp3_col = binary.and(tmp1_col, tmp2_col + Z) + link ~> regs.mstore(W, STEP + 3, tmp3_col);"#, r#"instr or X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> val3_col = binary.or(val1_col, val2_col + Z) - link ~> regs.mstore(W, STEP + 3, val3_col);"#, + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> tmp3_col = binary.or(tmp1_col, tmp2_col + Z) + link ~> regs.mstore(W, STEP + 3, tmp3_col);"#, r#"instr xor X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> val3_col = binary.xor(val1_col, val2_col + Z) - link ~> regs.mstore(W, STEP + 3, val3_col);"#, + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> tmp3_col = binary.xor(tmp1_col, tmp2_col + Z) + link ~> regs.mstore(W, STEP + 3, tmp3_col);"#, ], 0, ["and 0, 0, 0, 0;"], @@ -115,15 +115,15 @@ impl Runtime { "shift", [ r#"instr shl X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> val3_col = shift.shl(val1_col, val2_col + Z) - link ~> regs.mstore(W, STEP + 3, val3_col);"#, + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> tmp3_col = shift.shl(tmp1_col, tmp2_col + Z) + link ~> regs.mstore(W, STEP + 3, tmp3_col);"#, r#"instr shr X, Y, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> val2_col = regs.mload(Y, STEP + 1) - link ~> val3_col = shift.shr(val1_col, val2_col + Z) - link ~> regs.mstore(W, STEP + 3, val3_col);"#, + link ~> tmp1_col = regs.mload(X, STEP) + link ~> tmp2_col = regs.mload(Y, STEP + 1) + link ~> tmp3_col = shift.shr(tmp1_col, tmp2_col + Z) + link ~> regs.mstore(W, STEP + 3, tmp3_col);"#, ], 0, ["shl 0, 0, 0, 0;"], @@ -134,10 +134,10 @@ impl Runtime { None, "split_gl", [r#"instr split_gl X, Z, W - link ~> val1_col = regs.mload(X, STEP) - link ~> (val3_col, val4_col) = split_gl.split(val1_col) - link ~> regs.mstore(Z, STEP + 2, val3_col) - link ~> regs.mstore(W, STEP + 3, val4_col);"#], + link ~> tmp1_col = regs.mload(X, STEP) + link ~> (tmp3_col, tmp4_col) = split_gl.split(tmp1_col) + link ~> regs.mstore(Z, STEP + 2, tmp3_col) + link ~> regs.mstore(W, STEP + 3, tmp4_col);"#], 0, ["split_gl 0, 0, 0;"], ); @@ -451,10 +451,12 @@ impl Runtime { ] .into_iter(); - let jump_table = self - .syscalls - .keys() - .map(|s| format!("branch_if_zero 5, 0, {}, __ecall_handler_{};", *s as u32, s)); + let jump_table = self.syscalls.keys().map(|s| { + format!( + "branch_if_diff_equal 5, 0, {}, __ecall_handler_{};", + *s as u32, s + ) + }); let invalid_handler = ["__invalid_syscall:".to_string(), "fail;".to_string()].into_iter(); From 204b9e4ff3b2fae9c38a2af6bf31bd7b98e8542e Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Thu, 25 Jul 2024 10:09:03 +0100 Subject: [PATCH 24/24] DWARF debug symbols parsing in ELF files (#1564) Conversion from ELF now has feature parity with conversion from assembly, so I am setting it as the default. --------- Co-authored-by: Leandro Pacheco --- cli-rs/src/main.rs | 8 +- riscv-executor/src/profiler.rs | 21 +- riscv/Cargo.toml | 6 +- riscv/src/code_gen.rs | 5 +- riscv/src/elf/debug_info.rs | 601 +++++++++++++++++++++++++++++++ riscv/src/{elf.rs => elf/mod.rs} | 283 +++++++++------ 6 files changed, 805 insertions(+), 119 deletions(-) create mode 100644 riscv/src/elf/debug_info.rs rename riscv/src/{elf.rs => elf/mod.rs} (84%) diff --git a/cli-rs/src/main.rs b/cli-rs/src/main.rs index f09a8e668..7beaae473 100644 --- a/cli-rs/src/main.rs +++ b/cli-rs/src/main.rs @@ -62,10 +62,10 @@ enum Commands { #[arg(long)] coprocessors: Option, - /// Convert from the executable ELF file instead of the assembly. + /// Convert from the assembly files instead of the ELF executable. #[arg(short, long)] #[arg(default_value_t = false)] - elf: bool, + asm: bool, /// Run a long execution in chunks (Experimental and not sound!) #[arg(short, long)] @@ -224,14 +224,14 @@ fn run_command(command: Commands) { field, output_directory, coprocessors, - elf, + asm, continuations, } => { call_with_field!(compile_rust::( &file, Path::new(&output_directory), coprocessors, - elf, + !asm, continuations )) } diff --git a/riscv-executor/src/profiler.rs b/riscv-executor/src/profiler.rs index e5ee2cc76..c9bce731a 100644 --- a/riscv-executor/src/profiler.rs +++ b/riscv-executor/src/profiler.rs @@ -192,15 +192,19 @@ impl<'a> Profiler<'a> { self.function_begin .range(..=pc) .last() - .and_then(|(_, function)| { - self.location_begin + .map(|(_, function)| { + let (file, line) = self + .location_begin .range(..=pc) .last() - .map(|(_, (file, line))| Loc { - function, - file: *file, - line: *line, - }) + .map(|(_, (file, line))| (*file, *line)) + // for labels with no .loc above them, just point to main file + .unwrap_or((1, 0)); + Loc { + function, + file, + line, + } }) } @@ -265,9 +269,10 @@ impl<'a> Profiler<'a> { // we start profiling on the initial call to "__runtime_start" if target.function == "__runtime_start" { let call = Call { + // __runtime_start does not have a proper ".debug loc", just point to main file from: Loc { function: "", - file: 0, + file: 1, line: 0, }, target, diff --git a/riscv/Cargo.toml b/riscv/Cargo.toml index d316d7b0f..54543fe4d 100644 --- a/riscv/Cargo.toml +++ b/riscv/Cargo.toml @@ -24,7 +24,8 @@ powdr-pipeline.workspace = true powdr-riscv-executor.workspace = true powdr-riscv-syscalls.workspace = true -goblin = { version = "0.8" } +gimli = "0.31" +goblin = "0.8" lazy_static = "1.4.0" itertools = "0.13" lalrpop-util = { version = "^0.19", features = ["lexer"] } @@ -32,12 +33,13 @@ log = "0.4.17" mktemp = "0.5.0" num-traits = "0.2.15" raki = "0.1.4" -serde_json = "1.0" # This is only here to work around https://github.com/lalrpop/lalrpop/issues/750 # It should be removed once that workaround is no longer needed. regex-syntax = { version = "0.6", default_features = false, features = [ "unicode", ] } +serde_json = "1.0" +thiserror = "1.0" [build-dependencies] lalrpop = "^0.19" diff --git a/riscv/src/code_gen.rs b/riscv/src/code_gen.rs index af3792720..2a7298e36 100644 --- a/riscv/src/code_gen.rs +++ b/riscv/src/code_gen.rs @@ -253,7 +253,10 @@ fn translate_program_impl( } statements.extend([ "set_reg 0, 0;".to_string(), - format!("jump {}, 1;", program.start_function().as_ref()), + format!( + "jump {}, 1;", + escape_label(program.start_function().as_ref()) + ), "return;".to_string(), // This is not "riscv ret", but "return from powdr asm function". ]); for s in program.take_executable_statements() { diff --git a/riscv/src/elf/debug_info.rs b/riscv/src/elf/debug_info.rs new file mode 100644 index 000000000..756cf58c4 --- /dev/null +++ b/riscv/src/elf/debug_info.rs @@ -0,0 +1,601 @@ +use std::{ + borrow::Cow, + collections::{BTreeMap, BTreeSet, HashMap}, + path::Path, +}; + +use gimli::{ + read::AttributeValue, DebuggingInformationEntry, Dwarf, EndianSlice, LittleEndian, Operation, + Unit, UnitRef, +}; +use goblin::elf::{ + sym::{STT_FUNC, STT_OBJECT}, + Elf, SectionHeader, +}; +use itertools::Itertools; + +use super::AddressMap; + +type Reader<'a> = EndianSlice<'a, LittleEndian>; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("no debug information available")] + NoDebugInfo, + #[error("DIE tree traversal skipped a level")] + UnexpectedLevel, + #[error("failed to parse debug information: {0}")] + Parsing(#[from] gimli::Error), +} + +/// Debug information extracted from the ELF file. +#[derive(Default)] +pub struct DebugInfo { + /// List of source files: (directory, file name). + pub file_list: Vec<(String, String)>, + /// Relates addresses to source locations. + pub source_locations: Vec, + /// Maps (addresses, disambiguator) to symbol names. The disambiguator is + /// used to distinguish between multiple symbols at the same address. (i.e. + /// turns BTreeMap into a multimap.) + pub symbols: SymbolTable, + /// Human readable notes about an address + pub notes: HashMap, +} + +#[derive(Debug)] +pub struct SourceLocationInfo { + pub address: u32, + pub file: u64, + pub line: u64, + pub col: u64, +} + +impl DebugInfo { + /// Extracts debug information from the ELF file, if available. + pub fn new( + elf: &Elf, + file_buffer: &[u8], + address_map: &AddressMap, + is_data_addr: &dyn Fn(u32) -> bool, + jump_targets: &BTreeSet, + ) -> Result { + let dwarf = load_dwarf_sections(elf, file_buffer)?; + + let mut file_list = Vec::new(); + let mut source_locations = Vec::new(); + let mut notes = HashMap::new(); + + // Read the ELF symbol table, to be joined with symbols from the DWARF. + let mut symbols = read_symbol_table(elf); + + // Iterate over the compilation units: + let mut units_iter = dwarf.units(); + while let Some(unit) = units_iter.next()? { + let unit = dwarf.unit(unit)?; + // Shadows the Unit with a reference to itself, because it is more + // convenient to work with a UnitRef. + let unit = UnitRef::new(&dwarf, &unit); + + // Read the source locations for this compilation unit. + let file_index_delta = + read_source_locations(unit, &mut file_list, &mut source_locations)?; + + read_unit_symbols( + &dwarf, + unit, + file_index_delta, + is_data_addr, + jump_targets, + &mut symbols, + &mut notes, + )?; + } + + // Filter out the source locations that are not in the text section + filter_locations_in_text(&mut source_locations, address_map); + + // Deduplicate the symbols + dedup_names(&mut symbols); + + // Index by address, not by name. + let mut map_disambiguator = 0u32..; + let symbols = SymbolTable( + symbols + .into_iter() + .map(|(name, address)| ((address, map_disambiguator.next().unwrap()), name)) + .collect(), + ); + + Ok(DebugInfo { + file_list, + source_locations, + symbols, + notes, + }) + } +} + +/// Reads the source locations for a compilation unit. +fn read_source_locations( + unit: UnitRef, + file_list: &mut Vec<(String, String)>, + source_locations: &mut Vec, +) -> Result { + // Traverse all the line locations for the compilation unit. + let base_dir = Path::new( + unit.comp_dir + .map(|s| s.to_string()) + .transpose()? + .unwrap_or(""), + ); + let file_idx_delta = file_list.len() as u64; + if let Some(line_program) = unit.line_program.clone() { + // Get the source file listing + for file_entry in line_program.header().file_names() { + let directory = file_entry + .directory(line_program.header()) + .map(|attr| as_str(unit, attr)) + .transpose()? + .unwrap_or(""); + + // This unwrap can not panic because both base_dir and + // directory have been validated as UTF-8 strings. + let directory = base_dir + .join(directory) + .into_os_string() + .into_string() + .unwrap(); + + let path = as_str(unit, file_entry.path_name())?; + + file_list.push((directory, path.to_owned())); + } + + // Get the locations indexed by address + let mut rows = line_program.rows(); + while let Some((_, row)) = rows.next_row()? { + // End markers point to the address after the end, so we skip them. + if row.prologue_end() || row.end_sequence() { + continue; + } + + source_locations.push(SourceLocationInfo { + address: row.address() as u32, + file: row.file_index() + file_idx_delta, + line: match row.line() { + None => 0, + Some(v) => v.get(), + }, + col: match row.column() { + gimli::ColumnType::LeftEdge => 0, + gimli::ColumnType::Column(v) => v.get(), + }, + }) + } + } + + Ok(file_idx_delta) +} + +/// Traverse the tree in which the information about the compilation +/// unit is stored and extract function and variable names. +fn read_unit_symbols( + dwarf: &Dwarf, + unit: UnitRef, + file_idx_delta: u64, + is_data_addr: &dyn Fn(u32) -> bool, + jump_targets: &BTreeSet, + symbols: &mut Vec<(String, u32)>, + notes: &mut HashMap, +) -> Result<(), Error> { + // To simplify the algorithm, we start the name stack with a placeholder value. + let mut full_name = vec![None]; + let mut entries = unit.entries(); + while let Some((level_delta, entry)) = entries.next_dfs()? { + // Get the entry name as a human readable string (this is used in a comment) + let name = find_attr(entry, gimli::DW_AT_name) + .map(|name| unit.attr_string(name).map(|s| s.to_string_lossy())) + .transpose()?; + + match level_delta { + delta if delta > 1 => return Err(Error::UnexpectedLevel), + 1 => (), + _ => { + full_name.truncate((full_name.len() as isize + level_delta - 1) as usize); + } + } + full_name.push(name); + + match entry.tag() { + // This is the entry for a function or method. + gimli::DW_TAG_subprogram => { + let attr = find_attr(entry, gimli::DW_AT_linkage_name); + let Some(linkage_name) = attr.map(|ln| unit.attr_string(ln)).transpose()? else { + // This function has no linkage name in DWARF, so it + // must be in ELFs symbol table. + continue; + }; + + let start_addresses = get_function_start(dwarf, &unit, entry)?; + let name = linkage_name.to_string()?; + for address in start_addresses { + if jump_targets.contains(&address) { + symbols.push((name.to_owned(), address)); + } + } + } + // This is the entry for a variable. + gimli::DW_TAG_variable => { + let Some(address) = get_static_var_address(&unit, entry)? else { + continue; + }; + + if !is_data_addr(address) { + continue; + } + + if full_name.last().is_some() { + // The human readable name of the variable is available, + // so we assemble a pretty note to go into the comment. + let mut file_line = None; + if let Some(AttributeValue::FileIndex(file_idx)) = + find_attr(entry, gimli::DW_AT_decl_file) + { + if let Some(AttributeValue::Udata(line)) = + find_attr(entry, gimli::DW_AT_decl_line) + { + file_line = Some((file_idx + file_idx_delta, line)); + } + } + + let value = format!( + "{}{}", + full_name + .iter() + .map(|s| match s { + Some(s) => s, + None => &Cow::Borrowed("?"), + }) + .join("::"), + if let Some((file, line)) = file_line { + format!(" at file {file} line {line}") + } else { + String::new() + } + ); + + notes.insert(address, value); + } + + // The variable symbol name is only used as a fallback + // in case there is no pretty note. + if let Some(linkage_name) = find_attr(entry, gimli::DW_AT_linkage_name) + .map(|ln| unit.attr_string(ln)) + .transpose()? + { + symbols.push((linkage_name.to_string()?.to_owned(), address)); + } + } + _ => {} + }; + } + + Ok(()) +} + +fn load_dwarf_sections<'a>(elf: &Elf, file_buffer: &'a [u8]) -> Result>, Error> { + // Index the sections by their names: + let debug_sections: HashMap<&str, &SectionHeader> = elf + .section_headers + .iter() + .filter_map(|shdr| { + elf.shdr_strtab + .get_at(shdr.sh_name) + .map(|name| (name, shdr)) + }) + .collect(); + + if debug_sections.is_empty() { + return Err(Error::NoDebugInfo); + } + + // Load the DWARF sections: + Ok(gimli::Dwarf::load(move |section| { + Ok::<_, ()>(Reader::new( + debug_sections + .get(section.name()) + .map(|shdr| { + &file_buffer[shdr.sh_offset as usize..(shdr.sh_offset + shdr.sh_size) as usize] + }) + .unwrap_or(&[]), + Default::default(), + )) + }) + .unwrap()) +} + +/// This function linear searches for an attribute of an entry. +/// +/// My first idea was to iterate over the attribute list once, matching for all +/// attributes I was interested in. But then I figured out this operation is +/// N*M, where N is the number of attributes in the list and M is the number of +/// attributes I am interested in. So doing the inverse is easier and has the +/// same complexity. Since it is hard to tell in practice which one is faster, I +/// went with the easier approach. +fn find_attr<'a>( + entry: &DebuggingInformationEntry>, + attr_type: gimli::DwAt, +) -> Option>> { + let mut attrs = entry.attrs(); + while let Some(attr) = attrs.next().unwrap() { + if attr.name() == attr_type { + return Some(attr.value()); + } + } + None +} + +fn as_str<'a>( + unit: UnitRef>, + attr: AttributeValue>, +) -> Result<&'a str, gimli::Error> { + unit.attr_string(attr)?.to_string() +} + +fn get_static_var_address( + unit: &Unit, + entry: &DebuggingInformationEntry, +) -> Result, gimli::Error> { + let Some(attr) = find_attr(entry, gimli::DW_AT_location) else { + // No location available + return Ok(None); + }; + + let AttributeValue::Exprloc(address) = attr else { + // Not an static variable + return Ok(None); + }; + + // Do the magic to find the variable address + let mut ops = address.operations(unit.encoding()); + let first_op = ops.next()?; + let second_op = ops.next()?; + let (Some(Operation::Address { address }), None) = (first_op, second_op) else { + // The address is not a constant + return Ok(None); + }; + + Ok(Some(address as u32)) +} + +fn get_function_start( + dwarf: &Dwarf, + unit: &Unit, + entry: &DebuggingInformationEntry, +) -> Result, gimli::Error> { + let mut ret = Vec::new(); + + if let Some(low_pc) = find_attr(entry, gimli::DW_AT_low_pc) + .map(|val| dwarf.attr_address(unit, val)) + .transpose()? + .flatten() + { + ret.push(low_pc as u32); + } + + if let Some(ranges) = find_attr(entry, gimli::DW_AT_ranges) + .map(|val| dwarf.attr_ranges_offset(unit, val)) + .transpose()? + .flatten() + { + let mut iter = dwarf.ranges(unit, ranges)?; + while let Some(range) = iter.next()? { + ret.push(range.begin as u32); + } + } + + Ok(ret) +} + +/// Filter out source locations that are not in a text section. +fn filter_locations_in_text(locations: &mut Vec, address_map: &AddressMap) { + locations.sort_unstable_by_key(|loc| loc.address); + + let mut done_idx = 0; + for (&start_addr, &header) in address_map.0.iter() { + // Remove all entries that are in between done and the start address. + let start_idx = find_first_idx(&locations[done_idx..], start_addr) + done_idx; + locations.drain(done_idx..start_idx); + + // The end address is one past the last byte of the section. + let end_addr = start_addr + header.p_memsz as u32; + done_idx += find_first_idx(&locations[done_idx..], end_addr); + } +} + +fn find_first_idx(slice: &[SourceLocationInfo], addr: u32) -> usize { + match slice.binary_search_by_key(&addr, |loc| loc.address) { + Ok(mut idx) => { + while idx > 0 && slice[idx - 1].address == addr { + idx -= 1; + } + idx + } + Err(idx) => idx, + } +} + +/// Index the symbols by their addresses. +#[derive(Default)] +pub struct SymbolTable(BTreeMap<(u32, u32), String>); + +impl SymbolTable { + pub fn new(elf: &Elf) -> SymbolTable { + let mut symbols = read_symbol_table(elf); + + dedup_names(&mut symbols); + + let mut disambiguator = 0..; + SymbolTable( + symbols + .into_iter() + .map(|(name, addr)| ((addr, disambiguator.next().unwrap()), name.to_string())) + .collect(), + ) + } + + /// Returns an iterator over all symbols of a given address. + fn all_iter(&self, addr: u32) -> impl Iterator { + self.0 + .range((addr, 0)..=(addr, u32::MAX)) + .map(|(_, name)| name.as_ref()) + } + + fn default_label(addr: u32) -> Cow<'static, str> { + Cow::Owned(format!("__.L{addr:08x}")) + } + + /// Get a symbol, if the address has one. + pub fn try_get_one(&self, addr: u32) -> Option<&str> { + self.all_iter(addr).next() + } + + /// Get a symbol, or a default label formed from the address value. + pub fn get_one(&self, addr: u32) -> Cow { + match self.try_get_one(addr) { + Some(s) => Cow::Borrowed(s), + None => Self::default_label(addr), + } + } + + /// Get all symbol, or a default label formed from the address value. + pub fn get_all(&self, addr: u32) -> impl Iterator> { + let mut iter = self.all_iter(addr).peekable(); + let default = if iter.peek().is_none() { + Some(Self::default_label(addr)) + } else { + None + }; + iter.map(Cow::Borrowed).chain(default) + } +} + +fn read_symbol_table(elf: &Elf) -> Vec<(String, u32)> { + elf.syms + .iter() + .filter_map(|sym| { + // We only care about global symbols that have string names, and are + // either functions or variables. + if sym.st_name != 0 && (sym.st_type() == STT_OBJECT || sym.st_type() == STT_FUNC) { + Some((elf.strtab[sym.st_name].to_owned(), sym.st_value as u32)) + } else { + None + } + }) + .collect() +} + +/// Deduplicates by removing identical entries and appending the address to +/// repeated names. The vector ends up sorted. +fn dedup_names(symbols: &mut Vec<(String, u32)>) { + while dedup_names_pass(symbols) {} +} + +/// Deduplicates the names of the symbols by appending one level of address to +/// the name. +/// +/// Returns `true` if the names were deduplicated. +fn dedup_names_pass(symbols: &mut Vec<(String, u32)>) -> bool { + symbols.sort_unstable(); + symbols.dedup(); + + let mut deduplicated = false; + let mut iter = symbols.iter_mut(); + + // The first different name defines a group, which ends on the next + // different name. The whole group is deduplicated if it contains more than + // one element. + let mut next_group = iter.next().map(|(name, address)| (name, *address)); + while let Some((group_name, group_address)) = next_group { + let mut group_deduplicated = false; + next_group = None; + + // Find duplicates and update names in the group + for (name, address) in &mut iter { + if name == group_name { + group_deduplicated = true; + deduplicated = true; + *name = format!("{name}_{address:08x}"); + } else { + next_group = Some((name, *address)); + break; + } + } + + // If there were duplicates in the group, update the group leader, too. + if group_deduplicated { + *group_name = format!("{group_name}_{group_address:08x}"); + } + } + + deduplicated +} + +#[cfg(test)] +mod tests { + #[test] + fn dedup_names() { + let mut symbols = vec![ + ("baz".to_string(), 0x8000), + ("bar".to_string(), 0x3000), + ("foo".to_string(), 0x1000), + ("bar".to_string(), 0x5000), + ("foo".to_string(), 0x2000), + ("baz".to_string(), 0x7000), + ("baz".to_string(), 0x9000), + ("doo".to_string(), 0x0042), + ("baz".to_string(), 0xa000), + ("baz".to_string(), 0x6000), + ("bar".to_string(), 0x4000), + ]; + + super::dedup_names(&mut symbols); + + let expected = vec![ + ("bar_00003000".to_string(), 0x3000), + ("bar_00004000".to_string(), 0x4000), + ("bar_00005000".to_string(), 0x5000), + ("baz_00006000".to_string(), 0x6000), + ("baz_00007000".to_string(), 0x7000), + ("baz_00008000".to_string(), 0x8000), + ("baz_00009000".to_string(), 0x9000), + ("baz_0000a000".to_string(), 0xa000), + ("doo".to_string(), 0x0042), + ("foo_00001000".to_string(), 0x1000), + ("foo_00002000".to_string(), 0x2000), + ]; + assert_eq!(symbols, expected); + + let mut symbols = vec![ + ("john".to_string(), 0x42), + ("john".to_string(), 0x87), + ("john".to_string(), 0x1aa), + ("john_000001aa".to_string(), 0x1aa), + ("john_00000042".to_string(), 0x103), + ("john_00000087".to_string(), 0x103), + ]; + + super::dedup_names(&mut symbols); + + let expected = vec![ + ("john_00000042_00000042".to_string(), 0x42), + ("john_00000042_00000103".to_string(), 0x103), + ("john_00000087_00000087".to_string(), 0x87), + ("john_00000087_00000103".to_string(), 0x103), + ("john_000001aa".to_string(), 0x1aa), + ]; + + assert_eq!(symbols, expected); + } +} diff --git a/riscv/src/elf.rs b/riscv/src/elf/mod.rs similarity index 84% rename from riscv/src/elf.rs rename to riscv/src/elf/mod.rs index 075bc31df..fcf03ae43 100644 --- a/riscv/src/elf.rs +++ b/riscv/src/elf/mod.rs @@ -1,20 +1,17 @@ use std::{ borrow::Cow, + cell::Cell, cmp::Ordering, - collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap}, + collections::{btree_map::Entry, BTreeMap, BTreeSet}, fs, path::Path, }; -use goblin::{ - elf::sym::STT_OBJECT, - elf::{ - header::{EI_CLASS, EI_DATA, ELFCLASS32, ELFDATA2LSB, EM_RISCV, ET_DYN}, - program_header::PT_LOAD, - reloc::{R_RISCV_32, R_RISCV_HI20, R_RISCV_RELATIVE}, - sym::STT_FUNC, - Elf, ProgramHeader, - }, +use goblin::elf::{ + header::{EI_CLASS, EI_DATA, ELFCLASS32, ELFDATA2LSB, EM_RISCV, ET_DYN}, + program_header::PT_LOAD, + reloc::{R_RISCV_32, R_RISCV_HI20, R_RISCV_RELATIVE}, + Elf, ProgramHeader, }; use itertools::{Either, Itertools}; use powdr_asm_utils::data_storage::SingleDataValue; @@ -27,9 +24,14 @@ use raki::{ use crate::{ code_gen::{self, InstructionArgs, MemEntry, Register, RiscVProgram, Statement}, + elf::debug_info::SymbolTable, Runtime, }; +use self::debug_info::DebugInfo; + +mod debug_info; + /// Generates a Powdr Assembly program from a RISC-V 32 executable ELF file. pub fn translate( file_name: &Path, @@ -41,7 +43,7 @@ pub fn translate( } struct ElfProgram { - symbol_table: SymbolTable, + dbg: DebugInfo, data_map: BTreeMap, text_labels: BTreeSet, instructions: Vec, @@ -151,16 +153,44 @@ fn load_elf(file_name: &Path) -> ElfProgram { } // Sort text sections by address and flatten them. - lifted_text_sections.sort_by_key(|insns| insns[0].original_address); + lifted_text_sections.sort_by_key(|insns| insns[0].loc.address); let lifted_text_sections = lifted_text_sections .into_iter() .flatten() .collect::>(); - let symbol_table = SymbolTable::new(&elf); + // Try loading the debug information. + let debug_info = match debug_info::DebugInfo::new( + &elf, + &file_buffer, + &address_map, + &|key| data_map.contains_key(&key), + &referenced_text_addrs, + ) { + Ok(debug_info) => { + log::info!("Debug information loaded successfully."); + debug_info + } + Err(err) => { + match err { + debug_info::Error::NoDebugInfo => { + log::info!("No DWARF debug information found.") + } + err => { + log::warn!("Error reading DWARF debug information: {}", err) + } + } + log::info!("Falling back to using ELF symbol table."); + + DebugInfo { + symbols: SymbolTable::new(&elf), + ..Default::default() + } + } + }; ElfProgram { - symbol_table, + dbg: debug_info, data_map, text_labels: referenced_text_addrs, instructions: lifted_text_sections, @@ -235,60 +265,39 @@ fn static_relocate_data_sections( } } -/// Index the symbols by their addresses. -struct SymbolTable(HashMap); - -impl SymbolTable { - fn new(elf: &Elf) -> SymbolTable { - let mut deduplicator = HashMap::new(); - for sym in elf.syms.iter() { - // We only care about global symbols that have string names, and are - // either functions or variables. - if sym.st_name == 0 || (sym.st_type() != STT_OBJECT && sym.st_type() != STT_FUNC) { - continue; - } - deduplicator.insert(elf.strtab[sym.st_name].to_string(), sym.st_value as u32); - } - - Self( - deduplicator - .into_iter() - .map(|(name, addr)| (addr, name)) - .collect(), - ) - } - - /// Get the symbol if the address had one. - fn try_get(&self, addr: u32) -> Option<&str> { - self.0.get(&addr).map(|name| name.as_str()) - } - - /// Get the symbol or a default label formed from the address value. - fn get(&self, addr: u32) -> Cow { - self.0 - .get(&addr) - .map(|name| Cow::Borrowed(name.as_str())) - .unwrap_or_else(|| Cow::Owned(format!("L{addr:08x}"))) - } -} - impl RiscVProgram for ElfProgram { fn take_source_files_info(&mut self) -> impl Iterator { - // TODO: read the source files from the debug information. - std::iter::empty() + self.dbg + .file_list + .iter() + .enumerate() + .map(|(id, (dir, file))| crate::code_gen::SourceFileInfo { + // +1 because files are indexed from 1 + id: id as u32 + 1, + file, + dir, + }) } fn take_initial_mem(&mut self) -> impl Iterator { self.data_map.iter().map(|(addr, data)| { let value = match data { Data::TextLabel(label) => { - SingleDataValue::LabelReference(self.symbol_table.get(*label).into()) + SingleDataValue::LabelReference(self.dbg.symbols.get_one(*label).into()) } Data::Value(value) => SingleDataValue::Value(*value), }; + let label = self + .dbg + .notes + .get(addr) + .map(|note| note.as_str()) + .or_else(|| self.dbg.symbols.try_get_one(*addr)) + .map(|s| s.to_string()); + MemEntry { - label: self.symbol_table.try_get(*addr).map(|s| s.to_string()), + label, addr: *addr, value, } @@ -298,31 +307,71 @@ impl RiscVProgram for ElfProgram { fn take_executable_statements( &mut self, ) -> impl Iterator, WrappedArgs>> { - let labels = self.text_labels.iter(); + // In the output, the precedence is labels, locations, and then instructions. + // We merge the 3 iterators with this operations: merge(labels, merge(locs, instructions)), where each is sorted by address. + + // First the inner merge: locs and instructions. + let locs = self.dbg.source_locations.iter(); let instructions = self.instructions.iter(); + let locs_and_instructions = locs + .map(|loc| (Cell::new(0), loc)) + .merge_join_by(instructions, |next_loc, next_insn| { + assert!( + next_loc.1.address >= next_insn.loc.address, + "Debug location {:08x} doesn't match instruction address!", + next_loc.1.address + ); + if next_loc.1.address < next_insn.loc.address + next_insn.loc.size { + next_loc.0.set(next_insn.loc.address); + true + } else { + false + } + }) + .map(|result| match result { + // Extract the address from the Either, for easier comparison in the next step. + Either::Left((address, loc)) => (address.get(), Either::Left(loc)), + Either::Right(insn) => (insn.loc.address, Either::Right(insn)), + }); + // Now the outer merge: labels and locs_and_instructions. + let labels = self.text_labels.iter(); labels - .merge_join_by(instructions, |&&next_label, next_insn| { - match next_label.cmp(&next_insn.original_address) { - Ordering::Less => panic!("Label {next_label:08x} doesn't match exact address!"), + .merge_join_by( + locs_and_instructions, + |&label_addr, (right_addr, _)| match label_addr.cmp(right_addr) { + Ordering::Less => panic!("Label {label_addr:08x} doesn't match exact address!"), Ordering::Equal => true, Ordering::Greater => false, - } - }) - .map(|result| match result { - Either::Left(label) => Statement::Label(self.symbol_table.get(*label)), - Either::Right(insn) => Statement::Instruction { - op: insn.op, - args: WrappedArgs { - args: &insn.args, - symbol_table: &self.symbol_table, - }, }, + ) + .flat_map(|result| -> Box> { + match result { + Either::Left(label) => { + Box::new(self.dbg.symbols.get_all(*label).map(Statement::Label)) + } + Either::Right((_, Either::Left(loc))) => { + Box::new(std::iter::once(Statement::DebugLoc { + file: loc.file, + line: loc.line, + col: loc.col, + })) + } + Either::Right((_, Either::Right(insn))) => { + Box::new(std::iter::once(Statement::Instruction { + op: insn.op, + args: WrappedArgs { + args: &insn.args, + symbol_table: &self.dbg.symbols, + }, + })) + } + } }) } fn start_function(&self) -> Cow { - self.symbol_table.get(self.entry_point) + self.dbg.symbols.get_one(self.entry_point) } } @@ -343,7 +392,7 @@ impl<'a> InstructionArgs for WrappedArgs<'a> { rd: None, rs1: None, rs2: None, - } => Ok(self.symbol_table.get(*addr).into()), + } => Ok(self.symbol_table.get_one(*addr).into()), _ => Err(format!("Expected: label, got {:?}", self.args)), } } @@ -442,7 +491,7 @@ impl<'a> InstructionArgs for WrappedArgs<'a> { } => Ok(( Register::new(*rs1 as u8), Register::new(*rs2 as u8), - self.symbol_table.get(*addr).into(), + self.symbol_table.get_one(*addr).into(), )), _ => Err(format!("Expected: rs1, rs2, label, got {:?}", self.args)), } @@ -457,7 +506,7 @@ impl<'a> InstructionArgs for WrappedArgs<'a> { rs2: None, } => Ok(( Register::new(*rs1 as u8), - self.symbol_table.get(*addr).into(), + self.symbol_table.get_one(*addr).into(), )), HighLevelArgs { imm: HighLevelImmediate::CodeLabel(addr), @@ -466,7 +515,7 @@ impl<'a> InstructionArgs for WrappedArgs<'a> { rs2: None, } => Ok(( Register::new(*rd as u8), - self.symbol_table.get(*addr).into(), + self.symbol_table.get_one(*addr).into(), )), _ => Err(format!("Expected: {{rs1|rd}}, label, got {:?}", self.args)), } @@ -569,9 +618,29 @@ fn load_data_section(mut addr: u32, data: &[u8], data_map: &mut BTreeMap u32 { + match self { + UnimpOrInstruction::Unimp16 => 2, + UnimpOrInstruction::_Unimp32 => 4, + UnimpOrInstruction::Instruction(ins) => match ins.extension { + Extensions::C => 2, + _ => 4, + }, + } + } +} + struct MaybeInstruction { address: u32, - insn: Option, + insn: UnimpOrInstruction, } #[derive(Debug)] @@ -601,8 +670,13 @@ impl Default for HighLevelArgs { } } +struct Location { + address: u32, + size: u32, +} + struct HighLevelInsn { - original_address: u32, + loc: Location, op: &'static str, args: HighLevelArgs, } @@ -674,9 +748,14 @@ impl TwoOrOneMapper for InstructionLifter<'_> { insn1: &MaybeInstruction, insn2: &MaybeInstruction, ) -> Option { - let original_address = insn1.address; + use UnimpOrInstruction::Instruction as I; + + let loc = Location { + address: insn1.address, + size: insn1.insn.len() + insn2.insn.len(), + }; let insn2_addr = insn2.address; - let (Some(insn1), Some(insn2)) = (&insn1.insn, &insn2.insn) else { + let (I(insn1), I(insn2)) = (&insn1.insn, &insn2.insn) else { return None; }; @@ -701,15 +780,11 @@ impl TwoOrOneMapper for InstructionLifter<'_> { // to load an address into a register. We must check if this is // the case, and if the address points to a text section, we // must load it from a label. - let is_address = self.rellocs_set.contains(&original_address); + let is_address = self.rellocs_set.contains(&loc.address); let (op, args) = self.composed_immediate(*hi, *lo, *rd_lui, *rd_addi, insn2_addr, is_address)?; - HighLevelInsn { - op, - args, - original_address, - } + HighLevelInsn { op, args, loc } } ( // All other double instructions we can lift start with auipc. @@ -721,7 +796,7 @@ impl TwoOrOneMapper for InstructionLifter<'_> { }, insn2, ) => { - let hi = hi.wrapping_add(original_address as i32); + let hi = hi.wrapping_add(loc.address as i32); match insn2 { // la rd, symbol Ins { @@ -737,11 +812,7 @@ impl TwoOrOneMapper for InstructionLifter<'_> { hi, *lo, *rd_auipc, *rd_addi, insn2_addr, IS_ADDRESS, )?; - HighLevelInsn { - op, - args, - original_address, - } + HighLevelInsn { op, args, loc } } // l{b|h|w} rd, symbol Ins { @@ -768,7 +839,7 @@ impl TwoOrOneMapper for InstructionLifter<'_> { imm: HighLevelImmediate::Value(addr), ..Default::default() }, - original_address, + loc, } } // s{b|h|w} rd, symbol, rt @@ -804,7 +875,7 @@ impl TwoOrOneMapper for InstructionLifter<'_> { imm: HighLevelImmediate::CodeLabel(hi.wrapping_add(*lo) as u32), ..Default::default() }, - original_address, + loc, }, // tail offset Ins { @@ -820,11 +891,12 @@ impl TwoOrOneMapper for InstructionLifter<'_> { imm: HighLevelImmediate::CodeLabel(hi.wrapping_add(*lo) as u32), ..Default::default() }, - original_address, + loc, }, _ => { panic!( - "Unexpected instruction after AUIPC: {insn2:?} at {original_address:08x}" + "Unexpected instruction after AUIPC: {insn2:?} at {:08x}", + loc.address ); } } @@ -844,19 +916,22 @@ impl TwoOrOneMapper for InstructionLifter<'_> { } fn map_one(&mut self, insn: MaybeInstruction) -> HighLevelInsn { - let original_address = insn.address; - let Some(insn) = insn.insn else { + let loc = Location { + address: insn.address, + size: insn.insn.len(), + }; + let UnimpOrInstruction::Instruction(insn) = insn.insn else { return HighLevelInsn { op: "unimp", args: Default::default(), - original_address, + loc, }; }; let mut imm = match insn.opc { // All jump instructions that have an address as immediate Op::JAL | Op::BEQ | Op::BNE | Op::BLT | Op::BGE | Op::BLTU | Op::BGEU => { - let addr = (insn.imm.unwrap() + original_address as i32) as u32; + let addr = (insn.imm.unwrap() + loc.address as i32) as u32; if let ReadOrWrite::Write(refs) = &mut self.referenced_text_addrs { refs.insert(addr); } @@ -884,11 +959,11 @@ impl TwoOrOneMapper for InstructionLifter<'_> { args: HighLevelArgs { rd: insn.rd.map(|x| x as u32), imm: HighLevelImmediate::Value( - insn.imm.unwrap().wrapping_add(original_address as i32), + insn.imm.unwrap().wrapping_add(loc.address as i32), ), ..Default::default() }, - original_address, + loc, }; } // All other instructions, which have the immediate as a value @@ -917,7 +992,7 @@ impl TwoOrOneMapper for InstructionLifter<'_> { rs2: insn.rs2.map(|x| x as u32), imm, }, - original_address, + loc, } } } @@ -1009,7 +1084,7 @@ impl Iterator for RiscVInstructionIterator<'_> { maybe_insn = MaybeInstruction { address: self.curr_address, - insn: Some(insn), + insn: UnimpOrInstruction::Instruction(insn), }; } else { // 16 bits @@ -1022,7 +1097,7 @@ impl Iterator for RiscVInstructionIterator<'_> { maybe_insn = MaybeInstruction { address: self.curr_address, insn: match bin_instruction.decode(Isa::Rv32) { - Ok(c_insn) => Some(to_32bit_equivalent(c_insn)), + Ok(c_insn) => UnimpOrInstruction::Instruction(to_32bit_equivalent(c_insn)), Err(raki::decode::DecodingError::IllegalInstruction) => { // Although not a real RISC-V instruction, sometimes 0x0000 // is used on purpose as an illegal instruction (it even has @@ -1036,7 +1111,7 @@ impl Iterator for RiscVInstructionIterator<'_> { "Failed to decode 16-bit instruction at {:08x}", self.curr_address ); - None + UnimpOrInstruction::Unimp16 } Err(err) => panic!( "Unexpected decoding error at {:08x}: {err:?}",