Skip to content

Commit

Permalink
finish cross term & remove q_e
Browse files Browse the repository at this point in the history
  • Loading branch information
PayneJoe committed Dec 13, 2023
1 parent d5b0201 commit 1be1038
Showing 1 changed file with 242 additions and 32 deletions.
274 changes: 242 additions & 32 deletions src/primary/plonk.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/// plonk instances for primary circuit over BN254 curve
///
/// computation of cross terms followed from chapter 3.5 of protostar: https://eprint.iacr.org/2023/620.pdf
///
use ark_ec::pairing::Pairing;
use ark_ec::CurveGroup;
use ark_ff::Field;
use ark_ff::{Field, PrimeField};
use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial};
use jf_primitives::pcs::prelude::Commitment;
use jf_primitives::pcs::{
Expand Down Expand Up @@ -31,13 +33,12 @@ pub struct PLONKShape<E: Pairing> {
pub(crate) num_wire_types: usize,
pub(crate) num_public_input: usize,

pub(crate) q_c: Vec<E::ScalarField>,
pub(crate) q_lc: Vec<Vec<E::ScalarField>>,
pub(crate) q_mul: Vec<Vec<E::ScalarField>>,
pub(crate) q_hash: Vec<Vec<E::ScalarField>>,
pub(crate) q_ecc: Vec<E::ScalarField>,
pub(crate) q_hash: Vec<E::ScalarField>,
pub(crate) q_o: Vec<E::ScalarField>,
pub(crate) q_e: Vec<E::ScalarField>,
pub(crate) q_c: Vec<E::ScalarField>,
}

/// A type that holds a witness for a given Plonk instance
Expand Down Expand Up @@ -303,9 +304,8 @@ impl<E: Pairing> PLONKShape<E> {
q_lc: &Vec<Vec<E::ScalarField>>,
q_mul: &Vec<Vec<E::ScalarField>>,
q_ecc: &Vec<E::ScalarField>,
q_hash: &Vec<E::ScalarField>,
q_hash: &Vec<Vec<E::ScalarField>>,
q_o: &Vec<E::ScalarField>,
q_e: &Vec<E::ScalarField>,
) -> Result<PLONKShape<E>, MyError> {
assert!(q_lc.len() == num_wire_types - 1);
assert!(q_mul.len() == 2);
Expand All @@ -318,9 +318,10 @@ impl<E: Pairing> PLONKShape<E> {
};

let invalid_num: i32 = vec![
vec![q_c, q_ecc, q_hash, q_o, q_e],
vec![q_c, q_ecc, q_o],
q_lc.into_iter().collect::<Vec<&Vec<E::ScalarField>>>(),
q_mul.into_iter().collect::<Vec<&Vec<E::ScalarField>>>(),
q_hash.into_iter().collect::<Vec<&Vec<E::ScalarField>>>(),
]
.concat()
.iter()
Expand Down Expand Up @@ -349,65 +350,274 @@ impl<E: Pairing> PLONKShape<E> {
q_ecc: q_ecc.to_owned(),
q_hash: q_hash.to_owned(),
q_o: q_o.to_owned(),
q_e: q_e.to_owned(),
})
}

fn grand_product(n: usize, vec: Vec<&E::ScalarField>) -> E::ScalarField {
let first: E::ScalarField = *vec[0];
if n == 1 {
first
} else {
vec[1..].iter().fold(first, |acc, cur| acc * *cur)
}
}

fn compute_cross_terms(
degree: usize,
u1: E::ScalarField,
u2: E::ScalarField,
inst1: &Vec<Vec<E::ScalarField>>,
inst2: &Vec<Vec<E::ScalarField>>,
) -> Vec<Vec<E::ScalarField>> {
assert!(inst1.len() == inst2.len(), "compute cross term");

let element_mul = |a_vec: &Vec<E::ScalarField>, b_vec: &Vec<E::ScalarField>| {
a_vec.par_iter().zip(b_vec).map(|(a, b)| *a * *b).collect()
let transpose_matrix = |mat: Vec<Vec<E::ScalarField>>| {
let num_row = mat[0].len();
let mut mut_cols: Vec<_> = mat.into_iter().map(|col| col.into_iter()).collect();
(0..num_row)
.map(|_| {
mut_cols
.iter_mut()
.map(|n| n.next().unwrap())
.collect::<Vec<E::ScalarField>>()
})
.collect::<Vec<Vec<E::ScalarField>>>()
};
let trans_inst1 = transpose_matrix(inst1.clone());
let trans_inst2 = transpose_matrix(inst2.clone());

let max_degree = inst1.len();
let max_degree = 5 as usize;
(1..max_degree)
.rev()
.map(|r_degree| {
let l_degree = max_degree - r_degree;
let l_first: &Vec<E::ScalarField> = &(inst1[0]);
let l_prod: Vec<E::ScalarField> = inst1[1..l_degree]
.into_iter()
.collect::<Vec<&Vec<E::ScalarField>>>()
.iter()
.fold(l_first.clone(), |acc, xs| element_mul(&acc, xs));

let r_first: &Vec<E::ScalarField> = &(inst2[0]);
let r_prod: Vec<E::ScalarField> = inst2[l_degree..]
.into_iter()
.collect::<Vec<&Vec<E::ScalarField>>>()
.iter()
.fold(r_first.clone(), |acc, xs| element_mul(&acc, xs));
trans_inst1
.par_iter()
.zip(&trans_inst2)
.map(|(row_a, row_b)| {
let l_vars = vec![
vec![&u1; degree],
row_a
.into_iter()
.map(|a| a)
.collect::<Vec<&E::ScalarField>>(),
]
.concat();
let r_vars = vec![
vec![&u2; degree],
row_b
.into_iter()
.map(|a| a)
.collect::<Vec<&E::ScalarField>>(),
]
.concat();
// let l_vars = vec![vec![u1; degree], row_a].concat();
// let r_vars = vec![vec![u2; degree], row_b].concat();
Self::grand_product(l_degree, l_vars)
* Self::grand_product(r_degree, r_vars)
})
.collect::<Vec<E::ScalarField>>()
})
.rev()
.collect::<Vec<Vec<E::ScalarField>>>()
}

fn compute_cross_terms_five_exp(
inst1: &Vec<E::ScalarField>,
inst2: &Vec<E::ScalarField>,
) -> Vec<Vec<E::ScalarField>> {
let count_combination = |n: usize, r: usize| {
if r > n {
0
} else {
(1..=r).fold(1, |acc, val| acc * (n - val + 1) / val)
}
};
let vec_pow = |n: usize, vec: &Vec<E::ScalarField>| {
vec.par_iter()
.map(|v| {
let first = *v;
if n == 1 {
first
} else {
vec![v; n - 1].iter().fold(first, |a, b| a * *b)
}
})
.collect::<Vec<E::ScalarField>>()
};

element_mul(&l_prod, &r_prod)
let max_degree: usize = 5;
(1..max_degree)
.rev()
.map(|r_degree| {
let l_degree = max_degree - r_degree;
let const_var = count_combination(max_degree, r_degree);
let const_scalar = <E::ScalarField as PrimeField>::from_bigint(
<E::ScalarField as PrimeField>::BigInt::from(const_var as u32),
)
.unwrap();
let ref_const_scalar = &const_scalar;
let l_pow = vec_pow(l_degree, inst1);
let r_pow = vec_pow(r_degree, inst2);
l_pow
.iter()
.zip(r_pow)
.map(|(a, b)| *ref_const_scalar * a * b)
.collect::<Vec<E::ScalarField>>()
})
.rev()
.collect::<Vec<Vec<E::ScalarField>>>()
}

//// compute cross terms and their commitments
/// 1. length of cross term vector equals max_degree - 1
pub fn commit_T(
&self,
ck: &CommitmentKey<E>,
U1: &RelaxedPLONKInstance<E>,
W1: &RelaxedPLONKWitness<E>,
U2: &PLONKInstance<E>,
W2: &PLONKWitness<E>,
) -> Result<(Vec<E::ScalarField>, Commitment<E>), MyError> {
) -> Result<(Vec<Vec<E::ScalarField>>, Vec<Commitment<E>>), MyError> {
assert!(W1.W.len() == self.num_wire_types - 1, "wrong wires");
// q_ecc operation
// q_ecc operation, u^0 * q_ecc * w_0 * w_1 * w_2 * w_3 * w_o
let ecc_T: Vec<Vec<E::ScalarField>> = Self::compute_cross_terms(
0 as usize,
U1.u,
<E::ScalarField as Field>::ONE,
&W1.W,
&W2.W,
);

// q_lc operation, u^4 * (q_lc_0 * w_0 + q_lc_1 * w_1 + q_lc_2 * w_2 + q_lc_3 * w_3)
let lc_T = (0..self.num_wire_types - 1)
.map(|i| {
Self::compute_cross_terms(
4,
U1.u,
<E::ScalarField as Field>::ONE,
&W1.W[i..i + 1].to_vec(),
&W2.W[i..i + 1].to_vec(),
)
})
.collect::<Vec<Vec<Vec<E::ScalarField>>>>();

// q_mul operation, u^3 * (q_mul_0 * w_0 * w_1 + q_mul_1 * w_2 * w_3)
let mul_T = (0..self.num_wire_types - 1)
.step_by(2)
.map(|i| {
Self::compute_cross_terms(
3,
U1.u,
<E::ScalarField as Field>::ONE,
&W1.W[i..i + 2].to_vec(),
&W2.W[i..i + 2].to_vec(),
)
})
.collect::<Vec<Vec<Vec<E::ScalarField>>>>();

// q_out operation, u^4 * (q_o * w_o)
let out_T = Self::compute_cross_terms(
4,
U1.u,
<E::ScalarField as Field>::ONE,
&W1.W[self.num_wire_types - 1..].to_vec(),
&W2.W[self.num_wire_types - 1..].to_vec(),
);

// q_c operation, u^5 * q_c
let u1_vec = vec![U1.u; self.num_cons];
let u2_vec = vec![<E::ScalarField as Field>::ONE; self.num_cons];
let const_T = Self::compute_cross_terms_five_exp(&u1_vec, &u2_vec);

// q_hash operation, u^0 * (q_hash_0 * w_0^5 + q_hash_1 * w_1^5 + q_hash_2 * w_2^5 + q_hash_3 * w_3^5)
let hash_T = (0..self.num_wire_types - 1)
.map(|i| Self::compute_cross_terms_five_exp(&W1.W[i], &W2.W[i]))
.collect::<Vec<Vec<Vec<E::ScalarField>>>>();

//////////////////////////////// apply selectors on cross terms
let apply_selector = |T: &Vec<Vec<E::ScalarField>>, selector: &Vec<E::ScalarField>| {
(0..self.num_wire_types - 1)
.map(|i| {
let ref_T = &T[i];
ref_T
.par_iter()
.zip(selector)
.map(|(a, b)| *a * *b)
.collect::<Vec<E::ScalarField>>()
})
.collect::<Vec<Vec<E::ScalarField>>>()
};

// q_lc operation
let (ref_ecc_T, ref_out_T, ref_const_T, ref_q_ecc, ref_q_out, ref_q_const) =
(&ecc_T, &out_T, &const_T, &self.q_ecc, &self.q_o, &self.q_c);
let ecc_result = apply_selector(ref_ecc_T, ref_q_ecc);
let out_result = apply_selector(ref_out_T, ref_q_out);
let const_result = apply_selector(ref_const_T, ref_q_const);

// q_mul operation
let lc_result = (0..self.num_wire_types - 1)
.map(|i| {
let (ref_lc_T, ref_q_lc) = (&lc_T[i], &self.q_lc[i]);
apply_selector(ref_lc_T, ref_q_lc)
})
.collect::<Vec<Vec<Vec<E::ScalarField>>>>();

let hash_result = (0..self.num_wire_types - 1)
.map(|i| {
let (ref_hash_T, ref_q_hash) = (&hash_T[i], &self.q_hash[i]);
apply_selector(ref_hash_T, ref_q_hash)
})
.collect::<Vec<Vec<Vec<E::ScalarField>>>>();

// q_out operation
let mul_result = (0..2)
.map(|i| {
let (ref_mul_T, ref_q_mul) = (&mul_T[i], &self.q_mul[i]);
apply_selector(ref_mul_T, ref_q_mul)
})
.collect::<Vec<Vec<Vec<E::ScalarField>>>>();

////////////////////////////////////////// add-on all cross terms
let apply_mat_element_add =
|acc: &Vec<Vec<E::ScalarField>>, cur: &Vec<Vec<E::ScalarField>>| {
acc.into_iter()
.zip(cur)
.map(|(a_col, b_col)| {
a_col
.iter()
.zip(b_col)
.map(|(a, b)| *a + *b)
.collect::<Vec<E::ScalarField>>()
})
.collect::<Vec<Vec<E::ScalarField>>>()
};

let stack_T = vec![
vec![&ecc_result, &out_result, &const_result],
lc_result.iter().collect::<Vec<&Vec<Vec<E::ScalarField>>>>(),
hash_result
.iter()
.collect::<Vec<&Vec<Vec<E::ScalarField>>>>(),
mul_result
.iter()
.collect::<Vec<&Vec<Vec<E::ScalarField>>>>(),
]
.concat();
let T = stack_T[1..].iter().fold(stack_T[0].clone(), |acc, cur| {
apply_mat_element_add(&acc, cur)
});

// q_c and pi operation
////////////////////////////////////////// commit T
let com_T = T
.iter()
.map(|coefficients| {
let poly = <DensePolynomial<E::ScalarField> as DenseUVPolynomial<
E::ScalarField,
>>::from_coefficients_vec(coefficients.clone());
UnivariateKzgPCS::<E>::commit(ck, &poly).unwrap()
})
.collect::<Vec<Commitment<E>>>();

// q_hash operation
todo!()
Ok((T, com_T))
}
}

0 comments on commit 1be1038

Please sign in to comment.