diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 2d7777911..39a81e540 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -210,6 +210,57 @@ pub fn convert_analyzed_to_pil( )) } +#[allow(clippy::too_many_arguments)] +pub fn convert_analyzed_to_pil_with_callback>( + file_name: &str, + monitor: &mut DiffMonitor, + analyzed: AnalysisASMFile, + query_callback: Q, + output_dir: &Path, + force_overwrite: bool, + prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, +) -> Result<(PathBuf, Option>), Vec> { + let constraints = convert_analyzed_to_pil_constraints(analyzed, monitor); + log::debug!("Run airgen"); + let graph = airgen::compile(constraints); + log::debug!("Airgen done"); + log::trace!("{graph}"); + log::debug!("Run linker"); + let pil = linker::link(graph)?; + log::debug!("Linker done"); + log::trace!("{pil}"); + + let pil_file_name = format!( + "{}.pil", + Path::new(file_name).file_stem().unwrap().to_str().unwrap() + ); + + let pil_file_path = output_dir.join(pil_file_name); + if pil_file_path.exists() && !force_overwrite { + eprintln!( + "Target file {} already exists. Not overwriting.", + pil_file_path.to_str().unwrap() + ); + return Ok((pil_file_path, None)); + } + + fs::write(&pil_file_path, format!("{pil}")).unwrap(); + + let pil_file_name = pil_file_path.file_name().unwrap(); + Ok(( + pil_file_path.clone(), + Some(compile_pil_ast( + &pil, + pil_file_name, + output_dir, + query_callback, + prove_with, + external_witness_values, + )), + )) +} + pub type AnalyzedASTHook<'a, T> = &'a mut dyn FnMut(&AnalysisASMFile); /// Compiles the contents of a .asm file, outputs the PIL on stdout and tries to generate @@ -244,6 +295,34 @@ pub fn compile_asm_string( ) } +#[allow(clippy::too_many_arguments)] +pub fn compile_asm_string_with_callback>( + file_name: &str, + contents: &str, + query_callback: Q, + analyzed_hook: Option>, + output_dir: &Path, + force_overwrite: bool, + prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, +) -> Result<(PathBuf, Option>), Vec> { + let mut monitor = DiffMonitor::default(); + let analyzed = compile_asm_string_to_analyzed_ast(file_name, contents, Some(&mut monitor))?; + if let Some(hook) = analyzed_hook { + hook(&analyzed); + }; + convert_analyzed_to_pil_with_callback( + file_name, + &mut monitor, + analyzed, + query_callback, + output_dir, + force_overwrite, + prove_with, + external_witness_values, + ) +} + pub struct CompilationResult { /// Constant columns, potentially incomplete (if success is false) pub constants: Vec<(String, Vec)>, @@ -364,6 +443,26 @@ pub fn inputs_to_query_callback(inputs: Vec) -> impl QueryCa )) } } + ["\"data\"", index, what] => { + let index = index + .parse::() + .map_err(|e| format!("Error parsing index: {e})"))?; + let what = what + .parse::() + .map_err(|e| format!("Error parsing what: {e})"))?; + assert_eq!(what, 0); + + let value = inputs.get(index).cloned(); + if let Some(value) = value { + log::trace!("Input query: Index {index} -> {value}"); + Ok(Some(value)) + } else { + Err(format!( + "Error accessing prover inputs: Index {index} out of bounds {}", + inputs.len() + )) + } + } ["\"bootloader_input\"", index] => { let index = index .parse::() diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index 0380c39e7..f62b318eb 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -426,6 +426,7 @@ fn run_command(command: Commands) { continuations, } => match (just_execute, continuations) { (true, true) => { + assert!(matches!(field, FieldArgument::Gl)); let contents = fs::read_to_string(&file).unwrap(); let inputs = split_inputs::(&inputs); rust_continuations(file.as_str(), contents.as_str(), inputs); @@ -433,6 +434,10 @@ fn run_command(command: Commands) { (true, false) => { let contents = fs::read_to_string(&file).unwrap(); let inputs = split_inputs::(&inputs); + let inputs: HashMap> = + vec![(GoldilocksField::from(0), inputs)] + .into_iter() + .collect(); riscv_executor::execute::(&contents, &inputs, &default_input()); } (false, true) => { @@ -581,7 +586,9 @@ fn handle_riscv_asm( rust_continuations(file_name, contents, inputs); } (true, false) => { - riscv_executor::execute::(contents, &inputs, &default_input()); + let mut inputs_hash: HashMap> = HashMap::default(); + inputs_hash.insert(0u32.into(), inputs); + riscv_executor::execute::(contents, &inputs_hash, &default_input()); } (false, true) => { unimplemented!("Running witgen with continuations is not supported yet.") @@ -626,6 +633,8 @@ fn rust_continuations(file_name: &str, contents: &str, inputs: let program = compiler::compile_asm_string_to_analyzed_ast::(file_name, contents, None).unwrap(); + let inputs: HashMap> = vec![(F::from(0), inputs)].into_iter().collect(); + log::info!("Executing powdr-asm..."); let (full_trace, memory_accesses) = { let trace = diff --git a/riscv/runtime/src/coprocessors.rs b/riscv/runtime/src/coprocessors.rs index 9e0c7eced..7232be0f3 100644 --- a/riscv/runtime/src/coprocessors.rs +++ b/riscv/runtime/src/coprocessors.rs @@ -1,10 +1,25 @@ -// This is a dummy implementation of Poseidon hash, -// which will be replaced with a call to the Poseidon -// coprocessor during compilation. -// The function itself will be removed by the compiler -// during the reachability analysis. extern "C" { + // This is a dummy implementation of Poseidon hash, + // which will be replaced with a call to the Poseidon + // coprocessor during compilation. + // The function itself will be removed by the compiler + // during the reachability analysis. fn poseidon_gl_coprocessor(data: *mut [u64; 12]); + + // This will be replaced by a call to prover input. + fn input_coprocessor(index: u32, what: u32) -> u32; +} + +extern crate alloc; + +pub fn get_data(what: u32, data: &mut [u32]) { + for (i, d) in data.iter_mut().enumerate() { + *d = unsafe { input_coprocessor(what, (i + 1) as u32) }; + } +} + +pub fn get_data_len(what: u32) -> usize { + unsafe { input_coprocessor(what, 0) as usize } } const GOLDILOCKS: u64 = 0xffffffff00000001; diff --git a/riscv/src/coprocessors.rs b/riscv/src/coprocessors.rs index a5c62868f..6e0c37eb7 100644 --- a/riscv/src/coprocessors.rs +++ b/riscv/src/coprocessors.rs @@ -65,11 +65,20 @@ instr poseidon_gl A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11 -> X, Y, Z, W runtime_function_impl: Some(("poseidon_gl_coprocessor", poseidon_gl_call)), }; -static ALL_COPROCESSORS: [(&str, &CoProcessor); 4] = [ +static INPUT_COPROCESSOR: CoProcessor = CoProcessor { + name: "prover_input", + ty: "", + import: "", + instructions: "", + runtime_function_impl: Some(("input_coprocessor", prover_input_call)), +}; + +static ALL_COPROCESSORS: [(&str, &CoProcessor); 5] = [ (BINARY_COPROCESSOR.name, &BINARY_COPROCESSOR), (SHIFT_COPROCESSOR.name, &SHIFT_COPROCESSOR), (SPLIT_GL_COPROCESSOR.name, &SPLIT_GL_COPROCESSOR), (POSEIDON_GL_COPROCESSOR.name, &POSEIDON_GL_COPROCESSOR), + (INPUT_COPROCESSOR.name, &INPUT_COPROCESSOR), ]; /// Defines which coprocessors should be used by the RISCV machine. @@ -116,6 +125,7 @@ impl CoProcessors { coprocessors: BTreeMap::from([ (BINARY_COPROCESSOR.name, &BINARY_COPROCESSOR), (SHIFT_COPROCESSOR.name, &SHIFT_COPROCESSOR), + (INPUT_COPROCESSOR.name, &INPUT_COPROCESSOR), ]), } } @@ -134,7 +144,11 @@ impl CoProcessors { } pub fn declarations(&self) -> Vec<(&'static str, &'static str)> { - self.coprocessors.values().map(|c| (c.name, c.ty)).collect() + self.coprocessors + .values() + .filter(|c| !c.ty.is_empty()) + .map(|c| (c.name, c.ty)) + .collect() } pub fn machine_imports(&self) -> Vec<&'static str> { @@ -239,6 +253,10 @@ fn poseidon_gl_call() -> String { .collect() } +fn prover_input_call() -> String { + "x10 <=X= ${ (\"data\", x11, x10) };".to_string() +} + // This could also potentially go in the impl of CoProcessors, // but I purposefully left it outside because it should be removed eventually. pub fn call_every_submachine(coprocessors: &CoProcessors) -> Vec { diff --git a/riscv/tests/common/mod.rs b/riscv/tests/common/mod.rs index b616741ba..d34d06213 100644 --- a/riscv/tests/common/mod.rs +++ b/riscv/tests/common/mod.rs @@ -4,16 +4,26 @@ use compiler::{ }; use number::GoldilocksField; use riscv::bootloader::default_input; +use std::collections::HashMap; /// Like compiler::verify::verify_asm_string, but also runs RISCV executor. pub fn verify_riscv_asm_string(file_name: &str, contents: &str, inputs: Vec) { let temp_dir = mktemp::Temp::new_dir().unwrap().release(); + + let mut inputs_hash: HashMap> = HashMap::default(); + inputs_hash.insert(0u32.into(), inputs.clone()); + let (_, result) = compile_asm_string( file_name, contents, inputs.clone(), Some(&mut |analyzed| { - riscv_executor::execute_ast(analyzed, &inputs.clone(), &default_input(), usize::MAX); + riscv_executor::execute_ast( + analyzed, + &inputs_hash.clone(), + &default_input(), + usize::MAX, + ); }), &temp_dir, true, diff --git a/riscv/tests/riscv_data/evm/src/lib.rs b/riscv/tests/riscv_data/evm/src/lib.rs index a46e11003..8ac35351f 100644 --- a/riscv/tests/riscv_data/evm/src/lib.rs +++ b/riscv/tests/riscv_data/evm/src/lib.rs @@ -7,9 +7,10 @@ use revm::{ }, EVM, }; -use runtime::{print, get_prover_input}; +use runtime::{print, coprocessors::{get_data, get_data_len}}; extern crate alloc; +use alloc::vec; use alloc::vec::Vec; #[no_mangle] @@ -19,8 +20,10 @@ fn main() { b256!("e3c84e69bac71c159b2ff0d62b9a5c231887a809a96cb4a262a4b96ed78a1db2"); let mut db = CacheDB::new(EmptyDB::default()); - let bytecode_len = get_prover_input(0); - let bytecode: Vec<_> = (1..(bytecode_len + 1)).map(|idx| get_prover_input(idx) as u8).collect(); + let bytecode_len = get_data_len(0); + let mut bytecode = vec![0; bytecode_len]; + get_data(0, &mut bytecode); + let bytecode: Vec = bytecode.into_iter().map(|x| x as u8).collect(); // Fill database: let bytecode = Bytes::from(bytecode); diff --git a/riscv_executor/src/lib.rs b/riscv_executor/src/lib.rs index 8bba51a3c..a90be88fe 100644 --- a/riscv_executor/src/lib.rs +++ b/riscv_executor/src/lib.rs @@ -440,7 +440,7 @@ fn preprocess_main_function(machine: &Machine) -> Preprocess struct Executor<'a, 'b, F: FieldElement> { proc: TraceBuilder<'a, 'b>, label_map: HashMap<&'a str, Elem>, - inputs: &'b [F], + inputs: HashMap>, bootloader_inputs: &'b [F], stdout: io::Stdout, } @@ -717,7 +717,13 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { break 'input vec![match name.as_str() { "input" => { let idx = val.u() as usize; - to_u32(&self.inputs[idx]).unwrap().into() + to_u32(&self.inputs[&F::zero()][idx]).unwrap().into() + } + "data" => { + let idx = val.u() as usize; + let what = self.eval_expression(&t[2])[0]; + let what = what.u(); + to_u32(&self.inputs[&what.into()][idx]).unwrap().into() } "bootloader_input" => { let idx = val.u() as usize; @@ -745,7 +751,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { pub fn execute_ast<'a, T: FieldElement>( program: &'a AnalysisASMFile, - inputs: &[T], + inputs: &HashMap>, bootloader_inputs: &[T], max_steps_to_execute: usize, ) -> (ExecutionTrace<'a>, MemoryState) { @@ -765,7 +771,7 @@ pub fn execute_ast<'a, T: FieldElement>( let mut e = Executor { proc, label_map, - inputs, + inputs: inputs.clone(), bootloader_inputs, stdout: io::stdout(), }; @@ -814,7 +820,7 @@ pub fn execute_ast<'a, T: FieldElement>( curr_pc = match e.proc.advance(is_nop) { Some(pc) => pc, None => break, - } + }; } e.proc.finish() @@ -824,7 +830,11 @@ pub fn execute_ast<'a, T: FieldElement>( /// /// Generic argument F is just used by the parser, before everything is /// converted to i64, so it is important to the execution itself. -pub fn execute(asm_source: &str, inputs: &[F], bootloader_inputs: &[F]) { +pub fn execute( + asm_source: &str, + inputs: &HashMap>, + bootloader_inputs: &[F], +) { log::info!("Parsing..."); let parsed = parser::parse_asm::(None, asm_source).unwrap(); log::info!("Resolving imports...");