Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plonky public values constrained by fixed columns #1610

Merged
merged 12 commits into from
Aug 1, 2024
24 changes: 24 additions & 0 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,30 @@ impl<T> Analyzed<T> {
.filter_map(|(_poly, definition)| definition.as_mut())
.for_each(|definition| definition.post_visit_expressions_mut(f))
}

/// Retrieves (col_name, col_idx, offset) of each public witness in the trace.
pub fn get_publics(&self) -> Vec<(String, usize, usize)> {
let mut publics = self
.public_declarations
.values()
.map(|public_declaration| {
let column_name = public_declaration.referenced_poly_name();
let column_idx = {
let base = public_declaration.polynomial.poly_id.unwrap().id as usize;
match public_declaration.array_index {
Some(array_idx) => base + array_idx,
None => base,
}
};
let row_offset = public_declaration.index as usize;
(column_name, column_idx, row_offset)
})
.collect::<Vec<_>>();

// Sort, so that the order is deterministic
publics.sort();
publics
}
}

impl<T: FieldElement> Analyzed<T> {
Expand Down
156 changes: 46 additions & 110 deletions plonky3/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
//! A plonky3 adapter for powdr
//!
//! Supports public values without the use of fixed columns.
//!
//! Namely, given public value pub corresponding to a witness value in
//! row j of witness column x, a corresponding selector column s is constructed
//! to constrain s * (pub - x) on every row:
//!
//! col witness x;
//! public out_x = col x(j);
//! col witness s;
//! s * (pub - x) = 0;
//!
//! Moreover, s is constrained to be 1 at evaluation index s(j) and 0
//! everywhere else by applying the `is_zero` transformation to a column 'decr'
//! decrementing by 1 each row from an initial value set to j in the first row:
//!
//! col witness decr;
//! decr(0) = j;
//! decr - decr' - 1 = 0;
//! s = is_zero(decr);
//!
//! Note that in Plonky3 this transformation requires an additional column
//! `inv_decr` to track the inverse of decr for the `is_zero` operation,
//! therefore requiring a total of 3 extra witness columns per public value.
//! Supports public inputs with the use of fixed columns.
//! Namely, given public value pub corresponding to a witness value in row j
//! of witness column x, a corresponding fixed selector column s which is 0
//! everywhere save for at row j is constructed to constrain s * (pub - x) on
//! every row.

use std::{any::TypeId, collections::BTreeMap};

Expand Down Expand Up @@ -55,8 +37,6 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
pub fn generate_trace_rows(&self) -> RowMajorMatrix<Goldilocks> {
// an iterator over all columns, committed then fixed
let witness = self.witness().iter();

let publics = self.get_publics().into_iter();
let degrees = self.analyzed.degrees();

let values = match degrees.len() {
Expand All @@ -66,19 +46,7 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
(0..*degree)
.flat_map(move |i| {
// witness values
witness.clone().map(move |(_, v)| v[i as usize]).chain(
// publics rows: decrementor | inverse | selector
publics.clone().flat_map(move |(_, _, row_id)| {
let decr = T::from(row_id as u64) - T::from(i);
let inv_decr = if i as usize == row_id {
T::zero()
} else {
T::one() / decr
};
let s = T::from(i as usize == row_id);
[decr, inv_decr, s]
}),
)
witness.clone().map(move |(_, v)| v[i as usize])
})
.map(cast_to_goldilocks)
.collect()
Expand Down Expand Up @@ -122,35 +90,8 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
self.witness.as_ref().unwrap()
}

/// Retrieves (col_name, col_idx, offset) of each public witness in the trace.
pub(crate) fn get_publics(&self) -> Vec<(String, usize, usize)> {
let mut publics = self
.analyzed
.public_declarations
.values()
.map(|public_declaration| {
let witness_name = public_declaration.referenced_poly_name();
let witness_column = {
let base = public_declaration.polynomial.poly_id.unwrap().id as usize;
match public_declaration.array_index {
Some(array_idx) => base + array_idx,
None => base,
}
};
let witness_offset = public_declaration.index as usize;
(witness_name, witness_column, witness_offset)
})
.collect::<Vec<_>>();

// Sort, so that the order is deterministic
publics.sort();
publics
}

/// Calculates public values from generated witness values.
pub(crate) fn get_public_values(&self) -> Vec<Goldilocks> {
let publics = self.get_publics();

let witness = self
.witness
.as_ref()
Expand All @@ -159,11 +100,12 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
.map(|(name, values)| (name, values))
.collect::<BTreeMap<_, _>>();

publics
.into_iter()
self.analyzed
.get_publics()
.iter()
.map(|(col_name, _, idx)| {
let vals = *witness.get(&col_name).unwrap();
cast_to_goldilocks(vals[idx])
cast_to_goldilocks(vals[*idx])
})
.collect()
}
Expand Down Expand Up @@ -193,11 +135,12 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
}

/// Conversion to plonky3 expression
fn to_plonky3_expr<AB: AirBuilder<F = Val>>(
fn to_plonky3_expr<AB: AirBuilder<F = Val> + AirBuilderWithPublicValues>(
&self,
e: &AlgebraicExpression<T>,
main: &AB::M,
fixed: &AB::M,
publics: &BTreeMap<&String, <AB as AirBuilderWithPublicValues>::PublicVar>,
) -> AB::Expr {
let res = match e {
AlgebraicExpression::Reference(r) => {
Expand Down Expand Up @@ -225,13 +168,14 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
}
}
}
AlgebraicExpression::PublicReference(_) => unimplemented!(
"public references are not supported inside algebraic expressions in plonky3"
),
AlgebraicExpression::PublicReference(id) => (*publics
.get(id)
.expect("Referenced public value does not exist"))
.into(),
AlgebraicExpression::Number(n) => AB::Expr::from(cast_to_goldilocks(*n)),
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => {
let left = self.to_plonky3_expr::<AB>(left, main, fixed);
let right = self.to_plonky3_expr::<AB>(right, main, fixed);
let left = self.to_plonky3_expr::<AB>(left, main, fixed, publics);
let right = self.to_plonky3_expr::<AB>(right, main, fixed, publics);

match op {
AlgebraicBinaryOperator::Add => left + right,
Expand All @@ -243,7 +187,8 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
}
}
AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => {
let expr: <AB as AirBuilder>::Expr = self.to_plonky3_expr::<AB>(expr, main, fixed);
let expr: <AB as AirBuilder>::Expr =
self.to_plonky3_expr::<AB>(expr, main, fixed, publics);

match op {
AlgebraicUnaryOperator::Minus => -expr,
Expand All @@ -257,13 +202,15 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> {
}
}

/// An extension of [Air] allowing access to the number of fixed columns
topanisto marked this conversation as resolved.
Show resolved Hide resolved

impl<'a, T: FieldElement> BaseAir<Val> for PowdrCircuit<'a, T> {
fn width(&self) -> usize {
self.analyzed.commitment_count() + 3 * self.analyzed.publics_count()
self.analyzed.commitment_count()
}

fn preprocessed_width(&self) -> usize {
self.analyzed.constant_count()
self.analyzed.constant_count() + self.analyzed.publics_count()
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
}

fn preprocessed_trace(&self) -> Option<RowMajorMatrix<Val>> {
Expand All @@ -283,44 +230,32 @@ impl<'a, T: FieldElement, AB: AirBuilderWithPublicValues<F = Val> + PairBuilder>
let main = builder.main();
let fixed = builder.preprocessed();
let pi = builder.public_values();
let publics = self.get_publics();
let publics = self.analyzed.get_publics();
assert_eq!(publics.len(), pi.len());

let local = main.row_slice(0);

// public constraints
let pi_moved = pi.to_vec();
let (local, next) = (main.row_slice(0), main.row_slice(1));

let public_offset = self.analyzed.commitment_count();

publics.iter().zip(pi_moved).enumerate().for_each(
|(index, ((_, col_id, row_id), public_value))| {
//set decr for each public to be row_id in the first row and decrement by 1 each row
let (decr, inv_decr, s, decr_next) = (
local[public_offset + 3 * index],
local[public_offset + 3 * index + 1],
local[public_offset + 3 * index + 2],
next[public_offset + 3 * index],
);

let mut when_first_row = builder.when_first_row();
when_first_row.assert_eq(
decr,
cast_to_goldilocks(GoldilocksField::from(*row_id as u32)),
);

let mut when_transition = builder.when_transition();
when_transition.assert_eq(decr, decr_next + AB::Expr::one());

// is_zero logic-- s(row) is 1 if decr(row) is 0 and 0 otherwise
builder.assert_bool(s); //constraining s to 1 or 0
builder.assert_eq(s, AB::Expr::one() - inv_decr * decr);
builder.assert_zero(s * decr); //constraining is_zero
let public_vals_by_id = publics
.iter()
.zip(pi.to_vec())
.map(|((id, _, _), val)| (id, val))
.collect::<BTreeMap<&String, <AB as AirBuilderWithPublicValues>::PublicVar>>();

// constraining s(i) * (pub[i] - x(i)) = 0
let fixed_local = fixed.row_slice(0);
let public_offset = self.analyzed.constant_count();

publics
.iter()
.enumerate()
.for_each(|(index, (pub_id, col_id, _))| {
let selector = fixed_local[public_offset + index];
let witness_col = local[*col_id];
builder.assert_zero(s * (public_value.into() - witness_col));
},
);
let public_value = public_vals_by_id[pub_id];

// constraining s(i) * (pub[i] - x(i)) = 0
builder.assert_zero(selector * (public_value.into() - witness_col));
});

// circuit constraints
for identity in &self
Expand All @@ -337,6 +272,7 @@ impl<'a, T: FieldElement, AB: AirBuilderWithPublicValues<F = Val> + PairBuilder>
identity.left.selector.as_ref().unwrap(),
&main,
&fixed,
&public_vals_by_id,
);

builder.assert_zero(left);
Expand Down
66 changes: 60 additions & 6 deletions plonky3/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,19 @@ impl<T: FieldElement> Plonky3Prover<T> {
/// Returns preprocessed matrix based on the fixed inputs [`Plonky3Prover<T>`].
/// This is used when running the setup phase
pub fn get_preprocessed_matrix(&self) -> RowMajorMatrix<Goldilocks> {
match self.fixed.len() {
let publics = self
.analyzed
.get_publics()
.into_iter()
.map(|(name, _, row_id)| {
let selector = (0..self.analyzed.degree())
.map(move |i| T::from(i == row_id as u64))
.collect::<Vec<T>>();
(name, selector)
})
.collect::<Vec<(String, Vec<T>)>>();

match self.fixed.len() + publics.len() {
0 => RowMajorMatrix::new(Vec::<Goldilocks>::new(), 0),
_ => RowMajorMatrix::new(
// write fixed row by row
Expand All @@ -78,9 +90,14 @@ impl<T: FieldElement> Plonky3Prover<T> {
self.fixed
.iter()
.map(move |(_, values)| cast_to_goldilocks(values[i as usize]))
.chain(
publics
.iter()
.map(move |(_, values)| cast_to_goldilocks(values[i as usize])),
)
})
.collect(),
self.fixed.len(),
self.fixed.len() + publics.len(),
),
}
}
Expand All @@ -91,7 +108,20 @@ impl<T: FieldElement> Plonky3Prover<T> {
// get fixed columns
let fixed = &self.fixed;

if fixed.is_empty() {
// get selector columns for public values
let publics = self
.analyzed
.get_publics()
.into_iter()
.map(|(name, _, row_id)| {
let selector = (0..self.analyzed.degree())
.map(move |i| T::from(i == row_id as u64))
.collect::<Vec<T>>();
(name, selector)
})
.collect::<Vec<(String, Vec<T>)>>();

if fixed.is_empty() && publics.is_empty() {
return;
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
}

Expand All @@ -104,8 +134,18 @@ impl<T: FieldElement> Plonky3Prover<T> {
pcs,
self.analyzed.degree() as usize,
);
// get the preprocessed matrix
let matrix = self.get_preprocessed_matrix();
// write fixed into matrix row by row
let matrix = RowMajorMatrix::new(
(0..self.analyzed.degree())
.flat_map(|i| {
fixed
.iter()
Schaeff marked this conversation as resolved.
Show resolved Hide resolved
.chain(publics.iter())
.map(move |(_, values)| cast_to_goldilocks(values[i as usize]))
})
.collect(),
self.fixed.len() + publics.len(),
);

let evaluations = vec![(domain, matrix)];

Expand Down Expand Up @@ -238,11 +278,25 @@ mod tests {
}

#[test]
fn publics() {
fn public_values() {
let content = "namespace Global(8); pol witness x; x * (x - 1) = 0; public out = x(7);";
run_test_goldilocks(content);
}

#[test]
#[should_panic = "not implemented: Unexpected expression: :oldstate"]
fn public_reference() {
let content = r#"
namespace Global(8);
col witness x;
col witness y;
public oldstate = x(0);
x = 0;
y = 1 + :oldstate;
"#;
run_test_goldilocks(content);
}

#[test]
#[should_panic = "fri err: InvalidPowWitness"]
fn public_inputs_malicious() {
Expand Down
Loading