Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(evm_arithmetization): remove duplicate constraint in keccak_stark #48

Merged
merged 8 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions evm_arithmetization/src/keccak/keccak_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();

// The filter must be 0 or 1.
let filter = local_values[reg_step(NUM_ROUNDS - 1)];
yield_constr.constraint(filter * (filter - P::ONES));

// If this is not the final step, the filter must be off.
let final_step = local_values[reg_step(NUM_ROUNDS - 1)];
let not_final_step = P::ONES - final_step;
yield_constr.constraint(not_final_step * filter);
let not_final_step = P::ONES - local_values[reg_step(NUM_ROUNDS - 1)];

// If this is not the final step or a padding row,
// the local and next timestamps must match.
Expand Down Expand Up @@ -446,16 +440,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for KeccakStark<F
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();

// The filter must be 0 or 1.
let filter = local_values[reg_step(NUM_ROUNDS - 1)];
let constraint = builder.mul_sub_extension(filter, filter, filter);
yield_constr.constraint(builder, constraint);

// If this is not the final step, the filter must be off.
let final_step = local_values[reg_step(NUM_ROUNDS - 1)];
let not_final_step = builder.sub_extension(one_ext, final_step);
let constraint = builder.mul_extension(not_final_step, filter);
yield_constr.constraint(builder, constraint);
let not_final_step = builder.sub_extension(one_ext, local_values[reg_step(NUM_ROUNDS - 1)]);

// If this is not the final step or a padding row,
// the local and next timestamps must match.
Expand Down
27 changes: 19 additions & 8 deletions evm_arithmetization/src/keccak/round_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,22 @@ pub(crate) fn eval_round_flags<F: Field, P: PackedField<Scalar = F>>(
}

// Flags should circularly increment, or be all zero for padding rows.
let current_any_flag = (0..NUM_ROUNDS)
.map(|i| local_values[reg_step(i)])
.sum::<P>();
let next_any_flag = (0..NUM_ROUNDS).map(|i| next_values[reg_step(i)]).sum::<P>();
// Padding row should only start after the last round row.
let not_final_step = P::ONES - local_values[reg_step(NUM_ROUNDS - 1)];
let padding_constraint = (next_any_flag - F::ONE) * current_any_flag * not_final_step;
for i in 0..NUM_ROUNDS {
let current_round_flag = local_values[reg_step(i)];
let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)];
yield_constr.constraint_transition(next_any_flag * (next_round_flag - current_round_flag));
yield_constr.constraint_transition(
next_any_flag * (next_round_flag - current_round_flag) + padding_constraint,
);
}

// Padding rows should always be followed by padding rows.
let current_any_flag = (0..NUM_ROUNDS)
.map(|i| local_values[reg_step(i)])
.sum::<P>();
yield_constr.constraint_transition(next_any_flag * (current_any_flag - F::ONE));
}

Expand All @@ -56,19 +61,25 @@ pub(crate) fn eval_round_flags_recursively<F: RichField + Extendable<D>, const D
}

// Flags should circularly increment, or be all zero for padding rows.
let current_any_flag =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
let next_any_flag =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| next_values[reg_step(i)]));
// Padding row should only start after the last round row.
let not_final_step = builder.sub_extension(one, local_values[reg_step(NUM_ROUNDS - 1)]);
let padding_constraint = {
let tmp = builder.mul_sub_extension(current_any_flag, next_any_flag, current_any_flag);
builder.mul_extension(tmp, not_final_step)
shuklaayush marked this conversation as resolved.
Show resolved Hide resolved
};
for i in 0..NUM_ROUNDS {
let current_round_flag = local_values[reg_step(i)];
let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)];
let diff = builder.sub_extension(next_round_flag, current_round_flag);
let constraint = builder.mul_extension(next_any_flag, diff);
let flag_diff = builder.sub_extension(next_round_flag, current_round_flag);
let constraint = builder.mul_add_extension(next_any_flag, flag_diff, padding_constraint);
yield_constr.constraint_transition(builder, constraint);
}

// Padding rows should always be followed by padding rows.
let current_any_flag =
builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)]));
let constraint = builder.mul_sub_extension(next_any_flag, current_any_flag, next_any_flag);
yield_constr.constraint_transition(builder, constraint);
}