Skip to content

Commit

Permalink
Merge remote-tracking branch 'scroll/kevjue/phase_1_improvements' int…
Browse files Browse the repository at this point in the history
…o scroll-dev-0914
  • Loading branch information
lispc committed Sep 14, 2024
2 parents 9b9b512 + bf69f54 commit d8e89a4
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 154 deletions.
102 changes: 101 additions & 1 deletion crates/core/executor/src/events/precompiles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,75 @@ pub use bn254_scalar::{
pub use ec::*;
pub use edwards::*;
pub use fptower::*;
use hashbrown::HashMap;
pub use keccak256_permute::*;
use serde::{Deserialize, Serialize};
pub use sha256_compress::*;
pub use sha256_extend::*;
use strum::EnumIter;
use strum::{EnumIter, IntoEnumIterator};
pub use uint256::*;

use crate::syscalls::SyscallCode;

use super::{MemCopyEvent, MemoryLocalEvent};


#[derive(Clone, Debug, Serialize, Deserialize, EnumIter)]
/// Precompile event. There should be one variant for every precompile syscall.
pub enum PrecompileEvent {
/// Sha256 extend precompile event.
ShaExtend(ShaExtendEvent),
/// Sha256 compress precompile event.
ShaCompress(ShaCompressEvent),
/// Keccak256 permute precompile event.
KeccakPermute(KeccakPermuteEvent),
/// Edwards curve add precompile event.
EdAdd(EllipticCurveAddEvent),
/// Edwards curve decompress precompile event.
EdDecompress(EdDecompressEvent),
/// Secp256k1 curve add precompile event.
Secp256k1Add(EllipticCurveAddEvent),
/// Secp256k1 curve double precompile event.
Secp256k1Double(EllipticCurveDoubleEvent),
/// Secp256k1 curve decompress precompile event.
Secp256k1Decompress(EllipticCurveDecompressEvent),
/// K256 curve decompress precompile event.
K256Decompress(EllipticCurveDecompressEvent),
/// Bn254 curve add precompile event.
Bn254Add(EllipticCurveAddEvent),
/// Bn254 curve double precompile event.
Bn254Double(EllipticCurveDoubleEvent),
/// Bn254 base field operation precompile event.
Bn254Fp(FpOpEvent),
/// Bn254 quadratic field add/sub precompile event.
Bn254Fp2AddSub(Fp2AddSubEvent),
/// Bn254 quadratic field mul precompile event.
Bn254Fp2Mul(Fp2MulEvent),

Bn254ScalarMac(Bn254FieldArithEvent),
Bn254ScalarMul(Bn254FieldArithEvent),
MemCopy32(MemCopyEvent),
MemCopy64(MemCopyEvent),

/// Bls12-381 curve add precompile event.
Bls12381Add(EllipticCurveAddEvent),
/// Bls12-381 curve double precompile event.
Bls12381Double(EllipticCurveDoubleEvent),
/// Bls12-381 curve decompress precompile event.
Bls12381Decompress(EllipticCurveDecompressEvent),
/// Bls12-381 base field operation precompile event.
Bls12381Fp(FpOpEvent),
/// Bls12-381 quadratic field add/sub precompile event.
Bls12381Fp2AddSub(Fp2AddSubEvent),
/// Bls12-381 quadratic field mul precompile event.
Bls12381Fp2Mul(Fp2MulEvent),
/// Uint256 mul precompile event.
Uint256Mul(Uint256MulEvent),
}

/// Trait to retrieve all the local memory events from a vec of precompile events.
pub trait PrecompileLocalMemory {
/// Get an iterator of all the local memory events.
fn get_local_mem_events(&self) -> impl IntoIterator<Item = &MemoryLocalEvent>;
}

Expand Down Expand Up @@ -114,3 +143,74 @@ impl PrecompileLocalMemory for Vec<PrecompileEvent> {
iterators.into_iter().flatten()
}
}

/// A record of all the precompile events.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PrecompileEvents {
events: HashMap<SyscallCode, Vec<PrecompileEvent>>,
}

impl Default for PrecompileEvents {
fn default() -> Self {
let mut events = HashMap::new();
for syscall_code in SyscallCode::iter() {
if syscall_code.should_send() == 1 {
events.insert(syscall_code, Vec::new());
}
}

Self { events }
}
}

impl PrecompileEvents {
pub(crate) fn append(&mut self, other: &mut PrecompileEvents) {
for (syscall, events) in other.events.iter_mut() {
if !events.is_empty() {
self.events.entry(*syscall).or_default().append(events);
}
}
}

#[inline]
/// Add a precompile event for a given syscall code.
pub(crate) fn add_event(&mut self, syscall_code: SyscallCode, event: PrecompileEvent) {
assert!(syscall_code.should_send() == 1);
self.events.entry(syscall_code).or_default().push(event);
}

#[inline]
/// Insert a vector of precompile events for a given syscall code.
pub(crate) fn insert(&mut self, syscall_code: SyscallCode, events: Vec<PrecompileEvent>) {
assert!(syscall_code.should_send() == 1);
self.events.insert(syscall_code, events);
}

#[inline]
pub(crate) fn into_iter(self) -> impl Iterator<Item = (SyscallCode, Vec<PrecompileEvent>)> {
self.events.into_iter()
}

#[inline]
pub(crate) fn iter(&self) -> impl Iterator<Item = (&SyscallCode, &Vec<PrecompileEvent>)> {
self.events.iter()
}

#[inline]
/// Get all the precompile events for a given syscall code.
pub(crate) fn get_events(&self, syscall_code: SyscallCode) -> &Vec<PrecompileEvent> {
assert!(syscall_code.should_send() == 1);
&self.events[&syscall_code]
}

/// Get all the local events from all the precompile events.
pub(crate) fn get_local_mem_events(&self) -> impl Iterator<Item = &MemoryLocalEvent> {
let mut iterators = Vec::new();

for (_, events) in self.events.iter() {
iterators.push(events.get_local_mem_events());
}

iterators.into_iter().flatten()
}
}
102 changes: 12 additions & 90 deletions crates/core/executor/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use sp1_stark::{
MachineRecord, SP1CoreOpts, SplitOpts,
};
use std::{mem::take, sync::Arc};
use strum::IntoEnumIterator;

use serde::{Deserialize, Serialize};

Expand All @@ -15,7 +14,7 @@ use crate::{
events::{
add_sharded_byte_lookup_events, AluEvent, ByteLookupEvent, ByteRecord, CpuEvent, LookupId,
MemoryInitializeFinalizeEvent, MemoryLocalEvent, MemoryRecordEnum, PrecompileEvent,
PrecompileLocalMemory, SyscallEvent,
PrecompileEvents, SyscallEvent,
},
syscalls::SyscallCode,
CoreShape,
Expand Down Expand Up @@ -229,15 +228,20 @@ impl ExecutionRecord {
}

#[inline]
/// Add a precompile event to the execution record.
pub fn add_precompile_event(&mut self, syscall_code: SyscallCode, event: PrecompileEvent) {
self.precompile_events.add_precompile_event(syscall_code, event)
self.precompile_events.add_event(syscall_code, event);
}

/// Get all the precompile events for a syscall code.
#[inline]
#[must_use]
pub fn get_precompile_events(&self, syscall_code: SyscallCode) -> &Vec<PrecompileEvent> {
self.precompile_events.get_events(syscall_code)
}

/// Get all the local memory events.
#[inline]
pub fn get_local_mem_events(&self) -> impl Iterator<Item = &MemoryLocalEvent> {
let precompile_local_mem_events = self.precompile_events.get_local_mem_events();
precompile_local_mem_events.chain(self.cpu_local_memory_access.iter())
Expand Down Expand Up @@ -271,32 +275,11 @@ impl MachineRecord for ExecutionRecord {
stats.insert("shift_right_events".to_string(), self.shift_right_events.len());
stats.insert("divrem_events".to_string(), self.divrem_events.len());
stats.insert("lt_events".to_string(), self.lt_events.len());
// stats.insert("sha_extend_events".to_string(), self.sha_extend_events.len());
// stats.insert("sha_compress_events".to_string(), self.sha_compress_events.len());
// stats.insert("keccak_permute_events".to_string(), self.keccak_permute_events.len());
// stats.insert("ed_add_events".to_string(), self.ed_add_events.len());
// stats.insert("ed_decompress_events".to_string(), self.ed_decompress_events.len());
// stats.insert("secp256k1_add_events".to_string(), self.secp256k1_add_events.len());
// stats.insert("secp256k1_double_events".to_string(), self.secp256k1_double_events.len());
// stats.insert("bn254_add_events".to_string(), self.bn254_add_events.len());
// stats.insert("bn254_double_events".to_string(), self.bn254_double_events.len());
// stats.insert("k256_decompress_events".to_string(), self.k256_decompress_events.len());
// stats.insert("bls12381_add_events".to_string(), self.bls12381_add_events.len());
// stats.insert("bls12381_double_events".to_string(), self.bls12381_double_events.len());
// stats.insert("uint256_mul_events".to_string(), self.uint256_mul_events.len());
// stats.insert("bls12381_fp_event".to_string(), self.bls12381_fp_events.len());
// stats.insert(
// "bls12381_fp2_addsub_events".to_string(),
// self.bls12381_fp2_addsub_events.len(),
// );
// stats.insert("bls12381_fp2_mul_events".to_string(), self.bls12381_fp2_mul_events.len());
// stats.insert("bn254_fp_events".to_string(), self.bn254_fp_events.len());
// stats.insert("bn254_fp2_addsub_events".to_string(), self.bn254_fp2_addsub_events.len());
// stats.insert("bn254_fp2_mul_events".to_string(), self.bn254_fp2_mul_events.len());
// stats.insert(
// "bls12381_decompress_events".to_string(),
// self.bls12381_decompress_events.len(),
// );

for (syscall_code, events) in self.precompile_events.iter() {
stats.insert(format!("syscall {syscall_code:?}"), events.len());
}

stats.insert(
"global_memory_initialize_events".to_string(),
self.global_memory_initialize_events.len(),
Expand Down Expand Up @@ -397,64 +380,3 @@ impl ByteRecord for ExecutionRecord {
add_sharded_byte_lookup_events(&mut self.byte_lookups, new_events);
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PrecompileEvents {
events: HashMap<SyscallCode, Vec<PrecompileEvent>>,
}

impl Default for PrecompileEvents {
fn default() -> Self {
let mut events = HashMap::new();
for syscall_code in SyscallCode::iter() {
if syscall_code.should_send() == 1 {
events.insert(syscall_code, Vec::new());
}
}

Self { events }
}
}

impl PrecompileEvents {
fn append(&mut self, other: &mut PrecompileEvents) {
for (syscall, events) in other.events.iter_mut() {
if !events.is_empty() {
self.events.entry(*syscall).or_default().append(events);
}
}
}

#[inline]
fn add_precompile_event(&mut self, syscall_code: SyscallCode, event: PrecompileEvent) {
assert!(syscall_code.should_send() == 1);
self.events.entry(syscall_code).or_default().push(event);
}

#[inline]
fn insert(&mut self, syscall_code: SyscallCode, events: Vec<PrecompileEvent>) {
assert!(syscall_code.should_send() == 1);
self.events.insert(syscall_code, events);
}

#[inline]
fn into_iter(self) -> impl Iterator<Item = (SyscallCode, Vec<PrecompileEvent>)> {
self.events.into_iter()
}

#[inline]
fn get_events(&self, syscall_code: SyscallCode) -> &Vec<PrecompileEvent> {
assert!(syscall_code.should_send() == 1);
&self.events[&syscall_code]
}

fn get_local_mem_events(&self) -> impl Iterator<Item = &MemoryLocalEvent> {
let mut iterators = Vec::new();

for (_, events) in self.events.iter() {
iterators.push(events.get_local_mem_events());
}

iterators.into_iter().flatten()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ impl<E: EdwardsParameters> Syscall for EdwardsDecompressSyscall<E> {

let (y_memory_records_vec, y_vec) =
rt.mr_slice(slice_ptr + (COMPRESSED_POINT_BYTES as u32), WORDS_FIELD_ELEMENT);

let y_memory_records: [MemoryReadRecord; 8] = y_memory_records_vec.try_into().unwrap();

let sign_bool = sign != 0;
Expand Down
7 changes: 6 additions & 1 deletion crates/core/executor/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::Opcode;

/// Serialize a `HashMap<u32, V>` as a `Vec<(u32, V)>`.
pub fn serialize_hashmap_as_vec<V: Serialize, S: Serializer>(
map: &HashMap<u32, V, BuildNoHashHasher<u32>>,
serializer: S,
) -> Result<S::Ok, S::Error> {
Serialize::serialize(&map.iter().collect::<Vec<_>>(), serializer)
}

/// Deserialize a `Vec<(u32, V)>` as a `HashMap<u32, V>`.
pub fn deserialize_hashmap_as_vec<'de, V: Deserialize<'de>, D: Deserializer<'de>>(
deserializer: D,
) -> Result<HashMap<u32, V, BuildNoHashHasher<u32>>, D::Error> {
Expand All @@ -20,11 +22,13 @@ pub fn deserialize_hashmap_as_vec<'de, V: Deserialize<'de>, D: Deserializer<'de>
}

/// Returns `true` if the given `opcode` is a signed operation.
#[must_use]
pub fn is_signed_operation(opcode: Opcode) -> bool {
opcode == Opcode::DIV || opcode == Opcode::REM
}

/// Calculate the correct `quotient` and `remainder` for the given `b` and `c` per RISC-V spec.
#[must_use]
pub fn get_quotient_and_remainder(b: u32, c: u32, opcode: Opcode) -> (u32, u32) {
if c == 0 {
// When c is 0, the quotient is 2^32 - 1 and the remainder is b regardless of whether we
Expand All @@ -33,11 +37,12 @@ pub fn get_quotient_and_remainder(b: u32, c: u32, opcode: Opcode) -> (u32, u32)
} else if is_signed_operation(opcode) {
((b as i32).wrapping_div(c as i32) as u32, (b as i32).wrapping_rem(c as i32) as u32)
} else {
((b as u32).wrapping_div(c as u32) as u32, (b as u32).wrapping_rem(c as u32) as u32)
(b.wrapping_div(c), b.wrapping_rem(c))
}
}

/// Calculate the most significant bit of the given 32-bit integer `a`, and returns it as a u8.
#[must_use]
pub const fn get_msb(a: u32) -> u8 {
((a >> 31) & 1) as u8
}
1 change: 0 additions & 1 deletion crates/core/machine/src/riscv/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ impl CostEstimator for ExecutionReport {
+ self.opcode_counts[Opcode::REM]
+ self.opcode_counts[Opcode::DIVU]
+ self.opcode_counts[Opcode::REMU];

total_area += (divrem_events as u64) * costs[&RiscvAirDiscriminants::DivRem];
total_chips += 1;

Expand Down
Loading

0 comments on commit d8e89a4

Please sign in to comment.