From effdda0febe385172afee420ecb8d34f6d3ddb5a Mon Sep 17 00:00:00 2001 From: Emily Herbert <17410721+emilyaherbert@users.noreply.github.com> Date: Wed, 12 Jul 2023 19:08:02 +0200 Subject: [PATCH] Generalize assembly handling. --- Cargo.toml | 1 + asm_utils/Cargo.toml | 7 + asm_utils/src/ast.rs | 222 ++++++++++++++++++++++ asm_utils/src/compiler.rs | 45 +++++ {riscv => asm_utils}/src/data_parser.rs | 11 +- asm_utils/src/lib.rs | 7 + asm_utils/src/parser.rs | 17 ++ {riscv => asm_utils}/src/reachability.rs | 42 +++-- riscv/Cargo.toml | 1 + riscv/src/compiler.rs | 231 ++++++++++++++--------- riscv/src/disambiguator.rs | 2 +- riscv/src/lib.rs | 11 +- riscv/src/parser.rs | 227 ++-------------------- riscv/src/riscv_asm.lalrpop | 42 +++-- riscv/tests/instructions.rs | 5 +- riscv/tests/riscv.rs | 5 +- 16 files changed, 532 insertions(+), 344 deletions(-) create mode 100644 asm_utils/Cargo.toml create mode 100644 asm_utils/src/ast.rs create mode 100644 asm_utils/src/compiler.rs rename {riscv => asm_utils}/src/data_parser.rs (94%) create mode 100644 asm_utils/src/lib.rs create mode 100644 asm_utils/src/parser.rs rename {riscv => asm_utils}/src/reachability.rs (85%) diff --git a/Cargo.toml b/Cargo.toml index cfff306ab..fd35fb4c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "ast", "analysis", "linker", + "asm_utils" ] [patch."https://github.com/privacy-scaling-explorations/halo2.git"] diff --git a/asm_utils/Cargo.toml b/asm_utils/Cargo.toml new file mode 100644 index 000000000..10ff37ccc --- /dev/null +++ b/asm_utils/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "asm_utils" +version = "0.1.0" +edition = "2021" + +[dependencies] +itertools = "^0.10" diff --git a/asm_utils/src/ast.rs b/asm_utils/src/ast.rs new file mode 100644 index 000000000..3983aa59d --- /dev/null +++ b/asm_utils/src/ast.rs @@ -0,0 +1,222 @@ +//! Common AST for the frontend architecture inputs. + +use std::fmt::{self, Display}; + +#[derive(Clone)] +pub enum Statement { + Label(String), + Directive(String, Vec>), + Instruction(String, Vec>), +} + +#[derive(Clone)] +pub enum Argument { + Register(R), + RegOffset(R, Expression), + StringLiteral(Vec), + Expression(Expression), +} + +impl Argument { + pub fn post_visit_expressions_mut(&mut self, f: &mut impl FnMut(&mut Expression)) { + match self { + Argument::Register(_) | Argument::StringLiteral(_) => (), + Argument::RegOffset(_, expr) | Argument::Expression(expr) => { + expr.post_visit_mut(f); + } + } + } + + pub fn post_visit_expressions<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) { + match self { + Argument::Register(_) | Argument::StringLiteral(_) => (), + Argument::RegOffset(_, expr) | Argument::Expression(expr) => { + expr.post_visit(f); + } + } + } +} + +pub trait Register: Display {} + +#[derive(Clone, Copy)] +pub enum UnaryOpKind { + Negation, +} + +impl Display for UnaryOpKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UnaryOpKind::Negation => write!(f, "-"), + } + } +} + +#[derive(Clone, Copy)] +pub enum BinaryOpKind { + Or, + Xor, + And, + LeftShift, + RightShift, + Add, + Sub, + Mul, + Div, + Mod, +} + +pub trait FunctionOpKind: Display {} + +#[derive(Clone)] +pub enum Expression { + Number(i64), + Symbol(String), + UnaryOp(UnaryOpKind, Box>), + BinaryOp(BinaryOpKind, Box<[Expression; 2]>), + FunctionOp(F, Box>), +} + +impl Expression { + fn post_visit<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) { + match self { + Expression::Number(_) => {} + Expression::Symbol(_) => {} + Expression::UnaryOp(_, subexpr) => { + Self::post_visit(subexpr, f); + } + Expression::BinaryOp(_, subexprs) => { + subexprs.iter().for_each(|subexpr| { + Self::post_visit(subexpr, f); + }); + } + Expression::FunctionOp(_, subexpr) => { + Self::post_visit(subexpr, f); + } + } + f(self); + } + + fn post_visit_mut(&mut self, f: &mut impl FnMut(&mut Expression)) { + match self { + Expression::Number(_) => {} + Expression::Symbol(_) => {} + Expression::UnaryOp(_, subexpr) => { + Self::post_visit_mut(subexpr, f); + } + Expression::BinaryOp(_, subexprs) => { + subexprs.iter_mut().for_each(|subexpr| { + Self::post_visit_mut(subexpr, f); + }); + } + Expression::FunctionOp(_, subexpr) => { + Self::post_visit_mut(subexpr, f); + } + } + f(self); + } +} + +pub fn new_unary_op(op: UnaryOpKind, v: Expression) -> Expression { + Expression::UnaryOp(op, Box::new(v)) +} + +pub fn new_binary_op( + op: BinaryOpKind, + l: Expression, + r: Expression, +) -> Expression { + Expression::BinaryOp(op, Box::new([l, r])) +} + +pub fn new_function_op(op: F, v: Expression) -> Expression { + Expression::FunctionOp(op, Box::new(v)) +} + +impl Display for Statement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Statement::Label(l) => writeln!(f, "{l}:"), + Statement::Directive(d, args) => writeln!(f, " {d} {}", format_arguments(args)), + Statement::Instruction(i, args) => writeln!(f, " {i} {}", format_arguments(args)), + } + } +} + +impl Display for Argument { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Argument::Register(r) => write!(f, "{r}"), + Argument::RegOffset(reg, off) => write!(f, "{off}({reg})"), + Argument::StringLiteral(lit) => write!(f, "\"{}\"", String::from_utf8_lossy(lit)), + Argument::Expression(expr) => write!(f, "{expr}"), + } + } +} + +impl Display for Expression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Expression::Number(n) => write!(f, "{n}"), + Expression::Symbol(sym) => write!(f, "{sym}"), + Expression::UnaryOp(kind, expr) => write!(f, "({}{})", kind, expr), + Expression::BinaryOp(op, args) => { + let symbol = match op { + BinaryOpKind::Or => "|", + BinaryOpKind::Xor => "^", + BinaryOpKind::And => "&", + BinaryOpKind::LeftShift => "<<", + BinaryOpKind::RightShift => ">>", + BinaryOpKind::Add => "+", + BinaryOpKind::Sub => "-", + BinaryOpKind::Mul => "*", + BinaryOpKind::Div => "/", + BinaryOpKind::Mod => "%", + }; + write!(f, "({} {symbol} {})", args[0], args[1]) + } + Expression::FunctionOp(kind, expr) => write!(f, "{}({})", kind, expr), + } + } +} + +fn format_arguments(args: &[Argument]) -> String { + args.iter() + .map(|a| format!("{a}")) + .collect::>() + .join(", ") +} + +/// Parse an escaped string - used in the grammar. +pub fn unescape_string(s: &str) -> Vec { + assert!(s.len() >= 2); + assert!(s.starts_with('"') && s.ends_with('"')); + let mut chars = s[1..s.len() - 1].chars(); + let mut result = vec![]; + while let Some(c) = chars.next() { + result.push(if c == '\\' { + let next = chars.next().unwrap(); + if next.is_ascii_digit() { + // octal number. + let n = next as u8 - b'0'; + let nn = chars.next().unwrap() as u8 - b'0'; + let nnn = chars.next().unwrap() as u8 - b'0'; + nnn + nn * 8 + n * 64 + } else if next == 'x' { + todo!("Parse hex digit"); + } else { + (match next { + 'n' => '\n', + 'r' => '\r', + 't' => '\t', + 'b' => 8 as char, + 'f' => 12 as char, + other => other, + }) as u8 + } + } else { + c as u8 + }) + } + result +} diff --git a/asm_utils/src/compiler.rs b/asm_utils/src/compiler.rs new file mode 100644 index 000000000..070659fed --- /dev/null +++ b/asm_utils/src/compiler.rs @@ -0,0 +1,45 @@ +use std::collections::BTreeMap; + +use crate::ast::{Argument, Expression, FunctionOpKind, Register}; + +pub trait Compiler { + fn compile(assemblies: BTreeMap) -> String; +} + +pub fn next_multiple_of_four(x: usize) -> usize { + ((x + 3) / 4) * 4 +} + +pub fn quote(s: &str) -> String { + // TODO more things to quote + format!("\"{}\"", s.replace('\\', "\\\\").replace('\"', "\\\"")) +} + +pub fn escape_label(l: &str) -> String { + // TODO make this proper + l.replace('.', "_dot_").replace('/', "_slash_") +} + +pub fn argument_to_escaped_symbol(x: &Argument) -> String { + if let Argument::Expression(Expression::Symbol(symb)) = x { + escape_label(symb) + } else { + panic!("Expected a symbol, got {x}"); + } +} + +pub fn argument_to_number(x: &Argument) -> u32 { + if let Argument::Expression(expr) = x { + expression_to_number(expr) + } else { + panic!("Expected numeric expression, got {x}") + } +} + +pub fn expression_to_number(expr: &Expression) -> u32 { + if let Expression::Number(n) = expr { + *n as u32 + } else { + panic!("Constant expression could not be fully resolved to a number during preprocessing: {expr}"); + } +} diff --git a/riscv/src/data_parser.rs b/asm_utils/src/data_parser.rs similarity index 94% rename from riscv/src/data_parser.rs rename to asm_utils/src/data_parser.rs index 23e964290..e9b549c0e 100644 --- a/riscv/src/data_parser.rs +++ b/asm_utils/src/data_parser.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use crate::parser::{Argument, Expression, Statement}; +use crate::ast::{Argument, Expression, FunctionOpKind, Register, Statement}; pub enum DataValue { Direct(Vec), @@ -22,8 +22,8 @@ impl DataValue { /// Extract all data objects from the list of statements. /// Returns the named data objects themselves and a vector of the names /// in the order in which they occur in the statements. -pub fn extract_data_objects( - statements: &[Statement], +pub fn extract_data_objects( + statements: &[Statement], ) -> (BTreeMap>, Vec) { let mut object_order = vec![]; let mut current_label = None; @@ -74,7 +74,10 @@ pub fn extract_data_objects( (objects, object_order) } -fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec { +fn extract_data_value( + directive: &str, + arguments: &[Argument], +) -> Vec { match (directive, arguments) { ( ".zero", diff --git a/asm_utils/src/lib.rs b/asm_utils/src/lib.rs new file mode 100644 index 000000000..10a0960fc --- /dev/null +++ b/asm_utils/src/lib.rs @@ -0,0 +1,7 @@ +//! Common crate for generalized assembly handling. + +pub mod ast; +pub mod compiler; +pub mod data_parser; +pub mod parser; +pub mod reachability; diff --git a/asm_utils/src/parser.rs b/asm_utils/src/parser.rs new file mode 100644 index 000000000..3f0f29dcd --- /dev/null +++ b/asm_utils/src/parser.rs @@ -0,0 +1,17 @@ +use crate::ast::{FunctionOpKind, Register, Statement}; + +pub fn parse_asm>( + parser: P, + input: &str, +) -> Vec> { + input + .split('\n') + .map(|l| l.trim()) + .filter(|l| !l.is_empty()) + .flat_map(|line| parser.parse(line).unwrap()) + .collect() +} + +pub trait Parser { + fn parse(&self, input: &str) -> Result>, String>; +} diff --git a/riscv/src/reachability.rs b/asm_utils/src/reachability.rs similarity index 85% rename from riscv/src/reachability.rs rename to asm_utils/src/reachability.rs index 821a79da2..45cbff708 100644 --- a/riscv/src/reachability.rs +++ b/asm_utils/src/reachability.rs @@ -3,14 +3,15 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use itertools::Itertools; use crate::data_parser::DataValue; -use crate::parser::{Argument, Expression, Statement}; + +use crate::ast::{Argument, Expression, FunctionOpKind, Register, Statement}; /// Processes the statements and removes all statements and objects that are /// not reachable from the label `label`. /// Keeps the order of the statements. -pub fn filter_reachable_from( +pub fn filter_reachable_from( label: &str, - statements: &mut Vec, + statements: &mut Vec>, objects: &mut BTreeMap>, ) { let replacements = extract_replacements(statements); @@ -53,9 +54,9 @@ pub fn filter_reachable_from( .collect(); } -pub fn find_reachable_labels<'a>( +pub fn find_reachable_labels<'a, R: Register, F: FunctionOpKind>( label: &'a str, - statements: &'a [Statement], + statements: &'a [Statement], objects: &'a mut BTreeMap>, replacements: &BTreeMap<&str, &'a str>, ) -> BTreeSet<&'a str> { @@ -85,7 +86,9 @@ pub fn find_reachable_labels<'a>( processed_labels.extend(seen_labels_in_block); referenced_labels_in_block } else { - eprintln!("The RISCV assembly code references an external routine / label that is not available:"); + eprintln!( + "The assembly code references an external routine / label that is not available:" + ); eprintln!("{l}"); panic!(); }; @@ -99,7 +102,9 @@ pub fn find_reachable_labels<'a>( processed_labels } -fn extract_replacements(statements: &[Statement]) -> BTreeMap { +fn extract_replacements( + statements: &[Statement], +) -> BTreeMap { let mut replacements = statements .iter() .filter_map(|s| match s { @@ -143,7 +148,9 @@ fn extract_replacements(statements: &[Statement]) -> BTreeMap { replacements } -pub fn extract_label_offsets(statements: &[Statement]) -> BTreeMap<&str, usize> { +pub fn extract_label_offsets( + statements: &[Statement], +) -> BTreeMap<&str, usize> { statements .iter() .enumerate() @@ -159,7 +166,9 @@ pub fn extract_label_offsets(statements: &[Statement]) -> BTreeMap<&str, usize> }) } -pub fn references_in_statement(statement: &Statement) -> BTreeSet<&str> { +pub fn references_in_statement( + statement: &Statement, +) -> BTreeSet<&str> { let mut ret = BTreeSet::new(); match statement { Statement::Label(_) | Statement::Directive(_, _) => (), @@ -176,7 +185,9 @@ pub fn references_in_statement(statement: &Statement) -> BTreeSet<&str> { ret } -fn basic_block_references_starting_from(statements: &[Statement]) -> (Vec<&str>, Vec<&str>) { +fn basic_block_references_starting_from( + statements: &[Statement], +) -> (Vec<&str>, Vec<&str>) { let mut seen_labels = vec![]; let mut referenced_labels = BTreeSet::<&str>::new(); iterate_basic_block(statements, |s| { @@ -189,7 +200,10 @@ fn basic_block_references_starting_from(statements: &[Statement]) -> (Vec<&str>, (referenced_labels.into_iter().collect(), seen_labels) } -fn iterate_basic_block<'a>(statements: &'a [Statement], mut fun: impl FnMut(&'a Statement)) { +fn iterate_basic_block<'a, R: Register, F: FunctionOpKind>( + statements: &'a [Statement], + mut fun: impl FnMut(&'a Statement), +) { for s in statements { fun(s); if ends_control_flow(s) { @@ -198,7 +212,7 @@ fn iterate_basic_block<'a>(statements: &'a [Statement], mut fun: impl FnMut(&'a } } -fn ends_control_flow(s: &Statement) -> bool { +fn ends_control_flow(s: &Statement) -> bool { match s { Statement::Instruction(instruction, _) => match instruction.as_str() { "li" | "lui" | "la" | "mv" | "add" | "addi" | "sub" | "neg" | "mul" | "mulhu" @@ -216,8 +230,8 @@ fn ends_control_flow(s: &Statement) -> bool { } } -fn apply_replacement_to_instruction( - statement: &mut Statement, +fn apply_replacement_to_instruction( + statement: &mut Statement, replacements: &BTreeMap<&str, &str>, ) { match statement { diff --git a/riscv/Cargo.toml b/riscv/Cargo.toml index ad4952a60..e86d58ab7 100644 --- a/riscv/Cargo.toml +++ b/riscv/Cargo.toml @@ -13,6 +13,7 @@ walkdir = "2.3.3" number = { path = "../number" } compiler = { path = "../compiler" } parser_util = { path = "../parser_util" } +asm_utils = { path = "../asm_utils" } # This is only here to work around https://github.com/lalrpop/lalrpop/issues/750 # It should be removed once that workaround is no longer needed. regex-syntax = { version = "0.6", default_features = false, features = ["unicode"] } diff --git a/riscv/src/compiler.rs b/riscv/src/compiler.rs index e3dc4e28d..275fb46c3 100644 --- a/riscv/src/compiler.rs +++ b/riscv/src/compiler.rs @@ -1,81 +1,133 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::{ + collections::{BTreeMap, BTreeSet}, + fmt, +}; +use asm_utils::{ + ast::{BinaryOpKind, UnaryOpKind}, + data_parser::{self, DataValue}, + parser::parse_asm, + reachability, +}; use itertools::Itertools; -use crate::data_parser::{self, DataValue}; -use crate::parser::{self, Argument, Register, Statement, UnaryOpKind}; -use crate::{disambiguator, reachability}; +use crate::disambiguator; +use crate::parser::RiscParser; +use crate::{Argument, Expression, Statement}; -use super::parser::Expression; +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct Register { + value: u8, +} -/// Compiles riscv assembly to POWDR assembly. Adds required library routines. -pub fn compile_riscv_asm(mut assemblies: BTreeMap) -> String { - // stack grows towards zero - let stack_start = 0x10000; - // data grows away from zero - let data_start = 0x10100; +impl Register { + pub fn new(value: u8) -> Self { + Self { value } + } - assert!(assemblies - .insert("__runtime".to_string(), runtime().to_string()) - .is_none()); + pub fn is_zero(&self) -> bool { + self.value == 0 + } +} - // TODO remove unreferenced files. - let (mut statements, file_ids) = disambiguator::disambiguate( - assemblies +impl asm_utils::ast::Register for Register {} + +impl fmt::Display for Register { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "x{}", self.value) + } +} + +#[derive(Clone, Copy)] +pub enum FunctionKind { + HiDataRef, + LoDataRef, +} + +impl asm_utils::ast::FunctionOpKind for FunctionKind {} + +impl fmt::Display for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionKind::HiDataRef => write!(f, "%hi"), + FunctionKind::LoDataRef => write!(f, "%lo"), + } + } +} + +#[derive(Default)] +pub struct Risc {} + +impl asm_utils::compiler::Compiler for Risc { + /// Compiles riscv assembly to POWDR assembly. Adds required library routines. + fn compile(mut assemblies: BTreeMap) -> String { + // stack grows towards zero + let stack_start = 0x10000; + // data grows away from zero + let data_start = 0x10100; + + assert!(assemblies + .insert("__runtime".to_string(), runtime().to_string()) + .is_none()); + + // TODO remove unreferenced files. + let (mut statements, file_ids) = disambiguator::disambiguate( + assemblies + .into_iter() + .map(|(name, contents)| (name, parse_asm(RiscParser::default(), &contents))) + .collect(), + ); + let (mut objects, mut object_order) = data_parser::extract_data_objects(&statements); + + // Reduce to the code that is actually reachable from main + // (and the objects that are referred from there) + reachability::filter_reachable_from("__runtime_start", &mut statements, &mut objects); + + // Replace dynamic references to code labels + replace_dynamic_label_references(&mut statements, &objects); + + // Sort the objects according to the order of the names in object_order. + // With the single exception: If there is large object, put that at the end. + // The idea behind this is that there might be a single gigantic object representing the heap + // and putting that at the end should keep memory addresses small. + let mut large_objects = objects + .iter() + .filter(|(_name, data)| data.iter().map(|d| d.size()).sum::() > 0x2000); + if let (Some((heap, _)), None) = (large_objects.next(), large_objects.next()) { + let heap_pos = object_order.iter().position(|o| o == heap).unwrap(); + object_order.remove(heap_pos); + object_order.push(heap.clone()); + }; + let sorted_objects = object_order .into_iter() - .map(|(name, contents)| (name, parser::parse_asm(&contents))) - .collect(), - ); - let (mut objects, mut object_order) = data_parser::extract_data_objects(&statements); + .filter_map(|n| { + let value = objects.get_mut(&n).map(std::mem::take); + value.map(|v| (n, v)) + }) + .collect::>(); + let (data_code, data_positions) = store_data_objects(&sorted_objects, data_start); - // Reduce to the code that is actually reachable from main - // (and the objects that are referred from there) - reachability::filter_reachable_from("__runtime_start", &mut statements, &mut objects); - - // Replace dynamic references to code labels - replace_dynamic_label_references(&mut statements, &objects); - - // Sort the objects according to the order of the names in object_order. - // With the single exception: If there is large object, put that at the end. - // The idea behind this is that there might be a single gigantic object representing the heap - // and putting that at the end should keep memory addresses small. - let mut large_objects = objects - .iter() - .filter(|(_name, data)| data.iter().map(|d| d.size()).sum::() > 0x2000); - if let (Some((heap, _)), None) = (large_objects.next(), large_objects.next()) { - let heap_pos = object_order.iter().position(|o| o == heap).unwrap(); - object_order.remove(heap_pos); - object_order.push(heap.clone()); - }; - let sorted_objects = object_order - .into_iter() - .filter_map(|n| { - let value = objects.get_mut(&n).map(std::mem::take); - value.map(|v| (n, v)) - }) - .collect::>(); - let (data_code, data_positions) = store_data_objects(&sorted_objects, data_start); - - risv_machine( - &preamble(), - &file_ids - .into_iter() - .map(|(id, dir, file)| format!("debug file {id} {} {};", quote(&dir), quote(&file))) - .chain(["call __data_init;".to_string()]) - .chain([ - format!("// Set stack pointer\nx2 <=X= {stack_start};"), - "jump __runtime_start;".to_string(), - ]) - .chain( - substitute_symbols_with_values(statements, &data_positions) - .into_iter() - .flat_map(process_statement), - ) - .chain(["// This is the data initialization routine.\n__data_init::".to_string()]) - .chain(data_code) - .chain(["// This is the end of the data initialization routine.\nret;".to_string()]) - .join("\n"), - ) + risv_machine( + &preamble(), + &file_ids + .into_iter() + .map(|(id, dir, file)| format!("debug file {id} {} {};", quote(&dir), quote(&file))) + .chain(["call __data_init;".to_string()]) + .chain([ + format!("// Set stack pointer\nx2 <=X= {stack_start};"), + "jump __runtime_start;".to_string(), + ]) + .chain( + substitute_symbols_with_values(statements, &data_positions) + .into_iter() + .flat_map(process_statement), + ) + .chain(["// This is the data initialization routine.\n__data_init::".to_string()]) + .chain(data_code) + .chain(["// This is the end of the data initialization routine.\nret;".to_string()]) + .join("\n"), + ) + } } /// Replace certain patterns of references to code labels by @@ -133,11 +185,11 @@ fn replace_dynamic_label_reference( if instr1.as_str() != "lui" || instr2.as_str() != "addi" { return None; }; - let [Argument::Register(r1), Argument::Expression(Expression::UnaryOp(UnaryOpKind::HiDataRef, expr1))] = &args1[..] else { return None; }; + let [Argument::Register(r1), Argument::Expression(Expression::FunctionOp(FunctionKind::HiDataRef, expr1))] = &args1[..] else { return None; }; // Maybe should try to reduce expr1 and expr2 before comparing deciding it is a pure symbol? - let Expression::Symbol(label1) = &expr1[0] else { return None; }; - let [Argument::Register(r2), Argument::Register(r3), Argument::Expression(Expression::UnaryOp(UnaryOpKind::LoDataRef, expr2))] = &args2[..] else { return None; }; - let Expression::Symbol(label2) = &expr2[0] else { return None; }; + let Expression::Symbol(label1) = expr1.as_ref() else { return None; }; + let [Argument::Register(r2), Argument::Register(r3), Argument::Expression(Expression::FunctionOp(FunctionKind::LoDataRef, expr2))] = &args2[..] else { return None; }; + let Expression::Symbol(label2) = expr2.as_ref() else { return None; }; if r1 != r3 || label1 != label2 || data_objects.contains_key(label1) { return None; } @@ -241,10 +293,8 @@ fn substitute_symbols_with_values( } } Expression::UnaryOp(op, subexpr) => { - if let Expression::Number(num) = subexpr[0] { + if let Expression::Number(num) = subexpr.as_ref() { let result = match op { - UnaryOpKind::HiDataRef => num >> 12, - UnaryOpKind::LoDataRef => num & 0xfff, UnaryOpKind::Negation => -num, }; *expression = Expression::Number(result); @@ -255,20 +305,29 @@ fn substitute_symbols_with_values( (&subexprs[0], &subexprs[1]) { let result = match op { - parser::BinaryOpKind::Or => a | b, - parser::BinaryOpKind::Xor => a ^ b, - parser::BinaryOpKind::And => a & b, - parser::BinaryOpKind::LeftShift => a << b, - parser::BinaryOpKind::RightShift => a >> b, - parser::BinaryOpKind::Add => a + b, - parser::BinaryOpKind::Sub => a - b, - parser::BinaryOpKind::Mul => a * b, - parser::BinaryOpKind::Div => a / b, - parser::BinaryOpKind::Mod => a % b, + BinaryOpKind::Or => a | b, + BinaryOpKind::Xor => a ^ b, + BinaryOpKind::And => a & b, + BinaryOpKind::LeftShift => a << b, + BinaryOpKind::RightShift => a >> b, + BinaryOpKind::Add => a + b, + BinaryOpKind::Sub => a - b, + BinaryOpKind::Mul => a * b, + BinaryOpKind::Div => a / b, + BinaryOpKind::Mod => a % b, }; *expression = Expression::Number(result); } } + Expression::FunctionOp(op, subexpr) => { + if let Expression::Number(num) = subexpr.as_ref() { + let result = match op { + FunctionKind::HiDataRef => num >> 12, + FunctionKind::LoDataRef => num & 0xfff, + }; + *expression = Expression::Number(result); + }; + } }); } } diff --git a/riscv/src/disambiguator.rs b/riscv/src/disambiguator.rs index 7db7016dc..8359e8fd7 100644 --- a/riscv/src/disambiguator.rs +++ b/riscv/src/disambiguator.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet}; use itertools::Itertools; -use crate::parser::{Argument, Expression, Statement}; +use crate::{Argument, Expression, Statement}; /// Disambiguates the collection of assembly files and concatenates it to a single list of statements. /// Also disambiguates file ids (debugging information) and returns a list of all files with new IDs. diff --git a/riscv/src/lib.rs b/riscv/src/lib.rs index d7e1aece9..60be8552b 100644 --- a/riscv/src/lib.rs +++ b/riscv/src/lib.rs @@ -3,17 +3,22 @@ use std::{collections::BTreeMap, path::Path, process::Command}; use ::compiler::{compile_asm_string, Backend}; +use asm_utils::compiler::Compiler; use mktemp::Temp; use std::fs; use walkdir::WalkDir; use number::FieldElement; +use crate::compiler::{FunctionKind, Register}; + pub mod compiler; -mod data_parser; mod disambiguator; pub mod parser; -mod reachability; + +type Statement = asm_utils::ast::Statement; +type Argument = asm_utils::ast::Argument; +type Expression = asm_utils::ast::Expression; /// Compiles a rust file all the way down to PIL and generates /// fixed and witness columns. @@ -85,7 +90,7 @@ pub fn compile_riscv_asm_bundle( return; } - let powdr_asm = compiler::compile_riscv_asm(riscv_asm_files); + let powdr_asm = compiler::Risc::compile(riscv_asm_files); fs::write(powdr_asm_file_name.clone(), &powdr_asm).unwrap(); log::info!("Wrote {}", powdr_asm_file_name.to_str().unwrap()); diff --git a/riscv/src/parser.rs b/riscv/src/parser.rs index c337f4320..3c4c3c4fb 100644 --- a/riscv/src/parser.rs +++ b/riscv/src/parser.rs @@ -1,7 +1,9 @@ -use std::fmt::{self, Display}; - use lalrpop_util::*; +use crate::{ + compiler::{FunctionKind, Register}, + Statement, +}; use parser_util::handle_parse_error; lalrpop_mod!( @@ -10,222 +12,23 @@ lalrpop_mod!( "/riscv_asm.rs" ); -#[derive(Clone)] -pub enum Statement { - Label(String), - Directive(String, Vec), - Instruction(String, Vec), +pub struct RiscParser { + parser: riscv_asm::StatementsParser, } -#[derive(Clone)] -pub enum Argument { - Register(Register), - RegOffset(Register, Expression), - StringLiteral(Vec), - Expression(Expression), -} - -impl Argument { - pub(crate) fn post_visit_expressions_mut(&mut self, f: &mut impl FnMut(&mut Expression)) { - match self { - Argument::Register(_) | Argument::StringLiteral(_) => (), - Argument::RegOffset(_, expr) | Argument::Expression(expr) => { - expr.post_visit_mut(f); - } - } - } - - pub(crate) fn post_visit_expressions<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) { - match self { - Argument::Register(_) | Argument::StringLiteral(_) => (), - Argument::RegOffset(_, expr) | Argument::Expression(expr) => { - expr.post_visit(f); - } +impl Default for RiscParser { + fn default() -> Self { + Self { + parser: riscv_asm::StatementsParser::new(), } } } -#[derive(Clone, Copy, PartialEq, Eq)] -pub struct Register(u8); - -impl Register { - pub fn is_zero(&self) -> bool { - self.0 == 0 - } -} - -#[derive(Clone, Copy)] -pub enum UnaryOpKind { - HiDataRef, - LoDataRef, - Negation, -} - -#[derive(Clone, Copy)] -pub enum BinaryOpKind { - Or, - Xor, - And, - LeftShift, - RightShift, - Add, - Sub, - Mul, - Div, - Mod, -} - -#[derive(Clone)] -pub enum Expression { - Number(i64), - Symbol(String), - UnaryOp(UnaryOpKind, Box<[Expression]>), - BinaryOp(BinaryOpKind, Box<[Expression; 2]>), -} - -impl Expression { - fn post_visit<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) { - match self { - Expression::UnaryOp(_, subexpr) => subexpr.iter(), - Expression::BinaryOp(_, subexprs) => subexprs.iter(), - _ => [].iter(), - } - .for_each(|subexpr| { - Self::post_visit(subexpr, f); - }); - f(self); - } - - fn post_visit_mut(&mut self, f: &mut impl FnMut(&mut Expression)) { - match self { - Expression::UnaryOp(_, subexpr) => subexpr.iter_mut(), - Expression::BinaryOp(_, subexprs) => subexprs.iter_mut(), - _ => [].iter_mut(), - } - .for_each(|subexpr| { - Self::post_visit_mut(subexpr, f); - }); - f(self); - } -} - -fn new_unary_op(op: UnaryOpKind, v: Expression) -> Expression { - Expression::UnaryOp(op, Box::new([v])) -} - -fn new_binary_op(op: BinaryOpKind, l: Expression, r: Expression) -> Expression { - Expression::BinaryOp(op, Box::new([l, r])) -} - -impl Display for Statement { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Statement::Label(l) => writeln!(f, "{l}:"), - Statement::Directive(d, args) => writeln!(f, " {d} {}", format_arguments(args)), - Statement::Instruction(i, args) => writeln!(f, " {i} {}", format_arguments(args)), - } - } -} - -impl Display for Argument { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Argument::Register(r) => write!(f, "{r}"), - Argument::RegOffset(reg, off) => write!(f, "{off}({reg})"), - Argument::StringLiteral(lit) => write!(f, "\"{}\"", String::from_utf8_lossy(lit)), - Argument::Expression(expr) => write!(f, "{expr}"), - } - } -} - -impl Display for Expression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Expression::Number(n) => write!(f, "{n}"), - Expression::Symbol(sym) => write!(f, "{sym}"), - Expression::UnaryOp(UnaryOpKind::Negation, expr) => write!(f, "(-{})", expr[0]), - Expression::UnaryOp(UnaryOpKind::HiDataRef, expr) => write!(f, "%hi({})", expr[0]), - Expression::UnaryOp(UnaryOpKind::LoDataRef, expr) => write!(f, "%lo({})", expr[0]), - Expression::BinaryOp(op, args) => { - let symbol = match op { - BinaryOpKind::Or => "|", - BinaryOpKind::Xor => "^", - BinaryOpKind::And => "&", - BinaryOpKind::LeftShift => "<<", - BinaryOpKind::RightShift => ">>", - BinaryOpKind::Add => "+", - BinaryOpKind::Sub => "-", - BinaryOpKind::Mul => "*", - BinaryOpKind::Div => "/", - BinaryOpKind::Mod => "%", - }; - write!(f, "({} {symbol} {})", args[0], args[1]) - } - } - } -} - -fn format_arguments(args: &[Argument]) -> String { - args.iter() - .map(|a| format!("{a}")) - .collect::>() - .join(", ") -} - -impl Display for Register { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "x{}", self.0) - } -} - -pub fn parse_asm(input: &str) -> Vec { - let parser = riscv_asm::StatementsParser::new(); - input - .split('\n') - .map(|l| l.trim()) - .filter(|l| !l.is_empty()) - .flat_map(|line| { - parser - .parse(line) - .map_err(|err| { - handle_parse_error(err, None, line).output_to_stderr(); - panic!("RISCV assembly parse error"); - }) - .unwrap() - }) - .collect() -} - -/// Parse an escaped string - used in the grammar. -fn unescape_string(s: &str) -> Vec { - assert!(s.len() >= 2); - assert!(s.starts_with('"') && s.ends_with('"')); - let mut chars = s[1..s.len() - 1].chars(); - let mut result = vec![]; - while let Some(c) = chars.next() { - result.push(if c == '\\' { - let next = chars.next().unwrap(); - if next.is_ascii_digit() { - // octal number. - let n = next as u8 - b'0'; - let nn = chars.next().unwrap() as u8 - b'0'; - let nnn = chars.next().unwrap() as u8 - b'0'; - nnn + nn * 8 + n * 64 - } else if next == 'x' { - todo!("Parse hex digit"); - } else { - (match next { - 'n' => '\n', - 'r' => '\r', - 't' => '\t', - 'b' => 8 as char, - 'f' => 12 as char, - other => other, - }) as u8 - } - } else { - c as u8 +impl asm_utils::parser::Parser for RiscParser { + fn parse(&self, input: &str) -> Result, String> { + self.parser.parse(input).map_err(|err| { + handle_parse_error(err, None, input).output_to_stderr(); + panic!("RISCV assembly parse error"); }) } - result } diff --git a/riscv/src/riscv_asm.lalrpop b/riscv/src/riscv_asm.lalrpop index 83b8a4905..0a83be5f6 100644 --- a/riscv/src/riscv_asm.lalrpop +++ b/riscv/src/riscv_asm.lalrpop @@ -1,5 +1,7 @@ use std::str::FromStr; -use crate::parser::{Statement, Argument, Register, Expression, unescape_string, BinaryOpKind as BOp, UnaryOpKind as UOp, new_binary_op as bin_op, new_unary_op as un_op}; +use asm_utils::ast::{unescape_string, BinaryOpKind as BOp, UnaryOpKind as UOp, + new_binary_op as bin_op, new_unary_op as un_op, new_function_op as fn_op}; +use crate::{Argument, Register, Statement, FunctionKind as FOp, Expression}; grammar; @@ -75,22 +77,22 @@ Argument: Argument = { } Register: Register = { - r"x[0-9]" => Register(<>[1..].parse().unwrap()), - r"x1[0-9]" => Register(<>[1..].parse().unwrap()), - r"x2[0-9]" => Register(<>[1..].parse().unwrap()), - r"x3[0-1]" => Register(<>[1..].parse().unwrap()), - "zero" => Register(0), - "ra" => Register(1), - "sp" => Register(2), - "gp" => Register(3), - "tp" => Register(4), - r"a[0-7]" => Register(10 + <>[1..].parse::().unwrap()), - "fp" => Register(8), - r"s[0-1]" => Register(8 + <>[1..].parse::().unwrap()), - r"s[2-9]" => Register(16 + <>[1..].parse::().unwrap()), - r"s1[0-1]" => Register(16 + <>[1..].parse::().unwrap()), - r"t[0-2]" => Register(5 + <>[1..].parse::().unwrap()), - r"t[3-6]" => Register(25 + <>[1..].parse::().unwrap()), + r"x[0-9]" => Register::new(<>[1..].parse().unwrap()), + r"x1[0-9]" => Register::new(<>[1..].parse().unwrap()), + r"x2[0-9]" => Register::new(<>[1..].parse().unwrap()), + r"x3[0-1]" => Register::new(<>[1..].parse().unwrap()), + "zero" => Register::new(0), + "ra" => Register::new(1), + "sp" => Register::new(2), + "gp" => Register::new(3), + "tp" => Register::new(4), + r"a[0-7]" => Register::new(10 + <>[1..].parse::().unwrap()), + "fp" => Register::new(8), + r"s[0-1]" => Register::new(8 + <>[1..].parse::().unwrap()), + r"s[2-9]" => Register::new(16 + <>[1..].parse::().unwrap()), + r"s1[0-1]" => Register::new(16 + <>[1..].parse::().unwrap()), + r"t[0-2]" => Register::new(5 + <>[1..].parse::().unwrap()), + r"t[3-6]" => Register::new(25 + <>[1..].parse::().unwrap()), } OffsetRegister: Argument = { @@ -144,8 +146,8 @@ ExprUnary: Expression = { ExprTerm: Expression = { Number => Expression::Number(<>), "(" ")" => <>, - "%hi(" ")" => un_op(UOp::HiDataRef, <>), - "%lo(" ")" => un_op(UOp::LoDataRef, <>), + "%hi(" ")" => fn_op(FOp::HiDataRef, <>), + "%lo(" ")" => fn_op(FOp::LoDataRef, <>), Symbol => Expression::Symbol(<>) } @@ -169,4 +171,4 @@ Symbol: String = { Number: i64 = { r"-?[0-9][0-9_]*" => i64::from_str(<>).unwrap().into(), r"0x[0-9A-Fa-f][0-9A-Fa-f_]*" => i64::from_str_radix(&<>[2..].replace('_', ""), 16).unwrap().into(), -} +} \ No newline at end of file diff --git a/riscv/tests/instructions.rs b/riscv/tests/instructions.rs index fa67f8d46..7ec8b3d9b 100644 --- a/riscv/tests/instructions.rs +++ b/riscv/tests/instructions.rs @@ -1,12 +1,13 @@ mod instruction_tests { + use asm_utils::compiler::Compiler; use compiler::verify_asm_string; use number::GoldilocksField; - use riscv::compiler::compile_riscv_asm; + use riscv::compiler::Risc; use test_log::test; fn run_instruction_test(assembly: &str, name: &str) { // TODO Should we create one powdr asm from all tests or keep them separate? - let powdr_asm = compile_riscv_asm([(name.to_string(), assembly.to_string())].into()); + let powdr_asm = Risc::compile([(name.to_string(), assembly.to_string())].into()); verify_asm_string::(&format!("{name}.asm"), &powdr_asm, vec![]); } diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 2a121b83f..f00a4ce0d 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -1,3 +1,4 @@ +use asm_utils::compiler::Compiler; use compiler::verify_asm_string; use number::GoldilocksField; use test_log::test; @@ -95,7 +96,7 @@ fn test_print() { fn verify_file(case: &str, inputs: Vec) { let riscv_asm = riscv::compile_rust_to_riscv_asm(&format!("tests/riscv_data/{case}")); - let powdr_asm = riscv::compiler::compile_riscv_asm(riscv_asm); + let powdr_asm = riscv::compiler::Risc::compile(riscv_asm); verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); } @@ -103,7 +104,7 @@ fn verify_file(case: &str, inputs: Vec) { fn verify_crate(case: &str, inputs: Vec) { let riscv_asm = riscv::compile_rust_crate_to_riscv_asm(&format!("tests/riscv_data/{case}/Cargo.toml")); - let powdr_asm = riscv::compiler::compile_riscv_asm(riscv_asm); + let powdr_asm = riscv::compiler::Risc::compile(riscv_asm); verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); }