From 1be1038f0acaffd9241e3d608d010d0ee0cacad7 Mon Sep 17 00:00:00 2001 From: PayneJoe Date: Thu, 14 Dec 2023 01:30:25 +0800 Subject: [PATCH] finish cross term & remove q_e --- src/primary/plonk.rs | 274 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 242 insertions(+), 32 deletions(-) diff --git a/src/primary/plonk.rs b/src/primary/plonk.rs index 17aa06c..62c97b2 100644 --- a/src/primary/plonk.rs +++ b/src/primary/plonk.rs @@ -1,8 +1,10 @@ /// plonk instances for primary circuit over BN254 curve /// +/// computation of cross terms followed from chapter 3.5 of protostar: https://eprint.iacr.org/2023/620.pdf +/// use ark_ec::pairing::Pairing; use ark_ec::CurveGroup; -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; use jf_primitives::pcs::prelude::Commitment; use jf_primitives::pcs::{ @@ -31,13 +33,12 @@ pub struct PLONKShape { pub(crate) num_wire_types: usize, pub(crate) num_public_input: usize, - pub(crate) q_c: Vec, pub(crate) q_lc: Vec>, pub(crate) q_mul: Vec>, + pub(crate) q_hash: Vec>, pub(crate) q_ecc: Vec, - pub(crate) q_hash: Vec, pub(crate) q_o: Vec, - pub(crate) q_e: Vec, + pub(crate) q_c: Vec, } /// A type that holds a witness for a given Plonk instance @@ -303,9 +304,8 @@ impl PLONKShape { q_lc: &Vec>, q_mul: &Vec>, q_ecc: &Vec, - q_hash: &Vec, + q_hash: &Vec>, q_o: &Vec, - q_e: &Vec, ) -> Result, MyError> { assert!(q_lc.len() == num_wire_types - 1); assert!(q_mul.len() == 2); @@ -318,9 +318,10 @@ impl PLONKShape { }; let invalid_num: i32 = vec![ - vec![q_c, q_ecc, q_hash, q_o, q_e], + vec![q_c, q_ecc, q_o], q_lc.into_iter().collect::>>(), q_mul.into_iter().collect::>>(), + q_hash.into_iter().collect::>>(), ] .concat() .iter() @@ -349,45 +350,128 @@ impl PLONKShape { q_ecc: q_ecc.to_owned(), q_hash: q_hash.to_owned(), q_o: q_o.to_owned(), - q_e: q_e.to_owned(), }) } + fn grand_product(n: usize, vec: Vec<&E::ScalarField>) -> E::ScalarField { + let first: E::ScalarField = *vec[0]; + if n == 1 { + first + } else { + vec[1..].iter().fold(first, |acc, cur| acc * *cur) + } + } + fn compute_cross_terms( + degree: usize, + u1: E::ScalarField, + u2: E::ScalarField, inst1: &Vec>, inst2: &Vec>, ) -> Vec> { assert!(inst1.len() == inst2.len(), "compute cross term"); - let element_mul = |a_vec: &Vec, b_vec: &Vec| { - a_vec.par_iter().zip(b_vec).map(|(a, b)| *a * *b).collect() + let transpose_matrix = |mat: Vec>| { + let num_row = mat[0].len(); + let mut mut_cols: Vec<_> = mat.into_iter().map(|col| col.into_iter()).collect(); + (0..num_row) + .map(|_| { + mut_cols + .iter_mut() + .map(|n| n.next().unwrap()) + .collect::>() + }) + .collect::>>() }; + let trans_inst1 = transpose_matrix(inst1.clone()); + let trans_inst2 = transpose_matrix(inst2.clone()); - let max_degree = inst1.len(); + let max_degree = 5 as usize; (1..max_degree) .rev() .map(|r_degree| { let l_degree = max_degree - r_degree; - let l_first: &Vec = &(inst1[0]); - let l_prod: Vec = inst1[1..l_degree] - .into_iter() - .collect::>>() - .iter() - .fold(l_first.clone(), |acc, xs| element_mul(&acc, xs)); - let r_first: &Vec = &(inst2[0]); - let r_prod: Vec = inst2[l_degree..] - .into_iter() - .collect::>>() - .iter() - .fold(r_first.clone(), |acc, xs| element_mul(&acc, xs)); + trans_inst1 + .par_iter() + .zip(&trans_inst2) + .map(|(row_a, row_b)| { + let l_vars = vec![ + vec![&u1; degree], + row_a + .into_iter() + .map(|a| a) + .collect::>(), + ] + .concat(); + let r_vars = vec![ + vec![&u2; degree], + row_b + .into_iter() + .map(|a| a) + .collect::>(), + ] + .concat(); + // let l_vars = vec![vec![u1; degree], row_a].concat(); + // let r_vars = vec![vec![u2; degree], row_b].concat(); + Self::grand_product(l_degree, l_vars) + * Self::grand_product(r_degree, r_vars) + }) + .collect::>() + }) + .rev() + .collect::>>() + } + + fn compute_cross_terms_five_exp( + inst1: &Vec, + inst2: &Vec, + ) -> Vec> { + let count_combination = |n: usize, r: usize| { + if r > n { + 0 + } else { + (1..=r).fold(1, |acc, val| acc * (n - val + 1) / val) + } + }; + let vec_pow = |n: usize, vec: &Vec| { + vec.par_iter() + .map(|v| { + let first = *v; + if n == 1 { + first + } else { + vec![v; n - 1].iter().fold(first, |a, b| a * *b) + } + }) + .collect::>() + }; - element_mul(&l_prod, &r_prod) + let max_degree: usize = 5; + (1..max_degree) + .rev() + .map(|r_degree| { + let l_degree = max_degree - r_degree; + let const_var = count_combination(max_degree, r_degree); + let const_scalar = ::from_bigint( + ::BigInt::from(const_var as u32), + ) + .unwrap(); + let ref_const_scalar = &const_scalar; + let l_pow = vec_pow(l_degree, inst1); + let r_pow = vec_pow(r_degree, inst2); + l_pow + .iter() + .zip(r_pow) + .map(|(a, b)| *ref_const_scalar * a * b) + .collect::>() }) .rev() .collect::>>() } + //// compute cross terms and their commitments + /// 1. length of cross term vector equals max_degree - 1 pub fn commit_T( &self, ck: &CommitmentKey, @@ -395,19 +479,145 @@ impl PLONKShape { W1: &RelaxedPLONKWitness, U2: &PLONKInstance, W2: &PLONKWitness, - ) -> Result<(Vec, Commitment), MyError> { + ) -> Result<(Vec>, Vec>), MyError> { assert!(W1.W.len() == self.num_wire_types - 1, "wrong wires"); - // q_ecc operation + // q_ecc operation, u^0 * q_ecc * w_0 * w_1 * w_2 * w_3 * w_o + let ecc_T: Vec> = Self::compute_cross_terms( + 0 as usize, + U1.u, + ::ONE, + &W1.W, + &W2.W, + ); + + // q_lc operation, u^4 * (q_lc_0 * w_0 + q_lc_1 * w_1 + q_lc_2 * w_2 + q_lc_3 * w_3) + let lc_T = (0..self.num_wire_types - 1) + .map(|i| { + Self::compute_cross_terms( + 4, + U1.u, + ::ONE, + &W1.W[i..i + 1].to_vec(), + &W2.W[i..i + 1].to_vec(), + ) + }) + .collect::>>>(); + + // q_mul operation, u^3 * (q_mul_0 * w_0 * w_1 + q_mul_1 * w_2 * w_3) + let mul_T = (0..self.num_wire_types - 1) + .step_by(2) + .map(|i| { + Self::compute_cross_terms( + 3, + U1.u, + ::ONE, + &W1.W[i..i + 2].to_vec(), + &W2.W[i..i + 2].to_vec(), + ) + }) + .collect::>>>(); + + // q_out operation, u^4 * (q_o * w_o) + let out_T = Self::compute_cross_terms( + 4, + U1.u, + ::ONE, + &W1.W[self.num_wire_types - 1..].to_vec(), + &W2.W[self.num_wire_types - 1..].to_vec(), + ); + + // q_c operation, u^5 * q_c + let u1_vec = vec![U1.u; self.num_cons]; + let u2_vec = vec![::ONE; self.num_cons]; + let const_T = Self::compute_cross_terms_five_exp(&u1_vec, &u2_vec); + + // q_hash operation, u^0 * (q_hash_0 * w_0^5 + q_hash_1 * w_1^5 + q_hash_2 * w_2^5 + q_hash_3 * w_3^5) + let hash_T = (0..self.num_wire_types - 1) + .map(|i| Self::compute_cross_terms_five_exp(&W1.W[i], &W2.W[i])) + .collect::>>>(); + + //////////////////////////////// apply selectors on cross terms + let apply_selector = |T: &Vec>, selector: &Vec| { + (0..self.num_wire_types - 1) + .map(|i| { + let ref_T = &T[i]; + ref_T + .par_iter() + .zip(selector) + .map(|(a, b)| *a * *b) + .collect::>() + }) + .collect::>>() + }; - // q_lc operation + let (ref_ecc_T, ref_out_T, ref_const_T, ref_q_ecc, ref_q_out, ref_q_const) = + (&ecc_T, &out_T, &const_T, &self.q_ecc, &self.q_o, &self.q_c); + let ecc_result = apply_selector(ref_ecc_T, ref_q_ecc); + let out_result = apply_selector(ref_out_T, ref_q_out); + let const_result = apply_selector(ref_const_T, ref_q_const); - // q_mul operation + let lc_result = (0..self.num_wire_types - 1) + .map(|i| { + let (ref_lc_T, ref_q_lc) = (&lc_T[i], &self.q_lc[i]); + apply_selector(ref_lc_T, ref_q_lc) + }) + .collect::>>>(); + + let hash_result = (0..self.num_wire_types - 1) + .map(|i| { + let (ref_hash_T, ref_q_hash) = (&hash_T[i], &self.q_hash[i]); + apply_selector(ref_hash_T, ref_q_hash) + }) + .collect::>>>(); - // q_out operation + let mul_result = (0..2) + .map(|i| { + let (ref_mul_T, ref_q_mul) = (&mul_T[i], &self.q_mul[i]); + apply_selector(ref_mul_T, ref_q_mul) + }) + .collect::>>>(); + + ////////////////////////////////////////// add-on all cross terms + let apply_mat_element_add = + |acc: &Vec>, cur: &Vec>| { + acc.into_iter() + .zip(cur) + .map(|(a_col, b_col)| { + a_col + .iter() + .zip(b_col) + .map(|(a, b)| *a + *b) + .collect::>() + }) + .collect::>>() + }; + + let stack_T = vec![ + vec![&ecc_result, &out_result, &const_result], + lc_result.iter().collect::>>>(), + hash_result + .iter() + .collect::>>>(), + mul_result + .iter() + .collect::>>>(), + ] + .concat(); + let T = stack_T[1..].iter().fold(stack_T[0].clone(), |acc, cur| { + apply_mat_element_add(&acc, cur) + }); - // q_c and pi operation + ////////////////////////////////////////// commit T + let com_T = T + .iter() + .map(|coefficients| { + let poly = as DenseUVPolynomial< + E::ScalarField, + >>::from_coefficients_vec(coefficients.clone()); + UnivariateKzgPCS::::commit(ck, &poly).unwrap() + }) + .collect::>>(); - // q_hash operation - todo!() + Ok((T, com_T)) } }