Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(evm_arithmetization): remove duplicate constraint in keccak_stark #48

Merged
merged 8 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions evm_arithmetization/src/keccak/keccak_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();

// The filter must be 0 or 1.
let filter = local_values[reg_step(NUM_ROUNDS - 1)];
yield_constr.constraint(filter * (filter - P::ONES));

// If this is not the final step, the filter must be off.
let final_step = local_values[reg_step(NUM_ROUNDS - 1)];
let not_final_step = P::ONES - final_step;
yield_constr.constraint(not_final_step * filter);
let not_final_step = P::ONES - local_values[reg_step(NUM_ROUNDS - 1)];

// If this is not the final step or a padding row,
// the local and next timestamps must match.
Expand Down Expand Up @@ -446,16 +440,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();

// The filter must be 0 or 1.
let filter = local_values[reg_step(NUM_ROUNDS - 1)];
let constraint = builder.mul_sub_extension(filter, filter, filter);
yield_constr.constraint(builder, constraint);

// If this is not the final step, the filter must be off.
let final_step = local_values[reg_step(NUM_ROUNDS - 1)];
let not_final_step = builder.sub_extension(one_ext, final_step);
let constraint = builder.mul_extension(not_final_step, filter);
yield_constr.constraint(builder, constraint);
let not_final_step = builder.sub_extension(one_ext, local_values[reg_step(NUM_ROUNDS - 1)]);

// If this is not the final step or a padding row,
// the local and next timestamps must match.
Expand Down
26 changes: 18 additions & 8 deletions evm_arithmetization/src/keccak/round_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,21 @@ pub(crate) fn eval_round_flags<F: Field, P: PackedField<Scalar = F>>(
}

// Flags should circularly increment, or be all zero for padding rows.
let current_any_flag = (0..NUM_ROUNDS)
.map(|i| local_values[reg_step(i)])
.sum::<P>();
let next_any_flag = (0..NUM_ROUNDS).map(|i| next_values[reg_step(i)]).sum::<P>();
let last_row_flag = local_values[reg_step(NUM_ROUNDS - 1)];
shuklaayush marked this conversation as resolved.
Show resolved Hide resolved
for i in 0..NUM_ROUNDS {
let current_round_flag = local_values[reg_step(i)];
let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)];
yield_constr.constraint_transition(next_any_flag * (next_round_flag - current_round_flag));
yield_constr.constraint_transition(
next_any_flag * (next_round_flag - current_round_flag)
+ (next_any_flag - F::ONE) * current_any_flag * (last_row_flag - F::ONE),
);
}

// Padding rows should always be followed by padding rows.
let current_any_flag = (0..NUM_ROUNDS)
.map(|i| local_values[reg_step(i)])
.sum::<P>();
yield_constr.constraint_transition(next_any_flag * (current_any_flag - F::ONE));
}

Expand All @@ -56,19 +60,25 @@ pub(crate) fn eval_round_flags_recursively<F: RichField + Extendable<D>, const D
}

// Flags should circularly increment, or be all zero for padding rows.
let current_any_flag =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
let next_any_flag =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| next_values[reg_step(i)]));
let last_row_flag = local_values[reg_step(NUM_ROUNDS - 1)];
for i in 0..NUM_ROUNDS {
let current_round_flag = local_values[reg_step(i)];
let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)];
let diff = builder.sub_extension(next_round_flag, current_round_flag);
let constraint = builder.mul_extension(next_any_flag, diff);
let diff1 = builder.sub_extension(next_round_flag, current_round_flag);
let constraint1 = builder.mul_extension(next_any_flag, diff1);
let diff2 = builder.sub_extension(next_any_flag, one);
let diff3 = builder.sub_extension(last_row_flag, one);
let prod = builder.mul_extension(diff2, diff3);
let constraint2 = builder.mul_extension(current_any_flag, prod);
let constraint = builder.add_extension(constraint1, constraint2);
shuklaayush marked this conversation as resolved.
Show resolved Hide resolved
yield_constr.constraint_transition(builder, constraint);
}

// Padding rows should always be followed by padding rows.
let current_any_flag =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
let constraint = builder.mul_sub_extension(next_any_flag, current_any_flag, next_any_flag);
yield_constr.constraint_transition(builder, constraint);
}