Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated plonky3 prover #1857

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,23 +335,23 @@ impl<T> Analyzed<T> {
.for_each(|definition| definition.post_visit_expressions_mut(f))
}

/// Retrieves (col_name, poly_id, offset) of each public witness in the trace.
pub fn get_publics(&self) -> Vec<(String, PolyID, usize)> {
/// Retrieves (col_name, poly_id, offset, stage) of each public witness in the trace.
pub fn get_publics(&self) -> Vec<(String, PolyID, usize, u32)> {
let mut publics = self
.public_declarations
.values()
.map(|public_declaration| {
let column_name = public_declaration.referenced_poly_name();
let poly_id = {
let (poly_id, stage) = {
let symbol = &self.definitions[&public_declaration.polynomial.name].0;
symbol
(symbol
.array_elements()
.nth(public_declaration.array_index.unwrap_or_default())
.unwrap()
.1
.1, symbol.stage.unwrap_or_default())
};
let row_offset = public_declaration.index as usize;
(column_name, poly_id, row_offset)
(column_name, poly_id, row_offset, stage)
})
.collect::<Vec<_>>();

Expand Down
6 changes: 3 additions & 3 deletions number/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ ark-bn254 = { version = "0.4.0", default-features = false, features = [
] }
ark-ff = "0.4.2"
ark-serialize = "0.4.2"
p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-mersenne-31 = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-baby-bear = { git = "https://github.com/plonky3/Plonky3.git" }
p3-mersenne-31 = { git = "https://github.com/plonky3/Plonky3.git" }
p3-field = { git = "https://github.com/plonky3/Plonky3.git" }
num-bigint = { version = "0.4.3", features = ["serde"] }
num-traits = "0.2.15"
csv = "1.3"
Expand Down
41 changes: 21 additions & 20 deletions plonky3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,38 @@ rand = "0.8.5"
powdr-analysis = { path = "../analysis" }
powdr-executor = { path = "../executor" }

p3-air = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-matrix = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-uni-stark = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-commit = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa", features = [
p3-air = { git = "https://github.com/plonky3/Plonky3.git" }
p3-matrix = { git = "https://github.com/plonky3/Plonky3.git" }
p3-field = { git = "https://github.com/plonky3/Plonky3.git" }
p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git" }
p3-commit = { git = "https://github.com/plonky3/Plonky3.git", features = [
"test-utils",
] }
p3-poseidon2 = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-poseidon = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-fri = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-poseidon2 = { git = "https://github.com/plonky3/Plonky3.git" }
p3-poseidon = { git = "https://github.com/plonky3/Plonky3.git" }
p3-fri = { git = "https://github.com/plonky3/Plonky3.git" }
# We don't use p3-maybe-rayon directly, but it is a dependency of p3-uni-stark.
# Activating the "parallel" feature gives us parallelism in the prover.
p3-maybe-rayon = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa", features = [
p3-maybe-rayon = { git = "https://github.com/plonky3/Plonky3.git", features = [
"parallel",
] }

p3-mds = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-merkle-tree = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-mersenne-31 = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-circle = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-goldilocks = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-symmetric = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-dft = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-challenger = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-util = { git = "https://github.com/powdr-labs/Plonky3.git", rev = "b38d3fa" }
p3-mds = { git = "https://github.com/plonky3/Plonky3.git" }
p3-merkle-tree = { git = "https://github.com/plonky3/Plonky3.git" }
p3-mersenne-31 = { git = "https://github.com/plonky3/Plonky3.git" }
p3-circle = { git = "https://github.com/plonky3/Plonky3.git" }
p3-baby-bear = { git = "https://github.com/plonky3/Plonky3.git" }
p3-goldilocks = { git = "https://github.com/plonky3/Plonky3.git" }
p3-symmetric = { git = "https://github.com/plonky3/Plonky3.git" }
p3-dft = { git = "https://github.com/plonky3/Plonky3.git" }
p3-challenger = { git = "https://github.com/plonky3/Plonky3.git" }
p3-util = { git = "https://github.com/plonky3/Plonky3.git" }
lazy_static = "1.4.0"
rand_chacha = "0.3.1"
bincode = "1.3.3"
itertools = "0.13.0"

tracing = "0.1.37"
serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] }

[dev-dependencies]
powdr-pipeline.workspace = true
Expand Down
161 changes: 161 additions & 0 deletions plonky3/src/check_constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use alloc::vec::Vec;

use itertools::Itertools;
use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, PairBuilder};
use p3_field::Field;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::VerticalPair;
use p3_matrix::Matrix;
use tracing::instrument;

use crate::traits::MultistageAirBuilder;

#[instrument(name = "check constraints", skip_all)]
pub(crate) fn check_constraints<F, A>(
air: &A,
preprocessed: &RowMajorMatrix<F>,
traces_by_stage: Vec<&RowMajorMatrix<F>>,
public_values_by_stage: &Vec<&Vec<F>>,
challenges: Vec<&Vec<F>>,
) where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
{
let num_stages = traces_by_stage.len();
let height = traces_by_stage[0].height();

(0..height).for_each(|i| {
let i_next = (i + 1) % height;

let local_preprocessed = preprocessed.row_slice(i);
let next_preprocessed = preprocessed.row_slice(i_next);
let preprocessed = VerticalPair::new(
RowMajorMatrixView::new_row(&*local_preprocessed),
RowMajorMatrixView::new_row(&*next_preprocessed),
);

let stages_local_next = traces_by_stage
.iter()
.map(|trace| {
let stage_local = trace.row_slice(i);
let stage_next = trace.row_slice(i_next);
(stage_local, stage_next)
})
.collect_vec();

let traces_by_stage = (0..num_stages)
.map(|stage| {
VerticalPair::new(
RowMajorMatrixView::new_row(&*stages_local_next[stage].0),
RowMajorMatrixView::new_row(&*stages_local_next[stage].1),
)
})
.collect();

let mut builder = DebugConstraintBuilder {
row_index: i,
challenges: challenges.clone(),
preprocessed,
traces_by_stage,
public_values_by_stage,
is_first_row: F::from_bool(i == 0),
is_last_row: F::from_bool(i == height - 1),
is_transition: F::from_bool(i != height - 1),
};

air.eval(&mut builder);
});
}

/// An `AirBuilder` which asserts that each constraint is zero, allowing any failed constraints to
/// be detected early.
#[derive(Debug)]
pub struct DebugConstraintBuilder<'a, F: Field> {
row_index: usize,
preprocessed: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
challenges: Vec<&'a Vec<F>>,
traces_by_stage: Vec<VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>>,
public_values_by_stage: &'a [&'a Vec<F>],
is_first_row: F,
is_last_row: F,
is_transition: F,
}

impl<'a, F> AirBuilder for DebugConstraintBuilder<'a, F>
where
F: Field,
{
type F = F;
type Expr = F;
type Var = F;
type M = VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>;

fn is_first_row(&self) -> Self::Expr {
self.is_first_row
}

fn is_last_row(&self) -> Self::Expr {
self.is_last_row
}

fn is_transition_window(&self, size: usize) -> Self::Expr {
if size == 2 {
self.is_transition
} else {
panic!("only supports a window size of 2")
}
}

fn main(&self) -> Self::M {
self.traces_by_stage[0]
}

fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
assert_eq!(
x.into(),
F::zero(),
"constraints had nonzero value on row {}",
self.row_index
);
}

fn assert_eq<I1: Into<Self::Expr>, I2: Into<Self::Expr>>(&mut self, x: I1, y: I2) {
let x = x.into();
let y = y.into();
assert_eq!(
x, y,
"values didn't match on row {}: {} != {}",
self.row_index, x, y
);
}
}

impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> {
type PublicVar = Self::F;

fn public_values(&self) -> &[Self::PublicVar] {
self.stage_public_values(0)
}
}

impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> {
fn preprocessed(&self) -> Self::M {
self.preprocessed
}
}

impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> {
type Challenge = Self::Expr;

fn stage_public_values(&self, stage: usize) -> &[Self::F] {
self.public_values_by_stage[stage]
}

fn stage_trace(&self, stage: usize) -> Self::M {
self.traces_by_stage[stage]
}

fn stage_challenges(&self, stage: usize) -> &[Self::Expr] {
self.challenges[stage]
}
}
26 changes: 17 additions & 9 deletions plonky3/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use powdr_ast::analyzed::{
PolyID, PolynomialType, SelectedExpressions,
};

use p3_uni_stark::{CallbackResult, MultiStageAir, MultistageAirBuilder, NextStageTraceCallback};
use crate::{CallbackResult, MultiStageAir, MultistageAirBuilder, NextStageTraceCallback};
use powdr_ast::parsed::visitor::ExpressionVisitable;

use powdr_executor::witgen::WitgenCallback;
Expand All @@ -39,7 +39,7 @@ struct ConstraintSystem<T> {
// for each fixed column, the index of this column in the fixed columns
fixed_columns: HashMap<PolyID, usize>,
identities: Vec<Identity<SelectedExpressions<AlgebraicExpression<T>>>>,
publics: Vec<(String, PolyID, usize)>,
publics: Vec<(String, PolyID, usize, u32)>,
commitment_count: usize,
constant_count: usize,
// for each stage, the number of witness columns
Expand Down Expand Up @@ -182,7 +182,7 @@ where
self.constraint_system
.publics
.iter()
.filter_map(|(col_name, _, idx)| {
.filter_map(|(col_name, _, idx, _)| {
witness
.get(&col_name)
.map(|column| column[*idx].into_p3_field())
Expand Down Expand Up @@ -290,10 +290,6 @@ where
self.constraint_system.commitment_count
}

fn preprocessed_width(&self) -> usize {
self.constraint_system.constant_count + self.constraint_system.publics.len()
}

fn preprocessed_trace(&self) -> Option<RowMajorMatrix<Plonky3Field<T>>> {
#[cfg(debug_assertions)]
{
Expand Down Expand Up @@ -342,15 +338,15 @@ where
.publics
.iter()
.zip(pi.to_vec())
.map(|((id, _, _), val)| (id, val))
.map(|((id, _, _, _), val)| (id, val))
.collect::<BTreeMap<&String, <AB as AirBuilderWithPublicValues>::PublicVar>>();

// constrain public inputs using witness columns in stage 0
let fixed_local = fixed.row_slice(0);
let public_offset = self.constraint_system.constant_count;

self.constraint_system.publics.iter().enumerate().for_each(
|(index, (pub_id, poly_id, _))| {
|(index, (pub_id, poly_id, _, _))| {
let selector = fixed_local[public_offset + index];
let (stage, index) = self.constraint_system.witness_columns[poly_id];
assert_eq!(
Expand Down Expand Up @@ -401,6 +397,18 @@ where
ProverData<T>: Send,
Commitment<T>: Send,
{
fn stage_public_count(&self, stage: u32) -> usize {
self.constraint_system
.publics
.iter()
.filter(|(_, _, _, s)| *s == stage)
.count()
}

fn preprocessed_width(&self) -> usize {
self.constraint_system.constant_count + self.constraint_system.publics.len()
}

fn stage_count(&self) -> usize {
self.constraint_system.stage_widths.len()
}
Expand Down
Loading
Loading