Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fixed_cols_in_functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Jul 14, 2024
2 parents 36d2faf + e1169a7 commit dd3b250
Show file tree
Hide file tree
Showing 19 changed files with 328 additions and 195 deletions.
1 change: 1 addition & 0 deletions backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ strum = { version = "0.24.1", features = ["derive"] }
log = "0.4.17"
serde = "1.0"
serde_json = "1.0"
bincode = "1.3.3"
hex = "0.4"
thiserror = "1.0.43"
mktemp = "0.5.0"
Expand Down
117 changes: 87 additions & 30 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use std::{collections::BTreeMap, io, marker::PhantomData, path::PathBuf, sync::Arc};

use std::{
collections::BTreeMap,
io::{self, Cursor, Read},
marker::PhantomData,
path::PathBuf,
sync::Arc,
};

use itertools::Itertools;
use powdr_ast::analyzed::Analyzed;
use powdr_executor::witgen::WitgenCallback;
use powdr_number::{DegreeType, FieldElement};
Expand All @@ -10,11 +17,18 @@ use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};

mod split;

/// A composite proof that contains a proof for each machine separately.
/// A composite verification key that contains a verification key for each machine separately.
#[derive(Serialize, Deserialize)]
struct CompositeVerificationKey {
/// Verification key for each machine (if available, otherwise None), sorted by machine name.
verification_keys: Vec<Option<Vec<u8>>>,
}

/// A composite proof that contains a proof for each machine separately, sorted by machine name.
#[derive(Serialize, Deserialize)]
struct CompositeProof {
/// Map from machine name to proof
proofs: BTreeMap<String, Vec<u8>>,
proofs: Vec<Vec<u8>>,
}

pub(crate) struct CompositeBackendFactory<F: FieldElement, B: BackendFactory<F>> {
Expand Down Expand Up @@ -42,13 +56,40 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
verification_app_key: Option<&mut dyn std::io::Read>,
backend_options: BackendOptions,
) -> Result<Box<dyn Backend<'a, F> + 'a>, Error> {
if setup.is_some() || verification_key.is_some() || verification_app_key.is_some() {
if verification_app_key.is_some() {
unimplemented!();
}

let per_machine_data = split::split_pil((*pil).clone())
let pils = split::split_pil((*pil).clone());

// Read the setup once (if any) to pass to all backends.
let setup_bytes = setup.map(|setup| {
let mut setup_data = Vec::new();
setup.read_to_end(&mut setup_data).unwrap();
setup_data
});

// Read all provided verification keys
let verification_keys = verification_key
.map(|verification_key| bincode::deserialize_from(verification_key).unwrap())
.unwrap_or(CompositeVerificationKey {
verification_keys: vec![None; pils.len()],
})
.verification_keys;

let machine_data = pils
.into_iter()
.map(|(machine_name, pil)| {
.zip_eq(verification_keys.into_iter())
.map(|((machine_name, pil), verification_key)| {
// Set up readers for the setup and verification key
let mut setup_cursor = setup_bytes.as_ref().map(Cursor::new);
let setup = setup_cursor.as_mut().map(|cursor| cursor as &mut dyn Read);

let mut verification_key_cursor = verification_key.as_ref().map(Cursor::new);
let verification_key = verification_key_cursor
.as_mut()
.map(|cursor| cursor as &mut dyn Read);

let pil = Arc::new(pil);
let output_dir = output_dir
.clone()
Expand All @@ -61,22 +102,20 @@ impl<F: FieldElement, B: BackendFactory<F>> BackendFactory<F> for CompositeBacke
pil.clone(),
fixed,
output_dir,
// TODO: Handle setup, verification_key, verification_app_key
None,
None,
setup,
verification_key,
// TODO: Handle verification_app_key
None,
backend_options.clone(),
);
backend.map(|backend| (machine_name.to_string(), MachineData { pil, backend }))
})
.collect::<Result<BTreeMap<_, _>, _>>()?;
Ok(Box::new(CompositeBackend {
machine_data: per_machine_data,
}))
.collect::<Result<_, _>>()?;
Ok(Box::new(CompositeBackend { machine_data }))
}

fn generate_setup(&self, _size: DegreeType, _output: &mut dyn io::Write) -> Result<(), Error> {
Err(Error::NoSetupAvailable)
fn generate_setup(&self, size: DegreeType, output: &mut dyn io::Write) -> Result<(), Error> {
self.factory.generate_setup(size, output)
}
}

Expand All @@ -86,6 +125,9 @@ struct MachineData<'a, F> {
}

pub(crate) struct CompositeBackend<'a, F> {
/// Maps each machine name to the corresponding machine data
/// Note that it is essential that we use BTreeMap here to ensure that the machines are
/// deterministically ordered.
machine_data: BTreeMap<String, MachineData<'a, F>>,
}

Expand Down Expand Up @@ -117,33 +159,48 @@ impl<'a, F: FieldElement> Backend<'a, F> for CompositeBackend<'a, F> {

let witness = machine_witness_columns(witness, pil, machine);

backend
.prove(&witness, None, witgen_callback)
.map(|proof| (machine.clone(), proof))
backend.prove(&witness, None, witgen_callback)
})
.collect::<Result<_, _>>()?,
};
Ok(serde_json::to_vec(&proof).unwrap())
Ok(bincode::serialize(&proof).unwrap())
}

fn verify(&self, proof: &[u8], instances: &[Vec<F>]) -> Result<(), Error> {
let proof: CompositeProof = serde_json::from_slice(proof).unwrap();
for (machine, machine_proof) in proof.proofs {
let machine_data = self
.machine_data
.get(&machine)
.ok_or_else(|| Error::BackendError(format!("Unknown machine: {machine}")))?;
let proof: CompositeProof = bincode::deserialize(proof).unwrap();
for (machine_data, machine_proof) in self.machine_data.values().zip_eq(proof.proofs) {
machine_data.backend.verify(&machine_proof, instances)?;
}
Ok(())
}

fn export_setup(&self, _output: &mut dyn io::Write) -> Result<(), Error> {
unimplemented!()
fn export_setup(&self, output: &mut dyn io::Write) -> Result<(), Error> {
// All backend are the same, just pick the first
self.machine_data
.values()
.next()
.unwrap()
.backend
.export_setup(output)
}

fn export_verification_key(&self, _output: &mut dyn io::Write) -> Result<(), Error> {
unimplemented!();
fn verification_key_bytes(&self) -> Result<Vec<u8>, Error> {
let verification_key = CompositeVerificationKey {
verification_keys: self
.machine_data
.values()
.map(|machine_data| {
let backend = machine_data.backend.as_ref();
let vk_bytes = backend.verification_key_bytes();
match vk_bytes {
Ok(vk_bytes) => Ok(Some(vk_bytes)),
Err(Error::NoVerificationAvailable) => Ok(None),
Err(e) => Err(e),
}
})
.collect::<Result<_, _>>()?,
};
Ok(bincode::serialize(&verification_key).unwrap())
}

fn export_ethereum_verifier(&self, _output: &mut dyn io::Write) -> Result<(), Error> {
Expand Down
11 changes: 3 additions & 8 deletions backend/src/estark/starky_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::io;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
Expand Down Expand Up @@ -212,13 +211,9 @@ impl<'a, F: FieldElement> Backend<'a, F> for EStark<F> {
}
}

fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> {
match serde_json::to_writer(output, &self.setup) {
Ok(_) => Ok(()),
Err(_) => Err(Error::BackendError(
"Could not export verification key".to_string(),
)),
}
fn verification_key_bytes(&self) -> Result<Vec<u8>, Error> {
serde_json::to_vec(&self.setup)
.map_err(|_| Error::BackendError("Could not serialize verification key".to_string()))
}
}

Expand Down
6 changes: 2 additions & 4 deletions backend/src/halo2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,9 @@ impl<'a, T: FieldElement> Backend<'a, T> for Halo2Prover<T> {
Ok(self.write_setup(&mut output)?)
}

fn export_verification_key(&self, mut output: &mut dyn io::Write) -> Result<(), Error> {
fn verification_key_bytes(&self) -> Result<Vec<u8>, Error> {
let vk = self.verification_key()?;
vk.write(&mut output, SerdeFormat::Processed)?;

Ok(())
Ok(vk.to_bytes(SerdeFormat::Processed))
}

fn export_ethereum_verifier(&self, output: &mut dyn io::Write) -> Result<(), Error> {
Expand Down
11 changes: 10 additions & 1 deletion backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,16 @@ pub trait Backend<'a, F: FieldElement> {

/// Exports the verification key in a backend specific format. Can be used
/// to create a new backend object of the same kind.
fn export_verification_key(&self, _output: &mut dyn io::Write) -> Result<(), Error> {
fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> {
let v = self.verification_key_bytes()?;
log::info!("Verification key size: {} bytes", v.len());
output
.write_all(&v)
.map_err(|_| Error::BackendError("Could not write verification key".to_string()))?;
Ok(())
}

fn verification_key_bytes(&self) -> Result<Vec<u8>, Error> {
Err(Error::NoVerificationAvailable)
}

Expand Down
20 changes: 20 additions & 0 deletions executor/src/constant_evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,4 +609,24 @@ mod test {
("F.a".to_string(), convert([14, 15, 16, 17].to_vec()))
);
}

#[test]
fn do_not_add_constraint_for_empty_tuple() {
let input = r#"namespace N(4);
let f: -> () = || ();
let g: col = |i| {
// This returns an empty tuple, we check that this does not lead to
// a call to add_constraints()
f();
i
};
"#;
let analyzed = analyze_string::<GoldilocksField>(input);
assert_eq!(analyzed.degree(), 4);
let constants = generate(&analyzed);
assert_eq!(
constants[0],
("N.g".to_string(), convert([0, 1, 2, 3].to_vec()))
);
}
}
5 changes: 4 additions & 1 deletion pil-analyzer/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,10 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> {
}
Operation::AddConstraint => {
let result = self.value_stack.pop().unwrap();
self.symbols.add_constraints(result, SourceRef::unknown())?;
match result.as_ref() {
Value::Tuple(t) if t.is_empty() => {}
_ => self.symbols.add_constraints(result, SourceRef::unknown())?,
}
}
};
}
Expand Down
40 changes: 34 additions & 6 deletions pipeline/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ use std::{
io::{self, BufReader},
marker::Send,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::Instant,
};

use log::Level;
use mktemp::Temp;
use powdr_ast::{
analyzed::Analyzed,
asm_analysis::AnalysisASMFile,
Expand Down Expand Up @@ -114,6 +116,10 @@ pub struct Pipeline<T: FieldElement> {
artifact: Artifacts<T>,
/// Output directory for intermediate files. If None, no files are written.
output_dir: Option<PathBuf>,
/// The temporary directory, owned by the pipeline (or any copies of it).
/// This object is not used directly, but keeping it here ensures that the directory
/// is not deleted until the pipeline is dropped.
_tmp_dir: Option<Rc<Temp>>,
/// The name of the pipeline. Used to name output files.
name: Option<String>,
/// Whether to overwrite existing files. If false, an error is returned if a file
Expand All @@ -140,6 +146,7 @@ where
Pipeline {
artifact: Default::default(),
output_dir: None,
_tmp_dir: None,
log_level: Level::Info,
name: None,
force_overwrite: false,
Expand Down Expand Up @@ -190,12 +197,14 @@ where
/// let proof = pipeline.compute_proof().unwrap();
/// ```
impl<T: FieldElement> Pipeline<T> {
/// Initializes the output directory to a temporary directory.
/// Note that the user is responsible for keeping the temporary directory alive.
pub fn with_tmp_output(self, tmp_dir: &mktemp::Temp) -> Self {
/// Initializes the output directory to a temporary directory which lives as long
/// the pipeline does.
pub fn with_tmp_output(self) -> Self {
let tmp_dir = Rc::new(mktemp::Temp::new_dir().unwrap());
Pipeline {
output_dir: Some(tmp_dir.to_path_buf()),
force_overwrite: true,
_tmp_dir: Some(tmp_dir),
..self
}
}
Expand Down Expand Up @@ -795,8 +804,11 @@ impl<T: FieldElement> Pipeline<T> {

let start = Instant::now();
let fixed_cols = constant_evaluator::generate(&pil);
self.log(&format!(
"Fixed column generation took {}s",
start.elapsed().as_secs_f32()
));
self.maybe_write_constants(&fixed_cols)?;
self.log(&format!("Took {}", start.elapsed().as_secs_f32()));

self.artifact.fixed_cols = Some(Arc::new(fixed_cols));

Expand Down Expand Up @@ -830,7 +842,10 @@ impl<T: FieldElement> Pipeline<T> {
.with_external_witness_values(&external_witness_values)
.generate();

self.log(&format!("Took {}", start.elapsed().as_secs_f32()));
self.log(&format!(
"Witness generation took {}s",
start.elapsed().as_secs_f32()
));

self.maybe_write_witness(&fixed_cols, &witness)?;

Expand Down Expand Up @@ -914,13 +929,19 @@ impl<T: FieldElement> Pipeline<T> {
.as_ref()
.map(|path| fs::read(path).unwrap());

let start = Instant::now();
let proof = match backend.prove(&witness, existing_proof, witgen_callback) {
Ok(proof) => proof,
Err(powdr_backend::Error::BackendError(e)) => {
return Err(vec![e.to_string()]);
}
Err(e) => panic!("{}", e),
};
self.log(&format!(
"Proof generation took {}s",
start.elapsed().as_secs_f32()
));
self.log(&format!("Proof size: {} bytes", proof.len()));

drop(backend);

Expand Down Expand Up @@ -1096,8 +1117,15 @@ impl<T: FieldElement> Pipeline<T> {
)
.unwrap();

let start = Instant::now();
match backend.verify(proof, instances) {
Ok(_) => Ok(()),
Ok(_) => {
self.log(&format!(
"Verification took {}s",
start.elapsed().as_secs_f32()
));
Ok(())
}
Err(powdr_backend::Error::BackendError(e)) => Err(vec![e]),
_ => panic!(),
}
Expand Down
Loading

0 comments on commit dd3b250

Please sign in to comment.