diff --git a/Cargo.toml b/Cargo.toml index f7f7e51a8..d4c9d6f18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ log = "0.4.17" mktemp = "0.5.0" num-bigint = "^0.4" regex = "^1.7.0" +walkdir = "2.3.3" [build-dependencies] lalrpop = "^0.19" diff --git a/src/asm_compiler/mod.rs b/src/asm_compiler/mod.rs index b9c87646f..1d339cc48 100644 --- a/src/asm_compiler/mod.rs +++ b/src/asm_compiler/mod.rs @@ -314,7 +314,10 @@ impl ASMPILConverter { args: Vec, ) { assert!(write_regs.len() == 1); - let instr = &self.instructions[&instr_name]; + let instr = &self + .instructions + .get(&instr_name) + .unwrap_or_else(|| panic!("Intruction not found: {instr_name}")); assert_eq!(instr.outputs.len(), 1); let output = instr.outputs[0].clone(); assert!( @@ -331,7 +334,10 @@ impl ASMPILConverter { } fn handle_instruction(&mut self, instr_name: String, args: Vec) { - let instr = &self.instructions[&instr_name]; + let instr = &self + .instructions + .get(&instr_name) + .unwrap_or_else(|| panic!("Instruction not found: {instr_name}")); assert_eq!(instr.inputs.len() + instr.outputs.len(), args.len()); let mut args = args.into_iter(); diff --git a/src/bin/compiler.rs b/src/bin/compiler.rs index 1a38ae003..b138e2f09 100644 --- a/src/bin/compiler.rs +++ b/src/bin/compiler.rs @@ -17,9 +17,14 @@ enum Commands { /// and finally to PIL and generates fixed and witness columns. /// Needs `rustup target add riscv32imc-unknown-none-elf`. Rust { - /// Input file + /// Input file or directory. file: String, + /// Compile a full cargo crate with dependencies. + #[arg(long)] + #[arg(default_value_t = false)] + cargo: bool, + /// Comma-separated list of free inputs (numbers). #[arg(short, long)] #[arg(default_value_t = String::new())] @@ -117,12 +122,14 @@ fn main() { match command { Commands::Rust { file, + cargo, inputs, output_directory, force, } => { powdr::riscv::compile_rust( &file, + cargo, split_inputs(&inputs), Path::new(&output_directory), force, diff --git a/src/riscv/compiler.rs b/src/riscv/compiler.rs index 0c71536ed..b5121541b 100644 --- a/src/riscv/compiler.rs +++ b/src/riscv/compiler.rs @@ -1,4 +1,4 @@ -use std::{fs, path::Path}; +use std::{collections::BTreeMap, fs, path::Path}; use lazy_static::lazy_static; use regex::Regex; @@ -15,6 +15,11 @@ pub fn compile_file(file: &Path) { /// Compiles riscv assembly to POWDR assembly. Adds required library routines. pub fn compile_riscv_asm(data: &str) -> String { + // stack grows towards zero + let stack_start = 0x10000; + // data grows away from zero + let data_start = 0x20000; + let statements = parser::parse_asm(data); let labels = parser::extract_labels(&statements); let label_references = parser::extract_label_references(&statements); @@ -26,20 +31,88 @@ pub fn compile_riscv_asm(data: &str) -> String { .map(|label| library_routine(label)) .collect::>() .join("\n"); - let mut output = preamble(); + let statements = parser::parse_asm(&data); + let (data_code, data_positions) = store_data_objects( + parser::extract_data_objects(&statements).into_iter(), + data_start, + ); + let mut output = preamble() + + &data_code + + &format!("// Set stack pointer\nx2 <=X= {stack_start};\n") + + "jump main;\n"; - for s in parser::parse_asm(&data) { + let statements = insert_data_positions(statements, &data_positions); + + for s in statements { output += &process_statement(s); } output } +fn store_data_objects( + objects: impl Iterator)>, + mut memory_start: u32, +) -> (String, BTreeMap) { + memory_start = ((memory_start + 7) / 8) * 8; + let mut code = String::new(); + let mut positions = BTreeMap::new(); + for (name, data) in objects { + code += &format!("// data {name}\n"); + positions.insert(name, memory_start); + for i in 0..((data.len() + 3) / 4) { + let v = (0..4) + .map(|j| (data.get(i * 4 + j).cloned().unwrap_or_default() as u32) << (j * 8)) + .reduce(|a, b| a | b) + .unwrap(); + code += &format!( + "addr <=X= 0x{:x};\nmstore 0x{v:x};\n", + memory_start + (i * 4) as u32 + ); + } + memory_start += (((data.len() + 7) / 8) * 8) as u32; + } + (code, positions) +} + +fn insert_data_positions( + mut statements: Vec, + data_positions: &BTreeMap, +) -> Vec { + for s in &mut statements { + let Statement::Instruction(_name, args) = s else { continue; }; + for arg in args { + match arg { + Argument::RegOffset(_, offset) => replace_data_reference(offset, data_positions), + Argument::Constant(c) => replace_data_reference(c, data_positions), + _ => {} + } + } + } + statements +} + +fn replace_data_reference(constant: &mut Constant, data_positions: &BTreeMap) { + match constant { + Constant::Number(_) => {} + Constant::HiDataRef(data) => { + *constant = Constant::Number((data_positions[data] >> 16) as i64) + } + Constant::LoDataRef(data) => { + *constant = Constant::Number((data_positions[data] & 0xffff) as i64) + } + } +} + fn preamble() -> String { r#" +degree 262144; reg pc[@pc]; reg X[<=]; reg Y[<=]; reg Z[<=]; +reg tmp1; +reg tmp2; +reg tmp3; "# .to_string() + &(0..32) @@ -116,7 +189,9 @@ instr mload -> X { { addr, STEP, X } is m_is_read { m_addr, m_step, m_value } } // ============== control-flow instructions ============== instr jump l: label { pc' = l } -instr call l: label { pc' = l, x1' = pc + 1, x6' = l } +instr call l: label { pc' = l, x1' = pc + 1 } +// TODO x6 actually stores some relative address, but only part of it. +instr tail l: label { pc' = l, x6' = l } instr ret { pc' = x1 } instr branch_if_nonzero X, l: label { pc' = (1 - XIsZero) * l + XIsZero * (pc + 1) } @@ -133,13 +208,97 @@ instr branch_if_positive X, l: label { instr is_equal_zero X -> Y { Y = XIsZero } -// ================= arith/bitwise instructions ================= +// ================= binary/bitwise instructions ================= -// instr xor X, Y, Z { -// {X, Y, Z} in 1 { binary.X, binary.Y, binary.RESULT, 1 } -// } -// we wanted better synatx: { binary(X, Y, Z) } -// maybe alternate syntax: instr xor a(Y), b(Z) -> X +instr and <=Y= a, <=Z= b, c <=X= { + {Y, Z, X, 0} in binary_RESET { binary_A, binary_B, binary_C, binary_operation } +} + +instr or <=Y= a, <=Z= b, c <=X= { + {Y, Z, X, 1} in binary_RESET { binary_A, binary_B, binary_C, binary_operation } +} + +instr xor <=Y= a, <=Z= b, c <=X= { + {Y, Z, X, 2} in binary_RESET { binary_A, binary_B, binary_C, binary_operation } +} + +pil{ + macro is_nonzero(X) { X / X }; // 0 / 0 == 0 makes this work... + macro is_zero(X) { 1 - is_nonzero(X) }; + + col fixed binary_RESET(i) { is_zero((i % 4) - 3) }; + col fixed binary_FACTOR(i) { 1 << (((i + 1) % 4) * 8) }; + + col fixed binary_P_A(i) { i % 256 }; + col fixed binary_P_B(i) { (i >> 8) % 256 }; + col fixed binary_P_operation(i) { (i / (256 * 256)) % 3 }; + col fixed binary_P_C(i) { + match binary_P_operation(i) { + 0 => binary_P_A(i) & binary_P_B(i), + 1 => binary_P_A(i) | binary_P_B(i), + 2 => binary_P_A(i) ^ binary_P_B(i), + } & 0xff + }; + + col witness binary_A_byte; + col witness binary_B_byte; + col witness binary_C_byte; + + col witness binary_A; + col witness binary_B; + col witness binary_C; + col witness binary_operation; + + binary_A' = binary_A * (1 - binary_RESET) + binary_A_byte * binary_FACTOR; + binary_B' = binary_B * (1 - binary_RESET) + binary_B_byte * binary_FACTOR; + binary_C' = binary_C * (1 - binary_RESET) + binary_C_byte * binary_FACTOR; + (binary_operation' - binary_operation) * (1 - binary_RESET) = 0; + + {binary_operation', binary_A_byte, binary_B_byte, binary_C_byte} in {binary_P_operation, binary_P_A, binary_P_B, binary_P_C}; +} + +// ================= shift instructions ================= + +instr shl <=Y= a, <=Z= b, c <=X= { + {Y, Z, X, 0} in shift_RESET { shift_A, shift_B, shift_C, shift_operation } +} + +instr shr <=Y= a, <=Z= b, c <=X= { + {Y, Z, X, 1} in shift_RESET { shift_A, shift_B, shift_C, shift_operation } +} + +pil{ + col fixed shift_RESET(i) { is_zero((i % 4) - 3) }; + col fixed shift_FACTOR_ROW(i) { (i + 1) % 4 }; + col fixed shift_FACTOR(i) { 1 << (((i + 1) % 4) * 8) }; + + col fixed shift_P_A(i) { i % 256 }; + col fixed shift_P_B(i) { (i / 256) % 32 }; + col fixed shift_P_ROW(i) { (i / (256 * 32)) % 4 }; + col fixed shift_P_operation(i) { (i / (256 * 32 * 4)) % 2 }; + col fixed shift_P_C(i) { + match shift_P_operation(i) { + 0 => (shift_P_A(i) << (shift_P_B(i) + (shift_P_ROW(i) * 8))), + 1 => (shift_P_A(i) << (shift_P_ROW(i) * 8)) >> shift_P_B(i), + } & 0xffffffff + }; + + col witness shift_A_byte; + col witness shift_C_part; + + col witness shift_A; + col witness shift_B; + col witness shift_C; + col witness shift_operation; + + shift_A' = shift_A * (1 - shift_RESET) + shift_A_byte * shift_FACTOR; + (shift_B' - shift_B) * (1 - shift_RESET) = 0; + shift_C' = shift_C * (1 - shift_RESET) + shift_C_part; + (shift_operation' - shift_operation) * (1 - shift_RESET) = 0; + + // TODO this way, we cannot prove anything that shifts by more than 31 bits. + {shift_operation', shift_A_byte, shift_B', shift_FACTOR_ROW, shift_C_part} in {shift_P_operation, shift_P_A, shift_P_B, shift_P_ROW, shift_P_C}; +} // ================== wrapping instructions ============== @@ -160,6 +319,26 @@ pil{ wrap_bit * (1 - wrap_bit) = 0; } +// Input is a 32 bit unsigned number. We check the 7th bit and set all higher bits to that value. +instr sign_extend_byte <=Y= v, x <=X= { + // wrap_bit is used as sign_bit here. + Y = Y_7bit + wrap_bit * 0x80 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000, + X = Y_7bit + wrap_bit * 0xffffff80 +} +pil{ + col fixed seven_bit(i) { i & 0x7f }; + col witness Y_7bit; + { Y_7bit } in { seven_bit }; +} + +// Input is a 32 but unsined number (0 <= Y < 2**32) interpreted as a two's complement numbers. +// Returns a signed number (-2**31 <= X < 2**31). +instr to_signed <=Y= v, x <=X= { + // wrap_bit is used as sign_bit here. + Y = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + Y_7bit * 0x1000000 + wrap_bit * 0x80000000, + X = Y - wrap_bit * 2**31 +} + // ======================= assertions ========================= instr fail { 1 = 0 } @@ -173,10 +352,6 @@ pil{ { Y_b5 } in { bytes }; { Y_b6 } in { bytes }; } - -// set the stack pointer. -// TODO other things to initialize? -x2 <=X= 0x10000; "# } @@ -195,6 +370,31 @@ lazy_static! { .unwrap(), "unimp" ), + ( + Regex::new(r"^_ZN5alloc5alloc18handle_alloc_error17h[0-9a-f]{16}E$").unwrap(), + "unimp" + ), + ( + Regex::new(r"^_ZN5alloc7raw_vec17capacity_overflow17h[0-9a-f]{16}E$").unwrap(), + "unimp" + ), + ( + Regex::new(r"^_ZN5alloc7raw_vec17capacity_overflow17h[0-9a-f]{16}E$").unwrap(), + "unimp" + ), + ( + Regex::new(r"^_ZN5alloc5alloc18handle_alloc_error17h[0-9a-f]{16}E$").unwrap(), + "unimp" + ), + ( + Regex::new(r"^_ZN4core5slice5index26slice_start_index_len_fail17h[0-9a-f]{16}E$") + .unwrap(), + "unimp" + ), + // TODO rust alloc calls the global allocator - not sure why this is not automatic. + (Regex::new(r"^__rust_alloc$").unwrap(), "jmp __rg_alloc"), + (Regex::new(r"^__rust_realloc$").unwrap(), "jmp __rg_realloc"), + (Regex::new(r"^__rust_dealloc$").unwrap(), "jmp __rg_dealloc"), ( Regex::new(r"^memset@plt$").unwrap(), r#" @@ -209,6 +409,62 @@ lazy_static! { j memset@plt ___end_memset: ret +"# + ), + ( + Regex::new(r"^memcpy@plt$").unwrap(), +/* Source code for memcpy: +pub unsafe extern "C" fn memcpy(dest: *mut u8, src: *const u8, n: usize) -> *mut u8 { + // We only access u32 because then we do not have to deal with + // un-aligned memory access. + // TODO this does not really enforce that the pointers are u32-aligned. + let mut i: isize = 0; + while i + 3 < n as isize { + *((dest.offset(i)) as *mut u32) = *((src.offset(i)) as *mut u32); + i += 4; + } + if i < n as isize { + let value = *((src.offset(i)) as *mut u32); + let dest_value = (dest.offset(i)) as *mut u32; + let mask = (1 << (((n as isize - i) * 8) as u32)) - 1; + *dest_value = (*dest_value & !mask) | (value & mask); + } + dest +} +*/ + r#" + li a3, 4 + blt a2, a3, __memcpy_LBB2_5 + li a4, 0 +__memcpy_LBB2_2: + add a3, a1, a4 + lw a6, 0(a3) + add a7, a0, a4 + addi a3, a4, 4 + addi a5, a4, 7 + sw a6, 0(a7) + mv a4, a3 + blt a5, a2, __memcpy_LBB2_2 + bge a3, a2, __memcpy_LBB2_6 +__memcpy_LBB2_4: + add a1, a1, a3 + lw a1, 0(a1) + add a3, a3, a0 + slli a2, a2, 3 + lw a4, 0(a3) + li a5, -1 + sll a2, a5, a2 + not a5, a2 + and a2, a2, a4 + and a1, a1, a5 + or a1, a1, a2 + sw a1, 0(a3) + ret +__memcpy_LBB2_5: + li a3, 0 + blt a3, a2, __memcpy_LBB2_4 +__memcpy_LBB2_6: + ret "# ), ]; @@ -253,8 +509,9 @@ fn argument_to_number(x: &Argument) -> u32 { fn constant_to_number(c: &Constant) -> u32 { match c { Constant::Number(n) => *n as u32, - Constant::HiDataRef(_) => 0, // TODO - Constant::LoDataRef(_) => 0, // TODO + Constant::HiDataRef(n) | Constant::LoDataRef(n) => { + panic!("Data reference should have been replaced by number: {n}") + } } } @@ -313,6 +570,21 @@ fn rro(args: &[Argument]) -> (Register, Register, u32) { fn process_instruction(instr: &str, args: &[Argument]) -> String { match instr { + // load/store registers + "li" => { + let (rd, imm) = ri(args); + format!("{rd} <=X= {imm};\n") + } + "lui" => { + let (rd, imm) = ri(args); + format!("{rd} <=X= {};\n", imm << 12) + } + "mv" => { + let (rd, rs) = rr(args); + format!("{rd} <=X= {rs};\n") + } + + // Arithmetic "add" => { let (rd, r1, r2) = rrr(args); format!("{rd} <=X= wrap({r1} + {r2});\n") @@ -321,6 +593,71 @@ fn process_instruction(instr: &str, args: &[Argument]) -> String { let (rd, rs, imm) = rri(args); format!("{rd} <=X= wrap({rs} + {imm});\n") } + "sub" => { + let (rd, r1, r2) = rrr(args); + format!("{rd} <=X= wrap({r1} - {r2});\n") + } + "neg" => { + let (rd, r1) = rr(args); + format!("{rd} <=X= wrap(0 - {r1});\n") + } + + // bitwise + "xor" => { + let (rd, r1, r2) = rrr(args); + format!("{rd} <=X= xor({r1}, {r2});\n") + } + "xori" => { + let (rd, r1, imm) = rri(args); + format!("{rd} <=X= xor({r1}, {imm});\n") + } + "and" => { + let (rd, r1, r2) = rrr(args); + format!("{rd} <=X= and({r1}, {r2});\n") + } + "or" => { + let (rd, r1, r2) = rrr(args); + format!("{rd} <=X= or({r1}, {r2});\n") + } + "not" => { + let (rd, rs) = rr(args); + format!("{rd} <=X= xor({rs}, 0xffffffff);\n") + } + + // shift + "slli" => { + let (rd, rs, amount) = rri(args); + assert!(amount <= 31); + if amount <= 16 { + format!("{rd} <=X= wrap16({rs} * {});\n", 1 << amount) + } else { + format!("tmp1 <=X= wrap16({rs} * {});\n", 1 << 16) + + &format!("{rd} <=X= wrap16(tmp1 * {});\n", 1 << (amount - 16)) + } + } + "sll" => { + let (rd, r1, r2) = rrr(args); + format!("{rd} <=X= shl({r1}, {r2});\n") + } + "srli" => { + // logical shift right + let (rd, rs, amount) = rri(args); + assert!(amount <= 31); + format!("{rd} <=X= shr({rs}, {amount});\n") + } + "srl" => { + // logical shift right + let (rd, r1, r2) = rrr(args); + format!("{rd} <=X= shr({r1}, {r2});\n") + } + + // comparison + "seqz" => { + let (rd, rs) = rr(args); + format!("{rd} <=Y= is_equal_zero({rs});\n") + } + + // branching "beq" => { let (r1, r2, label) = rrl(args); format!("branch_if_zero {r1} - {r2}, {label};\n") @@ -331,12 +668,47 @@ fn process_instruction(instr: &str, args: &[Argument]) -> String { } "bgeu" => { let (r1, r2, label) = rrl(args); - format!("branch_if_positive {r1} - {r2}, {label};\n") + // TODO does this fulfill the input requirements for branch_if_positive? + format!("branch_if_positive {r1} - {r2} + 1, {label};\n") } "bltu" => { let (r1, r2, label) = rrl(args); format!("branch_if_positive {r2} - {r1}, {label};\n") } + "blt" => { + let (r1, r2, label) = rrl(args); + // Branch if r1 < r2 (signed). + // TODO does this fulfill the input requirements for branch_if_positive? + format!("tmp1 <=X= to_signed({r1});\n") + + &format!("tmp2 <=X= to_signed({r2});\n") + + &format!("branch_if_positive tmp2 - tmp1, {label};\n") + } + "bge" => { + let (r1, r2, label) = rrl(args); + // Branch if r1 >= r2 (signed). + // TODO does this fulfill the input requirements for branch_if_positive? + format!("tmp1 <=X= to_signed({r1});\n") + + &format!("tmp2 <=X= to_signed({r2});\n") + + &format!("branch_if_positive tmp1 - tmp2 + 1, {label};\n") + } + "bltz" => { + // branch if 2**31 <= r1 < 2**32 + let (r1, label) = rl(args); + format!("branch_if_positive {r1} - 2**31 + 1, {label};\n") + } + + "blez" => { + // branch less or equal zero + let (r1, label) = rl(args); + format!("tmp1 <=X= to_signed({r1});\n") + + &format!("branch_if_positive -tmp1 + 1, {label};\n") + } + "bgtz" => { + // branch if 0 < r1 < 2**31 + let (r1, label) = rl(args); + format!("tmp1 <=X= to_signed({r1});\n") + + &format!("branch_if_positive tmp1, {label};\n") + } "bne" => { let (r1, r2, label) = rrl(args); format!("branch_if_nonzero {r1} - {r2}, {label};\n") @@ -345,6 +717,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> String { let (r1, label) = rl(args); format!("branch_if_nonzero {r1}, {label};\n") } + + // jump and call "j" => { if let [Argument::Symbol(label)] = args { format!("jump {};\n", escape_label(label)) @@ -352,6 +726,10 @@ fn process_instruction(instr: &str, args: &[Argument]) -> String { panic!() } } + "jal" => { + let (_rd, _label) = rl(args); + todo!(); + } "call" => { if let [Argument::Symbol(label)] = args { format!("call {};\n", escape_label(label)) @@ -363,49 +741,85 @@ fn process_instruction(instr: &str, args: &[Argument]) -> String { assert!(args.is_empty()); "x10 <=X= ${ (\"input\", x10) };\n".to_string() } - "li" => { - let (rd, imm) = ri(args); - format!("{rd} <=X= {imm};\n") - } - "lui" => { - let (rd, imm) = ri(args); - format!("{rd} <=X= {};\n", imm << 12) - } - "lw" => { - let (rd, rs, off) = rro(args); - format!("addr <=X= wrap({rs} + {off});\n") + &format!("{rd} <=X= mload();\n") - } - "sw" => { - let (r1, r2, off) = rro(args); - format!("addr <=X= wrap({r2} + {off});\n") + &format!("mstore {r1};\n") - } - "mv" => { - let (rd, rs) = rr(args); - format!("{rd} <=X= {rs};\n") + "tail" => { + if let [Argument::Symbol(label)] = args { + format!("tail {};\n", escape_label(label)) + } else { + panic!() + } } "ret" => { assert!(args.is_empty()); "ret;\n".to_string() } - "seqz" => { - let (rd, rs) = rr(args); - format!("{rd} <=Y= is_equal_zero {rs};\n") + + // memory access + "lw" => { + let (rd, rs, off) = rro(args); + // TODO we need to consider misaligned loads / stores + format!("addr <=X= wrap({rs} + {off});\n") + &format!("{rd} <=X= mload();\n") } - "slli" => { - let (rd, rs, amount) = rri(args); - assert!(amount <= 31); - if amount <= 16 { - format!("{rd} <=X= wrap16({rs} * {});\n", 1 << amount) - } else { - todo!(); - } + "lb" => { + // load byte and sign-extend. the memory is little-endian. + let (rd, rs, off) = rro(args); + format!("tmp1 <=X= wrap({rs} + {off});\n") + + "addr <=X= and(tmp1, 0xfffffffc);\n" + + "tmp2 <=X= and(tmp1, 0x3);\n" + + &format!("{rd} <=X= mload();\n") + + &format!("{rd} <=X= shr({rd}, 8 * tmp2);\n") + + &format!("{rd} <=X= sign_extend_byte({rd});\n") + } + "lbu" => { + // load byte and zero-extend. the memory is little-endian. + let (rd, rs, off) = rro(args); + format!("tmp1 <=X= wrap({rs} + {off});\n") + + "addr <=X= and(tmp1, 0xfffffffc);\n" + + "tmp2 <=X= and(tmp1, 0x3);\n" + + &format!("{rd} <=X= mload();\n") + + &format!("{rd} <=X= shr({rd}, 8 * tmp2);\n") + + &format!("{rd} <=X= and({rd}, 0xff);\n") + } + "sw" => { + let (r1, r2, off) = rro(args); + format!("addr <=X= wrap({r2} + {off});\n") + &format!("mstore {r1};\n") + } + "sh" => { + // store half word (two bytes) + // TODO this code assumes it is at least aligned on + // a two-byte boundary + + let (rs, rd, off) = rro(args); + format!("tmp1 <=X= wrap({rd} + {off});\n") + + "addr <=X= and(tmp1, 0xfffffffc);\n" + + "tmp2 <=X= and(tmp1, 0x3);\n" + + "tmp1 <=X= mload();\n" + + "tmp3 <=X= shl(0xffff, 8 * tmp2);\n" + + "tmp3 <=X= xor(tmp3, 0xffffffff);\n" + + "tmp1 <=X= and(tmp1, tmp3);\n" + + &format!("tmp3 <=X= and({rs}, 0xffff);\n") + + "tmp3 <=X= shl(tmp3, 8 * tmp2);\n" + + "tmp1 <=X= or(tmp1, tmp3);\n" + + "mstore tmp1;\n" + } + "sb" => { + // store byte + let (rs, rd, off) = rro(args); + format!("tmp1 <=X= wrap({rd} + {off});\n") + + "addr <=X= and(tmp1, 0xfffffffc);\n" + + "tmp2 <=X= and(tmp1, 0x3);\n" + + "tmp1 <=X= mload();\n" + + "tmp3 <=X= shl(0xff, 8 * tmp2);\n" + + "tmp3 <=X= xor(tmp3, 0xffffffff);\n" + + "tmp1 <=X= and(tmp1, tmp3);\n" + + &format!("tmp3 <=X= and({rs}, 0xff);\n") + + "tmp3 <=X= shl(tmp3, 8 * tmp2);\n" + + "tmp1 <=X= or(tmp1, tmp3);\n" + + "mstore tmp1;\n" } "unimp" => "fail;\n".to_string(), - "xor" => { - todo!(); - // let (rd, r1, r2) = rrr(args); - // format!("{rd} <=X= xor {r1}, {r2};\n") + + _ => { + panic!("Unknown instruction: {instr}"); } - _ => todo!("Unknown instruction: {instr}"), } } diff --git a/src/riscv/mod.rs b/src/riscv/mod.rs index a3260d059..f1aa42377 100644 --- a/src/riscv/mod.rs +++ b/src/riscv/mod.rs @@ -2,6 +2,7 @@ use std::{path::Path, process::Command}; use mktemp::Temp; use std::fs; +use walkdir::WalkDir; use crate::number::AbstractNumberType; @@ -12,11 +13,21 @@ pub mod parser; /// fixed and witness columns. pub fn compile_rust( file_name: &str, + full_crate: bool, inputs: Vec, output_dir: &Path, force_overwrite: bool, ) { - let riscv_asm = compile_rust_to_riscv_asm(file_name); + let riscv_asm = if full_crate { + let cargo_toml = if file_name.ends_with("Cargo.toml") { + file_name.to_string() + } else { + format!("{file_name}/Cargo.toml") + }; + compile_rust_crate_to_riscv_asm(&cargo_toml) + } else { + compile_rust_to_riscv_asm(file_name) + }; let riscv_asm_file_name = output_dir.join(format!( "{}_riscv.asm", Path::new(file_name).file_stem().unwrap().to_str().unwrap() @@ -99,3 +110,34 @@ pub fn compile_rust_to_riscv_asm(input_file: &str) -> String { assert!(rustc_status.success()); fs::read_to_string(temp_file.to_str().unwrap()).unwrap() } + +pub fn compile_rust_crate_to_riscv_asm(input_dir: &str) -> String { + let temp_dir = Temp::new_dir().unwrap(); + + let cargo_status = Command::new("cargo") + .env("RUSTFLAGS", "--emit=asm") + .args([ + "build", + "--release", + "--target", + "riscv32imc-unknown-none-elf", + "--lib", + "--target-dir", + temp_dir.to_str().unwrap(), + "--manifest-path", + input_dir, + ]) + .status() + .unwrap(); + assert!(cargo_status.success()); + + let mut combined_assembly = String::new(); + for entry in WalkDir::new(&temp_dir) { + let entry = entry.unwrap(); + // TODO search only in certain subdir? + if entry.file_name().to_str().unwrap().ends_with(".s") { + combined_assembly += &fs::read_to_string(entry.path()).unwrap(); + } + } + combined_assembly +} diff --git a/src/riscv/parser.rs b/src/riscv/parser.rs index 42f3534b5..51b013a00 100644 --- a/src/riscv/parser.rs +++ b/src/riscv/parser.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{self, Display}; use lalrpop_util::*; @@ -19,7 +19,7 @@ pub enum Statement { pub enum Argument { Register(Register), RegOffset(Register, Constant), - StringLiteral(String), + StringLiteral(Vec), Constant(Constant), Symbol(String), Difference(String, String), @@ -50,7 +50,7 @@ impl Display for Argument { Argument::Register(r) => write!(f, "{r}"), Argument::Constant(c) => write!(f, "{c}"), Argument::RegOffset(reg, off) => write!(f, "{off}({reg})"), - Argument::StringLiteral(lit) => write!(f, "\"{lit}\""), + Argument::StringLiteral(lit) => write!(f, "\"{}\"", String::from_utf8_lossy(lit)), Argument::Symbol(s) => write!(f, "{s}"), Argument::Difference(left, right) => write!(f, "{left} - {right}"), } @@ -127,12 +127,48 @@ pub fn extract_label_references(statements: &[Statement]) -> BTreeSet<&str> { .collect() } -// TODO it actually parses to a byte array... -pub fn unescape_string(s: &str) -> String { +pub fn extract_data_objects(statements: &[Statement]) -> BTreeMap> { + let mut current_label = None; + let mut objects = BTreeMap::>>::new(); + for s in statements { + match s { + Statement::Label(l) => { + current_label = Some(l.as_str()); + } + // TODO We ignore size and alignment directives. + Statement::Directive(dir, args) => match (dir.as_str(), &args[..]) { + (".type", [Argument::Symbol(name), Argument::Symbol(kind)]) + if kind.as_str() == "@object" => + { + objects.insert(name.clone(), None); + } + (".ascii" | ".asciz", [Argument::StringLiteral(data)]) => { + if let Some(entry) = objects.get_mut(current_label.unwrap()) { + assert!(entry.is_none()); + *entry = Some(data.clone()); + } + } + _ => {} + }, + _ => {} + } + } + objects + .into_iter() + .map(|(k, v)| { + ( + k.clone(), + v.unwrap_or_else(|| panic!("Label for announced object {k} not found.")), + ) + }) + .collect() +} + +pub fn unescape_string(s: &str) -> Vec { assert!(s.len() >= 2); assert!(s.starts_with('"') && s.ends_with('"')); let mut chars = s[1..s.len() - 1].chars(); - let mut result = String::new(); + let mut result = vec![]; while let Some(c) = chars.next() { result.push(if c == '\\' { let next = chars.next().unwrap(); @@ -141,21 +177,21 @@ pub fn unescape_string(s: &str) -> String { let n = next as u8 - b'0'; let nn = chars.next().unwrap() as u8 - b'0'; let nnn = chars.next().unwrap() as u8 - b'0'; - (nnn + nn * 8 + n * 64) as char + nnn + nn * 8 + n * 64 } else if next == 'x' { todo!("Parse hex digit"); } else { - match next { + (match next { 'n' => '\n', 'r' => '\r', 't' => '\t', 'b' => 8 as char, 'f' => 12 as char, other => other, - } + }) as u8 } } else { - c + c as u8 }) } result diff --git a/src/riscv/riscv_asm.lalrpop b/src/riscv/riscv_asm.lalrpop index d7e93b478..9a1949006 100644 --- a/src/riscv/riscv_asm.lalrpop +++ b/src/riscv/riscv_asm.lalrpop @@ -86,7 +86,7 @@ Difference: Argument = { "-" => Argument::Difference(<>) } -StringLiteral: String = { +StringLiteral: Vec = { r#""[^\\"\n\r]*(\\[tnfbrx'"\\0-9][^\\"\n\r]*)*""# => unescape_string(<>) } diff --git a/src/witness_generator/machines/block_machine.rs b/src/witness_generator/machines/block_machine.rs index 96eaf2524..6f0bb8a08 100644 --- a/src/witness_generator/machines/block_machine.rs +++ b/src/witness_generator/machines/block_machine.rs @@ -276,7 +276,7 @@ impl BlockMachine { match constraint { Constraint::Assignment(a) => { let values = self.data.get_mut(&poly).unwrap(); - if r as usize <= values.len() { + if (r as usize) < values.len() { // do not write to other rows for now values[r as usize] = Some(a); } diff --git a/src/witness_generator/machines/double_sorted_witness_machine.rs b/src/witness_generator/machines/double_sorted_witness_machine.rs index cf0469714..fd9b76fe6 100644 --- a/src/witness_generator/machines/double_sorted_witness_machine.rs +++ b/src/witness_generator/machines/double_sorted_witness_machine.rs @@ -190,7 +190,7 @@ impl DoubleSortedWitnesses { })?; log::debug!( - "Query addr={addr}, step={step}, write: {is_write}, left: {}", + "Query addr={addr:x}, step={step}, write: {is_write}, left: {}", left[2].format(fixed_data) ); @@ -201,7 +201,7 @@ impl DoubleSortedWitnesses { Some(v) => v, None => return Ok(vec![]), }; - log::trace!("Memory write: addr={addr}, step={step}, value={value}"); + log::trace!("Memory write: addr={addr:x}, step={step}, value={value:x}"); self.data.insert(addr.clone(), value.clone()); self.trace .insert((addr, step), Operation { is_write, value }); @@ -214,7 +214,7 @@ impl DoubleSortedWitnesses { value: value.clone(), }, ); - log::trace!("Memory read: addr={addr}, step={step}, value={value}"); + log::trace!("Memory read: addr={addr:x}, step={step}, value={value:x}"); assignments.extend(match (left[2].clone() - value.clone().into()).solve() { Ok(ass) => ass, Err(_) => return Ok(vec![]), diff --git a/src/witness_generator/symbolic_witness_evaluator.rs b/src/witness_generator/symbolic_witness_evaluator.rs index 294ee68c4..7055d0a2d 100644 --- a/src/witness_generator/symbolic_witness_evaluator.rs +++ b/src/witness_generator/symbolic_witness_evaluator.rs @@ -51,7 +51,11 @@ where self.witness_access.value(name, next) } else { // Constant polynomial (or something else) - let values = self.fixed_data.fixed_cols[name]; + let values = self + .fixed_data + .fixed_cols + .get(name) + .unwrap_or_else(|| panic!("unknown col: {name}")); let row = if next { let degree = values.len() as DegreeType; (self.row + 1) % degree diff --git a/tests/riscv.rs b/tests/riscv.rs index b74214217..eac6ac668 100644 --- a/tests/riscv.rs +++ b/tests/riscv.rs @@ -6,15 +6,36 @@ mod common; #[test] fn test_sum() { let case = "sum.rs"; - verify( + verify_file( case, [16, 4, 1, 2, 8, 5].iter().map(|&x| x.into()).collect(), ); } -fn verify(case: &str, inputs: Vec) { +#[test] +fn test_byte_access() { + let case = "byte_access.rs"; + verify_file(case, [0, 104, 707].iter().map(|&x| x.into()).collect()); +} + +#[test] +fn test_keccak() { + let case = "keccak"; + verify_crate(case, vec![]); +} + +fn verify_file(case: &str, inputs: Vec) { let riscv_asm = powdr::riscv::compile_rust_to_riscv_asm(&format!("tests/riscv_data/{case}")); let powdr_asm = powdr::riscv::compiler::compile_riscv_asm(&riscv_asm); verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); } + +fn verify_crate(case: &str, inputs: Vec) { + let riscv_asm = powdr::riscv::compile_rust_crate_to_riscv_asm(&format!( + "tests/riscv_data/{case}/Cargo.toml" + )); + let powdr_asm = powdr::riscv::compiler::compile_riscv_asm(&riscv_asm); + + verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); +} diff --git a/tests/riscv_data/byte_access.rs b/tests/riscv_data/byte_access.rs new file mode 100644 index 000000000..b9865508c --- /dev/null +++ b/tests/riscv_data/byte_access.rs @@ -0,0 +1,29 @@ +#![no_std] + +use core::arch::asm; + +const X: &'static str = "abcdefg"; + +#[no_mangle] +pub extern "C" fn main() -> ! { + let replacement_index = get_prover_input(0) as usize; + let replacement_value = get_prover_input(1) as u8; + let mut x = [0; 10]; + for (i, c) in X.as_bytes().iter().enumerate() { + x[i] = *c; + } + x[replacement_index] = replacement_value; + let claimed_sum = get_prover_input(2) as u32; + let computed_sum = x.iter().map(|c| *c as u32).sum(); + assert!(claimed_sum == computed_sum); + loop {} +} + +#[inline] +fn get_prover_input(index: u32) -> u32 { + let mut value: u32; + unsafe { + asm!("ecall", lateout("a0") value, in("a0") index); + } + value +} diff --git a/tests/riscv_data/keccak/Cargo.toml b/tests/riscv_data/keccak/Cargo.toml new file mode 100644 index 000000000..b6df56e1a --- /dev/null +++ b/tests/riscv_data/keccak/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "keccak" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tiny-keccak = { version = "2.0.2", features = ["keccak"] } diff --git a/tests/riscv_data/keccak/src/lib.rs b/tests/riscv_data/keccak/src/lib.rs new file mode 100644 index 000000000..ed91acacb --- /dev/null +++ b/tests/riscv_data/keccak/src/lib.rs @@ -0,0 +1,20 @@ +#![no_std] +use core::panic::PanicInfo; + +use tiny_keccak::{Hasher, Keccak}; + +#[panic_handler] +fn panic(_info: &PanicInfo) -> ! { + loop {} +} + +#[no_mangle] +pub extern "C" fn main() -> ! { + let input = b"Solidity"; + let mut output = [0u8; 32]; + let mut hasher = Keccak::v256(); + //hasher.update(input); + hasher.finalize(&mut output); + // println!("{output:x?}"); + loop {} +}