From 0c6ac7d2c256f1e3f26b5edd03d3edc384cfbffd Mon Sep 17 00:00:00 2001 From: Champii1 Date: Mon, 12 Feb 2024 11:58:20 +0100 Subject: [PATCH] Add serialization for Analyzed and pipeline resume Fix pipeline name from file name with suffix Variable naming Replace `splitted` with `split` Co-authored-by: Leo Fix clippy Add test for serde of PIL Changed the optimized PIL file extension to .pilo Changed function names to reflect the operation on pil object --- ast/Cargo.toml | 1 + ast/src/analyzed/mod.rs | 35 ++++++++++++++++++----------------- ast/src/analyzed/types.rs | 11 ++++++----- ast/src/lib.rs | 3 ++- ast/src/parsed/mod.rs | 23 ++++++++++++----------- cli/src/main.rs | 2 +- number/Cargo.toml | 3 +++ number/src/bn254.rs | 2 ++ number/src/goldilocks.rs | 1 + number/src/macros.rs | 18 +++++++++++++++++- number/src/serialize.rs | 23 +++++++++++++++++++++++ number/src/traits.rs | 3 +++ pipeline/src/pipeline.rs | 37 +++++++++++++++++++++++++++++++++++++ pipeline/tests/pil.rs | 20 ++++++++++++++++++++ 14 files changed, 146 insertions(+), 36 deletions(-) diff --git a/ast/Cargo.toml b/ast/Cargo.toml index 4cc12aa5d..9101ddc14 100644 --- a/ast/Cargo.toml +++ b/ast/Cargo.toml @@ -16,6 +16,7 @@ num-traits = "0.2.15" diff = "0.1" log = "0.4.18" derive_more = "0.99.17" +serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } [dev-dependencies] pretty_assertions = "1.3.0" diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 8c0e3f5de..0991330cc 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -8,6 +8,7 @@ use std::fmt::Display; use std::ops::{self, ControlFlow}; use powdr_number::{DegreeType, FieldElement}; +use serde::{Deserialize, Serialize}; use crate::parsed::utils::expr_any; use crate::parsed::visitor::ExpressionVisitable; @@ -18,7 +19,7 @@ use crate::SourceRef; use self::types::TypedExpression; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum StatementIdentifier { /// Either an intermediate column or a definition. Definition(String), @@ -27,7 +28,7 @@ pub enum StatementIdentifier { Identity(usize), } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Analyzed { /// The degree of all namespaces, which must match. If there are no namespaces, then `None`. pub degree: Option, @@ -404,7 +405,7 @@ fn inlined_expression_from_intermediate_poly_id( expr } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Symbol { pub id: u64, pub source: SourceRef, @@ -461,7 +462,7 @@ impl Symbol { /// The "kind" of a symbol. In the future, this will be mostly /// replaced by its type. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum SymbolKind { /// Fixed, witness or intermediate polynomial Poly(PolynomialType), @@ -472,7 +473,7 @@ pub enum SymbolKind { Other(), } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum FunctionValueDefinition { Array(Vec>), Query(Expression), @@ -480,7 +481,7 @@ pub enum FunctionValueDefinition { } /// An array of elements that might be repeated. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct RepeatedArray { /// The pattern to be repeated pattern: Vec>, @@ -520,7 +521,7 @@ impl RepeatedArray { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PublicDeclaration { pub id: u64, pub source: SourceRef, @@ -540,7 +541,7 @@ impl PublicDeclaration { } } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct Identity { /// The ID is specific to the identity kind. pub id: u64, @@ -579,7 +580,7 @@ impl Identity> { } } -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize)] pub enum IdentityKind { Polynomial, Plookup, @@ -600,13 +601,13 @@ impl SelectedExpressions> { pub type Expression = parsed::Expression; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Reference { LocalVar(u64, String), Poly(PolynomialReference), } -#[derive(Debug, Clone, Eq)] +#[derive(Debug, Clone, Eq, Serialize, Deserialize)] pub struct AlgebraicReference { /// Name of the polynomial - just for informational purposes. /// Comparisons are based on polynomial ID. @@ -653,7 +654,7 @@ impl Hash for AlgebraicReference { self.next.hash(state); } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub enum AlgebraicExpression { Reference(AlgebraicReference), PublicReference(String), @@ -667,7 +668,7 @@ pub enum AlgebraicExpression { UnaryOperation(AlgebraicUnaryOperator, Box>), } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize)] pub enum AlgebraicBinaryOperator { Add, Sub, @@ -703,7 +704,7 @@ impl TryFrom for AlgebraicBinaryOperator { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize)] pub enum AlgebraicUnaryOperator { Minus, } @@ -790,7 +791,7 @@ impl From for AlgebraicExpression { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PolynomialReference { /// Name of the polynomial - just for informational purposes. /// Comparisons are based on polynomial ID. @@ -801,7 +802,7 @@ pub struct PolynomialReference { pub poly_id: Option, } -#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct PolyID { pub id: u64, pub ptype: PolynomialType, @@ -819,7 +820,7 @@ impl From<&Symbol> for PolyID { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum PolynomialType { Committed, Constant, diff --git a/ast/src/analyzed/types.rs b/ast/src/analyzed/types.rs index c6eec9587..50b312385 100644 --- a/ast/src/analyzed/types.rs +++ b/ast/src/analyzed/types.rs @@ -1,18 +1,19 @@ use std::fmt::Display; use powdr_number::FieldElement; +use serde::{Deserialize, Serialize}; use crate::parsed::{ArrayTypeName, Expression, FunctionTypeName, TupleTypeName, TypeName}; use super::Reference; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct TypedExpression { pub e: Expression, pub ty: Option, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub enum Type { /// Boolean Bool, @@ -71,7 +72,7 @@ impl From>> for Type } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct ArrayType { pub base: Box, pub length: Option, @@ -95,7 +96,7 @@ impl From>> for } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct TupleType { pub items: Vec, } @@ -108,7 +109,7 @@ impl From>> for } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct FunctionType { pub params: Vec, pub value: Box, diff --git a/ast/src/lib.rs b/ast/src/lib.rs index 3eee0b59d..c6c48414c 100644 --- a/ast/src/lib.rs +++ b/ast/src/lib.rs @@ -2,6 +2,7 @@ use itertools::Itertools; use log::log_enabled; +use serde::{Deserialize, Serialize}; use std::fmt::{Display, Result, Write}; use std::sync::Arc; @@ -21,7 +22,7 @@ pub struct DiffMonitor { current: Option, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct SourceRef { pub file: Option>, pub line: usize, diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index e3d060730..9cbefe2e3 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -11,6 +11,7 @@ use std::{ }; use powdr_number::{DegreeType, FieldElement}; +use serde::{Deserialize, Serialize}; use self::asm::{Part, SymbolPath}; use crate::SourceRef; @@ -151,7 +152,7 @@ impl PilStatement { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct SelectedExpressions { pub selector: Option, pub expressions: Vec, @@ -178,7 +179,7 @@ impl SelectedExpressions { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub enum Expression { Reference(Ref), PublicReference(String), @@ -292,18 +293,18 @@ impl NamespacedPolynomialReference { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct LambdaExpression { pub params: Vec, pub body: Box>, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct ArrayLiteral { pub items: Vec>, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize)] pub enum UnaryOperator { Minus, LogicalNot, @@ -320,7 +321,7 @@ impl UnaryOperator { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize)] pub enum BinaryOperator { Add, Sub, @@ -344,32 +345,32 @@ pub enum BinaryOperator { Greater, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct IndexAccess { pub array: Box>, pub index: Box>, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct FunctionCall { pub function: Box>, pub arguments: Vec>, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct MatchArm { pub pattern: MatchPattern, pub value: Expression, } /// A pattern for a match arm. We could extend this in the future. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub enum MatchPattern { CatchAll, Pattern(Expression), } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize)] pub struct IfExpression { pub condition: Box>, pub body: Box>, diff --git a/cli/src/main.rs b/cli/src/main.rs index 18331d1e8..7b0b9396f 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -812,7 +812,7 @@ fn read_and_prove( params: Option, ) -> Result<(), Vec> { Pipeline::::default() - .from_file(file.to_path_buf()) + .from_maybe_pil_object(file.to_path_buf()) .with_output(dir.to_path_buf(), true) .read_generated_witness(dir) .with_setup_file(params.map(PathBuf::from)) diff --git a/number/Cargo.toml b/number/Cargo.toml index 2cdb2e38b..05c7657fe 100644 --- a/number/Cargo.toml +++ b/number/Cargo.toml @@ -12,9 +12,12 @@ ark-bn254 = { version = "0.4.0", default-features = false, features = [ "scalar_field", ] } ark-ff = "0.4.2" +ark-serialize = "0.4.2" num-bigint = "0.4.3" num-traits = "0.2.15" csv = "1.3" +serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] } +serde_with = "3.6.1" [dev-dependencies] test-log = "0.2.12" diff --git a/number/src/bn254.rs b/number/src/bn254.rs index 5bbb382e8..a73575d8c 100644 --- a/number/src/bn254.rs +++ b/number/src/bn254.rs @@ -1,4 +1,6 @@ use ark_bn254::Fr; +use serde::{Deserialize, Serialize}; + powdr_field!(Bn254Field, Fr); #[cfg(test)] diff --git a/number/src/goldilocks.rs b/number/src/goldilocks.rs index f1ab60e6a..87d84c188 100644 --- a/number/src/goldilocks.rs +++ b/number/src/goldilocks.rs @@ -1,4 +1,5 @@ use ark_ff::{Fp64, MontBackend, MontConfig}; +use serde::{Deserialize, Serialize}; #[derive(MontConfig)] #[modulus = "18446744069414584321"] diff --git a/number/src/macros.rs b/number/src/macros.rs index 457a79112..fd55be65b 100644 --- a/number/src/macros.rs +++ b/number/src/macros.rs @@ -11,8 +11,24 @@ macro_rules! powdr_field { use std::ops::*; use std::str::FromStr; - #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, PartialOrd, Ord, Hash)] + #[derive( + Clone, + Copy, + PartialEq, + Eq, + Debug, + Default, + PartialOrd, + Ord, + Hash, + Serialize, + Deserialize, + )] pub struct $name { + #[serde( + serialize_with = "crate::serialize::ark_se", + deserialize_with = "crate::serialize::ark_de" + )] value: $ark_type, } diff --git a/number/src/serialize.rs b/number/src/serialize.rs index c5b1c6b53..14567feaa 100644 --- a/number/src/serialize.rs +++ b/number/src/serialize.rs @@ -1,6 +1,8 @@ use std::io::{Read, Write}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use csv::{Reader, Writer}; +use serde_with::{DeserializeAs, SerializeAs}; use crate::{DegreeType, FieldElement}; @@ -140,6 +142,27 @@ pub fn read_polys_file( } } +// Serde wrappers for serialize/deserialize + +pub fn ark_se(a: &A, s: S) -> Result +where + S: serde::Serializer, +{ + let mut bytes = vec![]; + a.serialize_with_mode(&mut bytes, Compress::Yes) + .map_err(serde::ser::Error::custom)?; + serde_with::Bytes::serialize_as(&bytes, s) +} + +pub fn ark_de<'de, D, A: CanonicalDeserialize>(data: D) -> Result +where + D: serde::de::Deserializer<'de>, +{ + let s: Vec = serde_with::Bytes::deserialize_as(data)?; + let a = A::deserialize_with_mode(s.as_slice(), Compress::Yes, Validate::Yes); + a.map_err(serde::de::Error::custom) +} + #[cfg(test)] mod tests { use crate::Bn254Field; diff --git a/number/src/traits.rs b/number/src/traits.rs index 18027cb43..988c19b1f 100644 --- a/number/src/traits.rs +++ b/number/src/traits.rs @@ -1,6 +1,7 @@ use std::{fmt, hash::Hash, ops::*, str::FromStr}; use num_traits::{One, Zero}; +use serde::{de::DeserializeOwned, Serialize}; use crate::{AbstractNumberType, DegreeType}; @@ -85,6 +86,8 @@ pub trait FieldElement: + From + From + fmt::LowerHex + + Serialize + + DeserializeOwned { /// The underlying fixed-width integer type type Integer: BigInt; diff --git a/pipeline/src/pipeline.rs b/pipeline/src/pipeline.rs index 1287a5c2f..1bb684f74 100644 --- a/pipeline/src/pipeline.rs +++ b/pipeline/src/pipeline.rs @@ -408,6 +408,26 @@ impl Pipeline { } } + pub fn from_maybe_pil_object(self, file: PathBuf) -> Self { + if file.extension().unwrap() == "pilo" { + self.from_pil_object(file) + } else { + self.from_file(file) + } + } + + pub fn from_pil_object(self, pil_file: PathBuf) -> Self { + let name = self + .name + .or(Some(Self::name_from_path_with_suffix(&pil_file))); + let analyzed = serde_cbor::from_reader(fs::File::open(&pil_file).unwrap()).unwrap(); + Pipeline { + artifact: Some(Artifact::OptimzedPil(analyzed)), + name, + ..self + } + } + /// Reads previously generated constants from the provided directory and /// advances the pipeline to the `PilWithEvaluatedFixedCols` stage. pub fn read_constants(mut self, directory: &Path) -> Self { @@ -487,6 +507,14 @@ impl Pipeline { path.file_stem().unwrap().to_str().unwrap().to_string() } + // This is used for parsing file names than ends with '_{suffix}' + fn name_from_path_with_suffix(path: &Path) -> String { + let file_name = Self::name_from_path(path); + let mut split = file_name.split('_').collect::>(); + split.pop(); + split.join("_") + } + fn log(&self, msg: &str) { log::log!(self.log_level, "{}", msg); } @@ -568,6 +596,7 @@ impl Pipeline { self.log("Optimizing pil..."); let optimized = powdr_pilopt::optimize(analyzed_pil); self.maybe_write_pil(&optimized, "_opt")?; + self.maybe_write_pil_object(&optimized, "_opt")?; Artifact::OptimzedPil(optimized) } Artifact::OptimzedPil(pil) => { @@ -721,6 +750,14 @@ impl Pipeline { Ok(()) } + fn maybe_write_pil_object(&self, pil: &Analyzed, suffix: &str) -> Result<(), Vec> { + if let Some(path) = self.path_if_should_write(|name| format!("{name}{suffix}.pilo"))? { + serde_cbor::to_writer(&mut fs::File::create(&path).unwrap(), pil) + .map_err(|e| vec![format!("Error writing {}: {e}", path.to_str().unwrap())])?; + } + Ok(()) + } + fn maybe_write_constants(&self, constants: &[(String, Vec)]) -> Result<(), Vec> { if let Some(path) = self.path_if_should_write(|name| format!("{name}_constants.bin"))? { let writer = BufWriter::new(fs::File::create(path).unwrap()); diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 35e9d4beb..05550395f 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -243,6 +243,26 @@ fn referencing_arrays() { gen_estark_proof(f, Default::default()); } +#[test] +fn serialize_deserialize_optimized_pil() { + let f = "pil/fibonacci.pil"; + let path = powdr_pipeline::test_util::resolve_test_file(f); + + let optimized = powdr_pipeline::Pipeline::::default() + .from_file(path) + .optimized_pil() + .unwrap(); + + let optimized_serialized = serde_cbor::to_vec(&optimized).unwrap(); + let optimized_deserialized: powdr_ast::analyzed::Analyzed = + serde_cbor::from_slice(&optimized_serialized[..]).unwrap(); + + let input_pil_file = format!("{}", optimized); + let output_pil_file = format!("{}", optimized_deserialized); + + assert_eq!(input_pil_file, output_pil_file); +} + mod book { use super::*; use test_log::test;