diff --git a/script/src/bin/cli/mod.rs b/script/src/bin/cli/mod.rs index 45bcd56..f44aab9 100644 --- a/script/src/bin/cli/mod.rs +++ b/script/src/bin/cli/mod.rs @@ -1,2 +1,3 @@ pub mod fork; pub mod operation; +pub mod stf_mode; diff --git a/script/src/bin/cli/operation.rs b/script/src/bin/cli/operation.rs index 5dfeb62..5868bda 100644 --- a/script/src/bin/cli/operation.rs +++ b/script/src/bin/cli/operation.rs @@ -4,7 +4,7 @@ use derive_more::Display; #[derive(Debug, Clone, Parser)] pub struct OperationArgs { #[clap(long, short)] - pub operation_name: OperationName, + pub operation_name: Option, } #[derive(ValueEnum, Debug, Clone, Display)] diff --git a/script/src/bin/cli/stf_mode.rs b/script/src/bin/cli/stf_mode.rs new file mode 100644 index 0000000..148dcae --- /dev/null +++ b/script/src/bin/cli/stf_mode.rs @@ -0,0 +1,17 @@ +use clap::{Parser, ValueEnum}; +use derive_more::Display; + +#[derive(Debug, Clone, Parser)] +pub struct STFModeArgs { + #[clap(long, short)] + pub stf_mode: STFMode, +} + +#[derive(ValueEnum, Debug, Clone, Display, PartialEq, Eq)] +#[clap(rename_all = "snake_case")] +pub enum STFMode { + #[display("operation")] + Operation, + #[display("epoch_processing")] + EpochProcessing, +} diff --git a/script/src/bin/main.rs b/script/src/bin/main.rs index 5b6a1f2..0c4be52 100644 --- a/script/src/bin/main.rs +++ b/script/src/bin/main.rs @@ -1,29 +1,23 @@ use clap::Parser; -use sp1_sdk::{include_elf, ProverClient, SP1Stdin}; -use tracing::{error, info}; - -use ream_consensus::deneb::beacon_state::BeaconState; -use ream_lib::{file::read_file, input::OperationInput}; +use tracing::error; mod cli; -use cli::operation::OperationName; +mod stf; -/// The ELF (executable and linkable format) file for the Succinct RISC-V zkVM. -pub const REAM_ELF: &[u8] = include_elf!("ream-operations"); +use cli::stf_mode::STFMode; /// The arguments for the command. -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Clone)] #[clap(author, version, about, long_about = None)] struct Args { - /// Argument for zkVMs - #[clap(long)] execute: bool, #[clap(long)] prove: bool, - /// Argument for STFs + #[clap(flatten)] + stf_mode: cli::stf_mode::STFModeArgs, #[clap(flatten)] fork: cli::fork::ForkArgs, @@ -61,115 +55,10 @@ fn main() { std::process::exit(1); } - let fork = args.fork.fork; - let operation_name = args.operation.operation_name; - let excluded_cases = args.excluded_cases; - - // Load the test assets. - // These assets are from consensus-specs repo. - let base_dir = test_case_dir - .join(format!("{}", fork)) - .join("operations") - .join(format!("{}", operation_name)) - .join("pyspec_tests"); - - let test_cases = ream_lib::file::get_test_cases(&base_dir); - for test_case in test_cases { - if excluded_cases.contains(&test_case) { - info!("Skipping test case: {}", test_case); - continue; - } - - info!("{}", "-".repeat(50)); - info!("[{}] Test case: {}", operation_name, test_case); - - let case_dir = &base_dir.join(&test_case); - let input_path = &case_dir.join(format!("{}.ssz_snappy", operation_name.to_input_name())); - - let pre_state: BeaconState = read_file(&case_dir.join("pre.ssz_snappy")); - let input = match operation_name { - OperationName::Attestation => OperationInput::Attestation(read_file(input_path)), - OperationName::AttesterSlashing => { - OperationInput::AttesterSlashing(read_file(input_path)) - } - OperationName::BlockHeader => OperationInput::BeaconBlock(read_file(input_path)), - OperationName::BLSToExecutionChange => { - OperationInput::SignedBLSToExecutionChange(read_file(input_path)) - } - OperationName::Deposit => OperationInput::Deposit(read_file(input_path)), - OperationName::ExecutionPayload => { - OperationInput::BeaconBlockBody(read_file(input_path)) - } - OperationName::ProposerSlashing => { - OperationInput::ProposerSlashing(read_file(input_path)) - } - OperationName::SyncAggregate => OperationInput::SyncAggregate(read_file(input_path)), - OperationName::VoluntaryExit => { - OperationInput::SignedVoluntaryExit(read_file(input_path)) - } - OperationName::Withdrawals => OperationInput::ExecutionPayload(read_file(input_path)), - }; - let post_state_opt: Option = { - if case_dir.join("post.ssz_snappy").exists() { - Some(read_file(&case_dir.join("post.ssz_snappy"))) - } else { - None - } - }; - - // Setup the prover client. - let client = ProverClient::from_env(); - - // Setup the inputs. - let mut stdin = SP1Stdin::new(); - stdin.write(&pre_state); - stdin.write(&input); - - if args.execute { - // Execute the program - let (output, report) = client.execute(REAM_ELF, &stdin).run().unwrap(); - info!("Program executed successfully."); - - // Decode the output - let result: BeaconState = ssz::Decode::from_ssz_bytes(output.as_slice()).unwrap(); - - // Match `post_state_opt`: some test cases should not mutate beacon state. - match post_state_opt { - Some(post_state) => { - assert_eq!(result, post_state); - info!("Execution is correct!: State mutated"); - } - None => { - assert_eq!(result, pre_state); - info!("Execution is correct!: State should not be mutated"); - } - } - - // Record the number of cycles executed. - info!("----- Cycle Tracker -----"); - info!("[{}] Test case: {}", operation_name, test_case); - info!("Number of cycles: {}", report.total_instruction_count()); - info!("Number of syscall count: {}", report.total_syscall_count()); - for (key, value) in report.cycle_tracker.iter() { - info!("{}: {}", key, value); - } - info!("----- Cycle Tracker End -----"); - } else { - // Setup the program for proving. - let (pk, vk) = client.setup(REAM_ELF); - - // Generate the proof - let proof = client - .prove(&pk, &stdin) - .run() - .expect("failed to generate proof"); - - info!("Successfully generated proof!"); + let stf_mode = args.clone().stf_mode.stf_mode; - // Verify the proof. - client.verify(&proof, &vk).expect("failed to verify proof"); - info!("Successfully verified proof!"); - } - info!("{}", "-".repeat(50)); + match stf_mode { + STFMode::Operation => stf::operation::run_operation(test_case_dir, args), + STFMode::EpochProcessing => todo!(), } } diff --git a/script/src/bin/stf/mod.rs b/script/src/bin/stf/mod.rs new file mode 100644 index 0000000..55bff9c --- /dev/null +++ b/script/src/bin/stf/mod.rs @@ -0,0 +1 @@ +pub mod operation; diff --git a/script/src/bin/stf/operation.rs b/script/src/bin/stf/operation.rs new file mode 100644 index 0000000..def95c6 --- /dev/null +++ b/script/src/bin/stf/operation.rs @@ -0,0 +1,133 @@ +use std::path::PathBuf; + +use ream_consensus::deneb::beacon_state::BeaconState; +use ream_lib::{file::read_file, input::OperationInput}; +use sp1_sdk::{include_elf, ProverClient, SP1Stdin}; +use tracing::info; + +use crate::{ + cli::{operation::OperationName, stf_mode::STFMode}, + Args, +}; + +/// The ELF (executable and linkable format) file for the Succinct RISC-V zkVM. +pub const REAM_ELF: &[u8] = include_elf!("ream-operations"); + +pub fn run_operation(test_case_dir: PathBuf, args: Args) { + assert!(args.stf_mode.stf_mode == STFMode::Operation); + + let operation_name = args + .operation + .operation_name + .expect("operation-name must be provided"); + let excluded_cases = args.excluded_cases; + + // Load the test assets. + // These assets are from consensus-specs repo. + let base_dir = test_case_dir + .join(format!("{}", args.fork.fork)) + .join("operations") + .join(format!("{}", operation_name)) + .join("pyspec_tests"); + + let test_cases = ream_lib::file::get_test_cases(&base_dir); + + for test_case in test_cases { + if excluded_cases.contains(&test_case) { + info!("Skipping test case: {}", test_case); + continue; + } + + info!("{}", "-".repeat(50)); + info!("[{}] Test case: {}", operation_name, test_case); + + let case_dir = &base_dir.join(&test_case); + let input_path = &case_dir.join(format!("{}.ssz_snappy", operation_name.to_input_name())); + + let pre_state: BeaconState = read_file(&case_dir.join("pre.ssz_snappy")); + let input = match operation_name { + OperationName::Attestation => OperationInput::Attestation(read_file(input_path)), + OperationName::AttesterSlashing => { + OperationInput::AttesterSlashing(read_file(input_path)) + } + OperationName::BlockHeader => OperationInput::BeaconBlock(read_file(input_path)), + OperationName::BLSToExecutionChange => { + OperationInput::SignedBLSToExecutionChange(read_file(input_path)) + } + OperationName::Deposit => OperationInput::Deposit(read_file(input_path)), + OperationName::ExecutionPayload => { + OperationInput::BeaconBlockBody(read_file(input_path)) + } + OperationName::ProposerSlashing => { + OperationInput::ProposerSlashing(read_file(input_path)) + } + OperationName::SyncAggregate => OperationInput::SyncAggregate(read_file(input_path)), + OperationName::VoluntaryExit => { + OperationInput::SignedVoluntaryExit(read_file(input_path)) + } + OperationName::Withdrawals => OperationInput::ExecutionPayload(read_file(input_path)), + }; + let post_state_opt: Option = { + if case_dir.join("post.ssz_snappy").exists() { + Some(read_file(&case_dir.join("post.ssz_snappy"))) + } else { + None + } + }; + + // Setup the prover client. + let client = ProverClient::from_env(); + + // Setup the inputs. + let mut stdin = SP1Stdin::new(); + stdin.write(&pre_state); + stdin.write(&input); + + if args.execute { + // Execute the program + let (output, report) = client.execute(REAM_ELF, &stdin).run().unwrap(); + info!("Program executed successfully."); + + // Decode the output + let result: BeaconState = ssz::Decode::from_ssz_bytes(output.as_slice()).unwrap(); + + // Match `post_state_opt`: some test cases should not mutate beacon state. + match post_state_opt { + Some(post_state) => { + assert_eq!(result, post_state); + info!("Execution is correct!: State mutated"); + } + None => { + assert_eq!(result, pre_state); + info!("Execution is correct!: State should not be mutated"); + } + } + + // Record the number of cycles executed. + info!("----- Cycle Tracker -----"); + info!("[{}] Test case: {}", operation_name, test_case); + info!("Number of cycles: {}", report.total_instruction_count()); + info!("Number of syscall count: {}", report.total_syscall_count()); + for (key, value) in report.cycle_tracker.iter() { + info!("{}: {}", key, value); + } + info!("----- Cycle Tracker End -----"); + } else { + // Setup the program for proving. + let (pk, vk) = client.setup(REAM_ELF); + + // Generate the proof + let proof = client + .prove(&pk, &stdin) + .run() + .expect("failed to generate proof"); + + info!("Successfully generated proof!"); + + // Verify the proof. + client.verify(&proof, &vk).expect("failed to verify proof"); + info!("Successfully verified proof!"); + } + info!("{}", "-".repeat(50)); + } +}