From 56efb771018d6cc123831af0c56becd6b6dfc582 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 7 Jun 2023 09:12:33 +0200 Subject: [PATCH 1/4] Support for multiple input assembly files. --- powdr_cli/Cargo.toml | 2 +- powdr_cli/src/main.rs | 26 ++++++++++++++++++-------- riscv/src/lib.rs | 10 ++++++---- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/powdr_cli/Cargo.toml b/powdr_cli/Cargo.toml index 522a68dc2..9ae50be4e 100644 --- a/powdr_cli/Cargo.toml +++ b/powdr_cli/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -clap = { version = "^4.1", features = ["derive"] } +clap = { version = "^4.3", features = ["derive"] } env_logger = "0.10.0" log = "0.4.17" compiler = { path = "../compiler" } diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index 902676337..99877c0cb 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -7,6 +7,7 @@ use compiler::{compile_pil_or_asm, Backend}; use env_logger::{Builder, Target}; use log::LevelFilter; use number::{Bn254Field, FieldElement, GoldilocksField}; +use std::{borrow::Cow, fs, io::Write, path::Path}; use riscv::{compile_riscv_asm, compile_rust}; use std::{fs, io::Write, path::Path}; use strum::{Display, EnumString, EnumVariantNames}; @@ -99,8 +100,9 @@ enum Commands { /// Compiles riscv assembly to powdr assembly and then to PIL /// and generates fixed and witness columns. RiscvAsm { - /// Input file - file: String, + /// Input files + #[arg(required = true)] + files: Vec, /// The field to use #[arg(long)] @@ -113,7 +115,7 @@ enum Commands { #[arg(default_value_t = String::new())] inputs: String, - /// Directory for output files. + /// Directory for output files. #[arg(short, long)] #[arg(default_value_t = String::from("."))] output_directory: String, @@ -186,22 +188,30 @@ fn main() { prove_with ), Commands::RiscvAsm { - file, + files, field, inputs, output_directory, force, prove_with, - } => call_with_field!( + } => { + assert!(!files.is_empty()); + let name = if files.len() == 1 { + Cow::Owned(files[0].clone()) + } else { + Cow::Borrowed("output") + }; + +call_with_field!( compile_riscv_asm, field, - &file, - &file, + &name, + files.into_iter(), split_inputs(&inputs), Path::new(&output_directory), force, prove_with - ), + );}, Commands::Reformat { file } => { let contents = fs::read_to_string(&file).unwrap(); match parser::parse::(Some(&file), &contents) { diff --git a/riscv/src/lib.rs b/riscv/src/lib.rs index 0e6642396..86dded0dd 100644 --- a/riscv/src/lib.rs +++ b/riscv/src/lib.rs @@ -104,17 +104,19 @@ pub fn compile_riscv_asm_bundle( /// fixed and witness columns. pub fn compile_riscv_asm( original_file_name: &str, - file_name: &str, + file_names: impl Iterator, inputs: Vec, output_dir: &Path, force_overwrite: bool, prove_with: Option, ) { - let contents = fs::read_to_string(file_name).unwrap(); compile_riscv_asm_bundle( original_file_name, - vec![(file_name.to_string(), contents)] - .into_iter() + file_names + .map(|name| { + let contents = fs::read_to_string(&name).unwrap(); + (name, contents) + }) .collect(), inputs, output_dir, From 7c67bf70d1bbec418eed302dfe45c284d3688a4d Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Thu, 25 May 2023 11:26:58 +0200 Subject: [PATCH 2/4] Support for multiple input assembly files. --- powdr_cli/src/main.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index 99877c0cb..e521a2139 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -7,9 +7,8 @@ use compiler::{compile_pil_or_asm, Backend}; use env_logger::{Builder, Target}; use log::LevelFilter; use number::{Bn254Field, FieldElement, GoldilocksField}; -use std::{borrow::Cow, fs, io::Write, path::Path}; use riscv::{compile_riscv_asm, compile_rust}; -use std::{fs, io::Write, path::Path}; +use std::{borrow::Cow, fs, io::Write, path::Path}; use strum::{Display, EnumString, EnumVariantNames}; #[derive(Clone, EnumString, EnumVariantNames, Display)] @@ -202,16 +201,17 @@ fn main() { Cow::Borrowed("output") }; -call_with_field!( - compile_riscv_asm, - field, - &name, - files.into_iter(), - split_inputs(&inputs), - Path::new(&output_directory), - force, - prove_with - );}, + call_with_field!( + compile_riscv_asm, + field, + &name, + files.into_iter(), + split_inputs(&inputs), + Path::new(&output_directory), + force, + prove_with + ); + } Commands::Reformat { file } => { let contents = fs::read_to_string(&file).unwrap(); match parser::parse::(Some(&file), &contents) { From 4bc194430410fd2e167499fe170660adb2d2ef6a Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Wed, 31 May 2023 14:09:03 +0100 Subject: [PATCH 3/4] Fixed test, but not the bug that made it pass. --- riscv/tests/riscv_data/print.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riscv/tests/riscv_data/print.rs b/riscv/tests/riscv_data/print.rs index cf7409362..2ab934efc 100644 --- a/riscv/tests/riscv_data/print.rs +++ b/riscv/tests/riscv_data/print.rs @@ -5,6 +5,6 @@ use runtime::{get_prover_input, print}; #[no_mangle] pub fn main() { let input = get_prover_input(0); - print(format_args!("Input in hex: {input:x}\n")); + print!("Input in hex: {input:x}\n"); assert_eq!([1, 2, 3], [4, 5, 6]); } From 0a28e3d7c81ee84293acd8acac1e5671a650efcc Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Fri, 26 May 2023 12:45:13 +0200 Subject: [PATCH 4/4] More generic RISC-V asm parsing. --- riscv/src/compiler.rs | 133 ++++++++++++++++++++---------------- riscv/src/data_parser.rs | 26 +++---- riscv/src/disambiguator.rs | 99 +++++++-------------------- riscv/src/parser.rs | 127 +++++++++++++++++++++++++++------- riscv/src/reachability.rs | 83 +++++++++------------- riscv/src/riscv_asm.lalrpop | 88 +++++++++++------------- 6 files changed, 286 insertions(+), 270 deletions(-) diff --git a/riscv/src/compiler.rs b/riscv/src/compiler.rs index 0181343fa..5a90303bb 100644 --- a/riscv/src/compiler.rs +++ b/riscv/src/compiler.rs @@ -3,10 +3,10 @@ use std::collections::{BTreeMap, BTreeSet}; use itertools::Itertools; use crate::data_parser::{self, DataValue}; -use crate::parser::{self, Argument, Register, Statement}; +use crate::parser::{self, Argument, Register, Statement, UnaryOpKind}; use crate::{disambiguator, reachability}; -use super::parser::Constant; +use super::parser::Expression; /// Compiles riscv assembly to POWDR assembly. Adds required library routines. pub fn compile_riscv_asm(mut assemblies: BTreeMap) -> String { @@ -52,7 +52,7 @@ pub fn compile_riscv_asm(mut assemblies: BTreeMap) -> String { "jump __runtime_start;".to_string(), ]) .chain( - insert_data_positions(statements, &data_positions) + substitute_symbols_with_values(statements, &data_positions) .into_iter() .flat_map(process_statement), ) @@ -125,19 +125,20 @@ fn replace_dynamic_label_reference( if instr1.as_str() != "lui" || instr2.as_str() != "addi" { return None; }; - let [Argument::Register(r1), Argument::Constant(Constant::HiDataRef(label1, offset1))] = &args1[..] else { return None; }; - let [Argument::Register(r2), Argument::Register(r3), Argument::Constant(Constant::LoDataRef(label2, offset2))] = &args2[..] else { return None; }; - if r1 != r3 - || label1 != label2 - || *offset1 != 0 - || *offset2 != 0 - || data_objects.contains_key(label1) - { + let [Argument::Register(r1), Argument::Expression(Expression::UnaryOp(UnaryOpKind::HiDataRef, expr1))] = &args1[..] else { return None; }; + // Maybe should try to reduce expr1 and expr2 before comparing deciding it is a pure symbol? + let Expression::Symbol(label1) = &expr1[0] else { return None; }; + let [Argument::Register(r2), Argument::Register(r3), Argument::Expression(Expression::UnaryOp(UnaryOpKind::LoDataRef, expr2))] = &args2[..] else { return None; }; + let Expression::Symbol(label2) = &expr2[0] else { return None; }; + if r1 != r3 || label1 != label2 || data_objects.contains_key(label1) { return None; } Some(Statement::Instruction( "load_dynamic".to_string(), - vec![Argument::Register(*r2), Argument::Symbol(label1.clone())], + vec![ + Argument::Register(*r2), + Argument::Expression(Expression::Symbol(label1.clone())), + ], )) } @@ -217,48 +218,55 @@ fn next_multiple_of_four(x: usize) -> usize { ((x + 3) / 4) * 4 } -fn insert_data_positions( +fn substitute_symbols_with_values( mut statements: Vec, data_positions: &BTreeMap, ) -> Vec { for s in &mut statements { let Statement::Instruction(_name, args) = s else { continue; }; for arg in args { - match arg { - Argument::RegOffset(_, offset) => replace_data_reference(offset, data_positions), - Argument::Constant(c) => replace_data_reference(c, data_positions), - Argument::Symbol(symb) => { + arg.post_visit_expressions_mut(&mut |expression| match expression { + Expression::Number(_) => {} + Expression::Symbol(symb) => { if let Some(pos) = data_positions.get(symb) { - *arg = Argument::Constant(Constant::Number(*pos as i64)) + *expression = Expression::Number(*pos as i64) } } - _ => {} - } + Expression::UnaryOp(op, subexpr) => { + if let Expression::Number(num) = subexpr[0] { + let result = match op { + UnaryOpKind::HiDataRef => num >> 12, + UnaryOpKind::LoDataRef => num & 0xfff, + UnaryOpKind::Negation => -num, + }; + *expression = Expression::Number(result); + }; + } + Expression::BinaryOp(op, subexprs) => { + if let (Expression::Number(a), Expression::Number(b)) = + (&subexprs[0], &subexprs[1]) + { + let result = match op { + parser::BinaryOpKind::Or => a | b, + parser::BinaryOpKind::Xor => a ^ b, + parser::BinaryOpKind::And => a & b, + parser::BinaryOpKind::LeftShift => a << b, + parser::BinaryOpKind::RightShift => a >> b, + parser::BinaryOpKind::Add => a + b, + parser::BinaryOpKind::Sub => a - b, + parser::BinaryOpKind::Mul => a * b, + parser::BinaryOpKind::Div => a / b, + parser::BinaryOpKind::Mod => a % b, + }; + *expression = Expression::Number(result); + } + } + }); } } statements } -fn replace_data_reference(constant: &mut Constant, data_positions: &BTreeMap) { - match constant { - Constant::Number(_) => {} - Constant::HiDataRef(data, offset) => { - if let Some(pos) = data_positions.get(data) { - let pos: u32 = pos + *offset as u32; - *constant = Constant::Number((pos >> 12) as i64) - } - // Otherwise, it references a code label - } - Constant::LoDataRef(data, offset) => { - if let Some(pos) = data_positions.get(data) { - let pos: u32 = pos + *offset as u32; - *constant = Constant::Number((pos & 0xfff) as i64) - } - // Otherwise, it references a code label - } - } -} - fn preamble() -> String { r#" degree 262144; @@ -601,19 +609,26 @@ fn escape_label(l: &str) -> String { } fn argument_to_number(x: &Argument) -> u32 { - if let Argument::Constant(c) = x { - constant_to_number(c) + if let Argument::Expression(expr) = x { + expression_to_number(expr) } else { - panic!("Expected number, got {x}") + panic!("Expected numeric expression, got {x}") } } -fn constant_to_number(c: &Constant) -> u32 { - match c { - Constant::Number(n) => *n as u32, - Constant::HiDataRef(n, off) | Constant::LoDataRef(n, off) => { - panic!("Data reference was not erased during preprocessing: {n} + {off}"); - } +fn expression_to_number(expr: &Expression) -> u32 { + if let Expression::Number(n) = expr { + *n as u32 + } else { + panic!("Constant expression could not be fully resolved to a number during preprocessing: {expr}"); + } +} + +fn argument_to_escaped_symbol(x: &Argument) -> String { + if let Argument::Expression(Expression::Symbol(symb)) = x { + escape_label(symb) + } else { + panic!("Expected a symbol, got {x}"); } } @@ -654,8 +669,8 @@ fn rr(args: &[Argument]) -> (Register, Register) { fn rrl(args: &[Argument]) -> (Register, Register, String) { match args { - [Argument::Register(r1), Argument::Register(r2), Argument::Symbol(l)] => { - (*r1, *r2, escape_label(l)) + [Argument::Register(r1), Argument::Register(r2), l] => { + (*r1, *r2, argument_to_escaped_symbol(l)) } _ => panic!(), } @@ -663,7 +678,7 @@ fn rrl(args: &[Argument]) -> (Register, Register, String) { fn rl(args: &[Argument]) -> (Register, String) { match args { - [Argument::Register(r1), Argument::Symbol(l)] => (*r1, escape_label(l)), + [Argument::Register(r1), l] => (*r1, argument_to_escaped_symbol(l)), _ => panic!(), } } @@ -671,7 +686,7 @@ fn rl(args: &[Argument]) -> (Register, String) { fn rro(args: &[Argument]) -> (Register, Register, u32) { match args { [Argument::Register(r1), Argument::RegOffset(r2, off)] => { - (*r1, *r2, constant_to_number(off)) + (*r1, *r2, expression_to_number(off)) } _ => panic!(), } @@ -952,8 +967,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec { // jump and call "j" => { - if let [Argument::Symbol(label)] = args { - vec![format!("jump {};", escape_label(label))] + if let [label] = args { + vec![format!("jump {};", argument_to_escaped_symbol(label))] } else { panic!() } @@ -972,8 +987,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec { vec![format!("jump_and_link_dyn {rs};")] } "call" => { - if let [Argument::Symbol(label)] = args { - vec![format!("call {};", escape_label(label))] + if let [label] = args { + vec![format!("call {};", argument_to_escaped_symbol(label))] } else { panic!() } @@ -989,8 +1004,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec { vec!["x0 <=X= ${ (\"print_char\", x10) };\n".to_string()] } "tail" => { - if let [Argument::Symbol(label)] = args { - vec![format!("tail {};", escape_label(label))] + if let [label] = args { + vec![format!("tail {};", argument_to_escaped_symbol(label))] } else { panic!() } diff --git a/riscv/src/data_parser.rs b/riscv/src/data_parser.rs index 999b70487..23e964290 100644 --- a/riscv/src/data_parser.rs +++ b/riscv/src/data_parser.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use crate::parser::{Argument, Constant, Statement}; +use crate::parser::{Argument, Expression, Statement}; pub enum DataValue { Direct(Vec), @@ -34,9 +34,10 @@ pub fn extract_data_objects( current_label = Some(l.as_str()); } Statement::Directive(dir, args) => match (dir.as_str(), &args[..]) { - (".type", [Argument::Symbol(name), Argument::Symbol(kind)]) - if kind.as_str() == "@object" => - { + ( + ".type", + [Argument::Expression(Expression::Symbol(name)), Argument::Expression(Expression::Symbol(kind))], + ) if kind.as_str() == "@object" => { object_order.push(name.clone()); assert!(objects.insert(name.clone(), vec![]).is_none()); } @@ -47,9 +48,10 @@ pub fn extract_data_objects( entry.extend(extract_data_value(dir.as_str(), args)); }); } - (".size", [Argument::Symbol(name), Argument::Constant(Constant::Number(n))]) - if Some(name.as_str()) == current_label => - { + ( + ".size", + [Argument::Expression(Expression::Symbol(name)), Argument::Expression(Expression::Number(n))], + ) if Some(name.as_str()) == current_label => { objects .entry(current_label.unwrap().into()) .and_modify(|entry| { @@ -76,9 +78,9 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec match (directive, arguments) { ( ".zero", - [Argument::Constant(Constant::Number(n))] + [Argument::Expression(Expression::Number(n))] // TODO not clear what the second argument is - | [Argument::Constant(Constant::Number(n)), _], + | [Argument::Expression(Expression::Number(n)), _], ) => { vec![DataValue::Zero(*n as usize)] } @@ -95,7 +97,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec .iter() .map(|x| { match x { - Argument::Constant(Constant::Number(n)) =>{ + Argument::Expression(Expression::Number(n)) =>{ let n = *n as u32; DataValue::Direct(vec![ (n & 0xff) as u8, @@ -104,7 +106,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec (n >> 24 & 0xff) as u8, ]) } - Argument::Symbol(sym) => { + Argument::Expression(Expression::Symbol(sym)) => { DataValue::Reference(sym.clone()) } _ => panic!("Invalid .word directive") @@ -117,7 +119,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec vec![DataValue::Direct(data .iter() .map(|x| { - if let Argument::Constant(Constant::Number(n)) = x { + if let Argument::Expression(Expression::Number(n)) = x { *n as u8 } else { panic!("Invalid argument to .byte directive") diff --git a/riscv/src/disambiguator.rs b/riscv/src/disambiguator.rs index 9b7d903e9..e5ab7af9d 100644 --- a/riscv/src/disambiguator.rs +++ b/riscv/src/disambiguator.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use itertools::Itertools; -use crate::parser::{Argument, Constant, Statement}; +use crate::parser::{Argument, Expression, Statement}; pub fn disambiguate(assemblies: Vec<(String, Vec)>) -> Vec { let globals = assemblies @@ -12,7 +12,10 @@ pub fn disambiguate(assemblies: Vec<(String, Vec)>) -> Vec assemblies .into_iter() - .map(|(name, statements)| disambiguate_file(&name, statements, &globals)) + .map(|(name, mut statements)| { + disambiguate_file(&name, &mut statements, &globals); + statements + }) .concat() } @@ -25,7 +28,7 @@ fn extract_globals(statements: &[Statement]) -> HashSet { return args .iter() .map(|a| { - if let Argument::Symbol(s) = a { + if let Argument::Expression(Expression::Symbol(s)) = a { s.clone() } else { panic!("Invalid .globl directive: {s}"); @@ -40,82 +43,30 @@ fn extract_globals(statements: &[Statement]) -> HashSet { .collect() } -fn disambiguate_file( - file_name: &str, - statements: Vec, - globals: &HashSet, -) -> Vec { +fn disambiguate_file(file_name: &str, statements: &mut [Statement], globals: &HashSet) { let prefix = file_name.replace('-', "_dash_"); - statements - .into_iter() - .map(|s| match s { - Statement::Label(l) => { - Statement::Label(disambiguate_symbol_if_needed(l, &prefix, globals)) + for s in statements { + match s { + Statement::Label(l) => disambiguate_symbol_if_needed(l, &prefix, globals), + Statement::Directive(_, args) | Statement::Instruction(_, args) => { + for arg in args.iter_mut() { + disambiguate_argument_if_needed(arg, &prefix, globals); + } } - Statement::Directive(dir, args) => Statement::Directive( - dir, - disambiguate_arguments_if_needed(args, &prefix, globals), - ), - Statement::Instruction(instr, args) => Statement::Instruction( - instr, - disambiguate_arguments_if_needed(args, &prefix, globals), - ), - }) - .collect() -} - -fn disambiguate_arguments_if_needed( - args: Vec, - prefix: &str, - globals: &HashSet, -) -> Vec { - args.into_iter() - .map(|a| disambiguate_argument_if_needed(a, prefix, globals)) - .collect() -} - -fn disambiguate_argument_if_needed( - arg: Argument, - prefix: &str, - globals: &HashSet, -) -> Argument { - match arg { - Argument::Register(_) | Argument::StringLiteral(_) => arg, - Argument::RegOffset(reg, constant) => Argument::RegOffset( - reg, - disambiguate_constant_if_needed(constant, prefix, globals), - ), - Argument::Constant(c) => { - Argument::Constant(disambiguate_constant_if_needed(c, prefix, globals)) - } - Argument::Symbol(s) => Argument::Symbol(disambiguate_symbol_if_needed(s, prefix, globals)), - Argument::Difference(l, r) => Argument::Difference( - disambiguate_symbol_if_needed(l, prefix, globals), - disambiguate_symbol_if_needed(r, prefix, globals), - ), - } -} - -fn disambiguate_constant_if_needed( - c: Constant, - prefix: &str, - globals: &HashSet, -) -> Constant { - match c { - Constant::Number(_) => c, - Constant::HiDataRef(s, offset) => { - Constant::HiDataRef(disambiguate_symbol_if_needed(s, prefix, globals), offset) - } - Constant::LoDataRef(s, offset) => { - Constant::LoDataRef(disambiguate_symbol_if_needed(s, prefix, globals), offset) } } } -fn disambiguate_symbol_if_needed(s: String, prefix: &str, globals: &HashSet) -> String { - if globals.contains(s.as_str()) || s.starts_with('@') { - s - } else { - format!("{prefix}__{s}") +fn disambiguate_argument_if_needed(arg: &mut Argument, prefix: &str, globals: &HashSet) { + arg.post_visit_expressions_mut(&mut |expr| { + if let Expression::Symbol(sym) = expr { + disambiguate_symbol_if_needed(sym, prefix, globals); + } + }); +} + +fn disambiguate_symbol_if_needed(s: &mut String, prefix: &str, globals: &HashSet) { + if !s.starts_with('@') && !globals.contains(s.as_str()) { + *s = format!("{prefix}__{s}"); } } diff --git a/riscv/src/parser.rs b/riscv/src/parser.rs index 9d9201ab7..c337f4320 100644 --- a/riscv/src/parser.rs +++ b/riscv/src/parser.rs @@ -20,11 +20,29 @@ pub enum Statement { #[derive(Clone)] pub enum Argument { Register(Register), - RegOffset(Register, Constant), + RegOffset(Register, Expression), StringLiteral(Vec), - Constant(Constant), - Symbol(String), - Difference(String, String), + Expression(Expression), +} + +impl Argument { + pub(crate) fn post_visit_expressions_mut(&mut self, f: &mut impl FnMut(&mut Expression)) { + match self { + Argument::Register(_) | Argument::StringLiteral(_) => (), + Argument::RegOffset(_, expr) | Argument::Expression(expr) => { + expr.post_visit_mut(f); + } + } + } + + pub(crate) fn post_visit_expressions<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) { + match self { + Argument::Register(_) | Argument::StringLiteral(_) => (), + Argument::RegOffset(_, expr) | Argument::Expression(expr) => { + expr.post_visit(f); + } + } + } } #[derive(Clone, Copy, PartialEq, Eq)] @@ -36,11 +54,67 @@ impl Register { } } +#[derive(Clone, Copy)] +pub enum UnaryOpKind { + HiDataRef, + LoDataRef, + Negation, +} + +#[derive(Clone, Copy)] +pub enum BinaryOpKind { + Or, + Xor, + And, + LeftShift, + RightShift, + Add, + Sub, + Mul, + Div, + Mod, +} + #[derive(Clone)] -pub enum Constant { +pub enum Expression { Number(i64), - HiDataRef(String, i64), - LoDataRef(String, i64), + Symbol(String), + UnaryOp(UnaryOpKind, Box<[Expression]>), + BinaryOp(BinaryOpKind, Box<[Expression; 2]>), +} + +impl Expression { + fn post_visit<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) { + match self { + Expression::UnaryOp(_, subexpr) => subexpr.iter(), + Expression::BinaryOp(_, subexprs) => subexprs.iter(), + _ => [].iter(), + } + .for_each(|subexpr| { + Self::post_visit(subexpr, f); + }); + f(self); + } + + fn post_visit_mut(&mut self, f: &mut impl FnMut(&mut Expression)) { + match self { + Expression::UnaryOp(_, subexpr) => subexpr.iter_mut(), + Expression::BinaryOp(_, subexprs) => subexprs.iter_mut(), + _ => [].iter_mut(), + } + .for_each(|subexpr| { + Self::post_visit_mut(subexpr, f); + }); + f(self); + } +} + +fn new_unary_op(op: UnaryOpKind, v: Expression) -> Expression { + Expression::UnaryOp(op, Box::new([v])) +} + +fn new_binary_op(op: BinaryOpKind, l: Expression, r: Expression) -> Expression { + Expression::BinaryOp(op, Box::new([l, r])) } impl Display for Statement { @@ -57,37 +131,40 @@ impl Display for Argument { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Argument::Register(r) => write!(f, "{r}"), - Argument::Constant(c) => write!(f, "{c}"), Argument::RegOffset(reg, off) => write!(f, "{off}({reg})"), Argument::StringLiteral(lit) => write!(f, "\"{}\"", String::from_utf8_lossy(lit)), - Argument::Symbol(s) => write!(f, "{s}"), - Argument::Difference(left, right) => write!(f, "{left} - {right}"), + Argument::Expression(expr) => write!(f, "{expr}"), } } } -impl Display for Constant { +impl Display for Expression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Constant::Number(n) => write!(f, "{n}"), - Constant::HiDataRef(sym, offset) => { - write!(f, "%hi({sym}{})", format_hi_lo_offset(*offset)) - } - Constant::LoDataRef(sym, offset) => { - write!(f, "%lo({sym}{})", format_hi_lo_offset(*offset)) + Expression::Number(n) => write!(f, "{n}"), + Expression::Symbol(sym) => write!(f, "{sym}"), + Expression::UnaryOp(UnaryOpKind::Negation, expr) => write!(f, "(-{})", expr[0]), + Expression::UnaryOp(UnaryOpKind::HiDataRef, expr) => write!(f, "%hi({})", expr[0]), + Expression::UnaryOp(UnaryOpKind::LoDataRef, expr) => write!(f, "%lo({})", expr[0]), + Expression::BinaryOp(op, args) => { + let symbol = match op { + BinaryOpKind::Or => "|", + BinaryOpKind::Xor => "^", + BinaryOpKind::And => "&", + BinaryOpKind::LeftShift => "<<", + BinaryOpKind::RightShift => ">>", + BinaryOpKind::Add => "+", + BinaryOpKind::Sub => "-", + BinaryOpKind::Mul => "*", + BinaryOpKind::Div => "/", + BinaryOpKind::Mod => "%", + }; + write!(f, "({} {symbol} {})", args[0], args[1]) } } } } -fn format_hi_lo_offset(offset: i64) -> String { - match offset { - 0 => String::new(), - 1.. => format!(" + {offset}"), - ..=-1 => format!(" - {}", -offset), - } -} - fn format_arguments(args: &[Argument]) -> String { args.iter() .map(|a| format!("{a}")) diff --git a/riscv/src/reachability.rs b/riscv/src/reachability.rs index 9c1e962f5..fe804132b 100644 --- a/riscv/src/reachability.rs +++ b/riscv/src/reachability.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use itertools::Itertools; use crate::data_parser::DataValue; -use crate::parser::{Argument, Constant, Statement}; +use crate::parser::{Argument, Expression, Statement}; pub fn filter_reachable_from( label: &str, @@ -62,7 +62,10 @@ pub fn filter_reachable_from( let offset = *label_offsets.get(l).unwrap(); basic_block_code_starting_from(&statements[offset..]) .into_iter() - .map(|s| apply_replacement_to_instruction(s, &replacements)) + .map(|mut s| { + apply_replacement_to_instruction(&mut s, &replacements); + s + }) }) .collect(); let referenced_labels = referenced_labels @@ -82,7 +85,8 @@ fn extract_replacements(statements: &[Statement]) -> BTreeMap<&str, &str> { .iter() .filter_map(|s| match s { Statement::Directive(dir, args) if dir.as_str() == ".set" => { - if let [Argument::Symbol(from), Argument::Symbol(to)] = &args[..] { + if let [Argument::Expression(Expression::Symbol(from)), Argument::Expression(Expression::Symbol(to))] = &args[..] + { Some((from.as_str(), to.as_str())) } else { panic!(); @@ -137,23 +141,20 @@ pub fn extract_label_offsets(statements: &[Statement]) -> BTreeMap<&str, usize> } pub fn references_in_statement(statement: &Statement) -> BTreeSet<&str> { + let mut ret = BTreeSet::new(); match statement { - Statement::Label(_) | Statement::Directive(_, _) => Default::default(), - Statement::Instruction(_, args) => args - .iter() - .filter_map(|arg| match arg { - Argument::Register(_) | Argument::StringLiteral(_) => None, - Argument::Symbol(s) => Some(s.as_str()), - Argument::RegOffset(_, c) | Argument::Constant(c) => match c { - Constant::Number(_) => None, - Constant::HiDataRef(s, _offset) | Constant::LoDataRef(s, _offset) => { - Some(s.as_str()) + Statement::Label(_) | Statement::Directive(_, _) => (), + Statement::Instruction(_, args) => { + for arg in args { + arg.post_visit_expressions(&mut |expr| { + if let Expression::Symbol(sym) = expr { + ret.insert(sym.as_str()); } - }, - Argument::Difference(_, _) => todo!(), - }) - .collect(), - } + }); + } + } + }; + ret } fn basic_block_references_starting_from(statements: &[Statement]) -> (Vec<&str>, Vec<&str>) { @@ -208,41 +209,24 @@ fn ends_control_flow(s: &Statement) -> bool { } fn apply_replacement_to_instruction( - statement: Statement, + statement: &mut Statement, replacements: &BTreeMap<&str, &str>, -) -> Statement { +) { match statement { - Statement::Label(_) => statement, - Statement::Instruction(instr, args) => Statement::Instruction( - instr, - args.into_iter() - .map(|a| match a { - Argument::Register(_) | Argument::StringLiteral(_) => a, - Argument::Symbol(s) => Argument::Symbol(replace(s, replacements)), - Argument::RegOffset(reg, c) => { - Argument::RegOffset(reg, apply_replacement_to_constant(c, replacements)) + Statement::Label(_) => (), + Statement::Instruction(_, args) => { + for a in args { + a.post_visit_expressions_mut(&mut |expr| { + if let Expression::Symbol(s) = expr { + replace(s, replacements); } - Argument::Constant(c) => { - Argument::Constant(apply_replacement_to_constant(c, replacements)) - } - Argument::Difference(l, r) => { - Argument::Difference(replace(l, replacements), replace(r, replacements)) - } - }) - .collect(), - ), + }); + } + } _ => panic!("Expected instruction but got: {statement}"), } } -fn apply_replacement_to_constant(c: Constant, replacements: &BTreeMap<&str, &str>) -> Constant { - match c { - Constant::Number(_) => c, - Constant::HiDataRef(s, off) => Constant::HiDataRef(replace(s, replacements), off), - Constant::LoDataRef(s, off) => Constant::LoDataRef(replace(s, replacements), off), - } -} - fn apply_replacement_to_object(object: &mut Vec, replacements: &BTreeMap<&str, &str>) { for value in object { if let DataValue::Reference(reference) = value { @@ -253,9 +237,8 @@ fn apply_replacement_to_object(object: &mut Vec, replacements: &BTree } } -fn replace(s: String, replacements: &BTreeMap<&str, &str>) -> String { - match replacements.get(s.as_str()) { - Some(r) => r.to_string(), - None => s, +fn replace(s: &mut String, replacements: &BTreeMap<&str, &str>) { + if let Some(r) = replacements.get(s.as_str()) { + *s = r.to_string(); } } diff --git a/riscv/src/riscv_asm.lalrpop b/riscv/src/riscv_asm.lalrpop index cb246a875..14d7bdc3c 100644 --- a/riscv/src/riscv_asm.lalrpop +++ b/riscv/src/riscv_asm.lalrpop @@ -1,5 +1,5 @@ use std::str::FromStr; -use crate::parser::{Statement, Argument, Register, Constant, unescape_string}; +use crate::parser::{Statement, Argument, Register, Expression, unescape_string, BinaryOpKind as BOp, UnaryOpKind as UOp, new_binary_op as bin_op, new_unary_op as un_op}; grammar; @@ -55,9 +55,7 @@ Argument: Argument = { Register => Argument::Register(<>), OffsetRegister, StringLiteral => Argument::StringLiteral(<>), - Symbol => Argument::Symbol(<>), - Constant => Argument::Constant(<>), - Difference, + Expression => Argument::Expression(<>), } Register: Register = { @@ -80,68 +78,59 @@ Register: Register = { } OffsetRegister: Argument = { - "(" ")" => Argument::RegOffset(r, c), + "(" ")" => Argument::RegOffset(r, c), } -Constant: Constant = { - ConstantExpression => Constant::Number(<>), - "%hi(" ")" => Constant::HiDataRef(arg.0, arg.1), - "%lo(" ")" => Constant::LoDataRef(arg.0, arg.1), +Expression: Expression = { + ExprBinaryOr, } -HiLowArgument: (String, i64) = { - => (<>, 0), - "+" => (<>), - "-" => (n, -o), +ExprBinaryOr: Expression = { + "|" => bin_op(BOp::Or, l, r), + ExprBinaryXor, } -ConstantExpression: i64 = { - "|" => l | r, - ConstantBinaryXor, +ExprBinaryXor: Expression = { + "^" => bin_op(BOp::Xor, l, r), + ExprBinaryAnd, } -ConstantBinaryXor: i64 = { - "^" => l ^ r, - ConstantBinaryAnd, +ExprBinaryAnd: Expression = { + "&" => bin_op(BOp::And, l, r), + ExprBitShift, } -ConstantBinaryAnd: i64 = { - "&" => l & r, - ConstantBitShift, +ExprBitShift: Expression = { + "<<" => bin_op(BOp::LeftShift, l, r), + ">>" => bin_op(BOp::RightShift, l, r), + ExprSum, } -ConstantBitShift: i64 = { - "<<" => l << r, - ">>" => l >> r, - ConstantSum, +ExprSum: Expression = { + "+" => bin_op(BOp::Add, l, r), + "-" => bin_op(BOp::Sub, l, r), + ExprProduct, } -ConstantSum: i64 = { - "+" => l + r, - "-" => l - r, - ConstantProduct, +ExprProduct: Expression = { + "*" => bin_op(BOp::Mul, l, r), + "/" => bin_op(BOp::Div, l, r), + "%" => bin_op(BOp::Mod, l, r), + ExprUnary, } -ConstantProduct: i64 = { - "*" => l * r, - "/" => l / r, - "%" => l % r, - ConstantUnary, +ExprUnary: Expression = { + "-" => un_op(UOp::Negation, <>), + "+" , + ExprTerm, } -ConstantUnary: i64 = { - "-" => -<>, - "+" => <>, - ConstantTerm, -} - -ConstantTerm: i64 = { - Number, - "(" ")", -} - -Difference: Argument = { - "-" => Argument::Difference(<>) +ExprTerm: Expression = { + Number => Expression::Number(<>), + "(" ")" => <>, + "%hi(" ")" => un_op(UOp::HiDataRef, <>), + "%lo(" ")" => un_op(UOp::LoDataRef, <>), + Symbol => Expression::Symbol(<>) } StringLiteral: Vec = { @@ -164,5 +153,4 @@ Symbol: String = { Number: i64 = { r"-?[0-9][0-9_]*" => i64::from_str(<>).unwrap().into(), r"0x[0-9A-Fa-f][0-9A-Fa-f_]*" => i64::from_str_radix(&<>[2..].replace('_', ""), 16).unwrap().into(), - -} \ No newline at end of file +}