Skip to content

Commit

Permalink
suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
dark64 committed Aug 1, 2023
1 parent 52b6815 commit 0fd334b
Showing 1 changed file with 35 additions and 43 deletions.
78 changes: 35 additions & 43 deletions zokrates_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,15 +844,25 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let lhs_id = self.define(lhs_flattened, statements_flattened);
let rhs_id = self.define(rhs_flattened, statements_flattened);

// shifted_sub := 2**bit_width + lhs - rhs
println!("{}", bit_width);

// shifted_sub := lhs + 2**bit_width - rhs
let shifted_sub = FlatExpression::add(
FlatExpression::value(T::from(2).pow(bit_width)),
FlatExpression::identifier(lhs_id),
FlatExpression::sub(
FlatExpression::identifier(lhs_id),
FlatExpression::value(T::from(2).pow(bit_width)),
FlatExpression::identifier(rhs_id),
),
);

// let shifted_sub = FlatExpression::add(
// FlatExpression::value(T::from(2).pow(bit_width)),
// FlatExpression::sub(
// FlatExpression::identifier(lhs_id),
// FlatExpression::identifier(rhs_id),
// ),
// );

let sub_width = bit_width + 1;

let shifted_sub_bits_be = self.get_bits_unchecked(
Expand Down Expand Up @@ -2020,49 +2030,31 @@ impl<'ast, T: Field> Flattener<'ast, T> {

let from = std::cmp::max(from, to);
let res = match self.bits_cache.entry(e.field.clone().unwrap()) {
Entry::Occupied(mut entry) => {
let res: Vec<_> = entry.get().clone();
Entry::Occupied(entry) => {
let mut res: Vec<_> = entry.get().clone();

if res.len() > to {
// if the result is bigger than `to`, we zero check the sum of higher bits up to `to`
let bit_sum = res[..res.len() - to]
.iter()
.cloned()
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::add(acc, e)
});

// sum check
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::value(T::from(0)),
bit_sum,
error,
));

// truncate to the `to` lowest bits
let bits = res[res.len() - to..].to_vec();
assert_eq!(bits.len(), to);

// update the entry
entry.insert(
(0..res.len() - to)
.map(|_| FlatExpression::value(T::zero()))
.chain(bits.clone())
.collect(),
);
// only keep the last `to` values and return the sum of the others
let sum = res
.drain(0..res.len().saturating_sub(to))
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::add(acc, e)
});

return bits;
}
// force the sum to be 0
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::value(T::from(0)),
sum,
error,
));

// if result is smaller than `to` we pad it with zeroes on the left (big endian) to return `to` bits
if res.len() < to {
return (0..to - res.len())
.map(|_| FlatExpression::value(T::zero()))
.chain(res)
.collect();
}
// sanity check that we have at most `to` values
assert!(res.len() <= to);

res
// return the result left-padded to `to` values
std::iter::repeat(FlatExpression::value(T::zero()))
.take(to - res.len())
.chain(res.clone())
.collect()
}
Entry::Vacant(_) => {
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
Expand Down Expand Up @@ -2700,7 +2692,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
c.value,
error.into(),
),
// c < e <=> 2^bw - 1 - e < 2^bw - 1 - c
// c <= e <=> 2^bw - 1 - e <= 2^bw - 1 - c
(FlatExpression::Value(c), e) => {
let max = T::from(2u32).pow(bitwidth) - T::one();
self.enforce_constant_le_check(
Expand Down

0 comments on commit 0fd334b

Please sign in to comment.