Merge pull request #313 from powdr-org/generic_riscv_asm_parsing

More generic RISC-V asm parsing
This commit is contained in:
chriseth
2023-06-07 10:02:54 +02:00
committed by GitHub
10 changed files with 319 additions and 291 deletions

View File

@@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
clap = { version = "^4.1", features = ["derive"] }
clap = { version = "^4.3", features = ["derive"] }
env_logger = "0.10.0"
log = "0.4.17"
compiler = { path = "../compiler" }

View File

@@ -8,7 +8,7 @@ use env_logger::{Builder, Target};
use log::LevelFilter;
use number::{Bn254Field, FieldElement, GoldilocksField};
use riscv::{compile_riscv_asm, compile_rust};
use std::{fs, io::Write, path::Path};
use std::{borrow::Cow, fs, io::Write, path::Path};
use strum::{Display, EnumString, EnumVariantNames};
#[derive(Clone, EnumString, EnumVariantNames, Display)]
@@ -99,8 +99,9 @@ enum Commands {
/// Compiles riscv assembly to powdr assembly and then to PIL
/// and generates fixed and witness columns.
RiscvAsm {
/// Input file
file: String,
/// Input files
#[arg(required = true)]
files: Vec<String>,
/// The field to use
#[arg(long)]
@@ -113,7 +114,7 @@ enum Commands {
#[arg(default_value_t = String::new())]
inputs: String,
/// Directory for output files.
/// Directory for output files.
#[arg(short, long)]
#[arg(default_value_t = String::from("."))]
output_directory: String,
@@ -186,22 +187,31 @@ fn main() {
prove_with
),
Commands::RiscvAsm {
file,
files,
field,
inputs,
output_directory,
force,
prove_with,
} => call_with_field!(
compile_riscv_asm,
field,
&file,
&file,
split_inputs(&inputs),
Path::new(&output_directory),
force,
prove_with
),
} => {
assert!(!files.is_empty());
let name = if files.len() == 1 {
Cow::Owned(files[0].clone())
} else {
Cow::Borrowed("output")
};
call_with_field!(
compile_riscv_asm,
field,
&name,
files.into_iter(),
split_inputs(&inputs),
Path::new(&output_directory),
force,
prove_with
);
}
Commands::Reformat { file } => {
let contents = fs::read_to_string(&file).unwrap();
match parser::parse::<GoldilocksField>(Some(&file), &contents) {

View File

@@ -3,10 +3,10 @@ use std::collections::{BTreeMap, BTreeSet};
use itertools::Itertools;
use crate::data_parser::{self, DataValue};
use crate::parser::{self, Argument, Register, Statement};
use crate::parser::{self, Argument, Register, Statement, UnaryOpKind};
use crate::{disambiguator, reachability};
use super::parser::Constant;
use super::parser::Expression;
/// Compiles riscv assembly to POWDR assembly. Adds required library routines.
pub fn compile_riscv_asm(mut assemblies: BTreeMap<String, String>) -> String {
@@ -52,7 +52,7 @@ pub fn compile_riscv_asm(mut assemblies: BTreeMap<String, String>) -> String {
"jump __runtime_start;".to_string(),
])
.chain(
insert_data_positions(statements, &data_positions)
substitute_symbols_with_values(statements, &data_positions)
.into_iter()
.flat_map(process_statement),
)
@@ -125,19 +125,20 @@ fn replace_dynamic_label_reference(
if instr1.as_str() != "lui" || instr2.as_str() != "addi" {
return None;
};
let [Argument::Register(r1), Argument::Constant(Constant::HiDataRef(label1, offset1))] = &args1[..] else { return None; };
let [Argument::Register(r2), Argument::Register(r3), Argument::Constant(Constant::LoDataRef(label2, offset2))] = &args2[..] else { return None; };
if r1 != r3
|| label1 != label2
|| *offset1 != 0
|| *offset2 != 0
|| data_objects.contains_key(label1)
{
let [Argument::Register(r1), Argument::Expression(Expression::UnaryOp(UnaryOpKind::HiDataRef, expr1))] = &args1[..] else { return None; };
// Maybe should try to reduce expr1 and expr2 before comparing deciding it is a pure symbol?
let Expression::Symbol(label1) = &expr1[0] else { return None; };
let [Argument::Register(r2), Argument::Register(r3), Argument::Expression(Expression::UnaryOp(UnaryOpKind::LoDataRef, expr2))] = &args2[..] else { return None; };
let Expression::Symbol(label2) = &expr2[0] else { return None; };
if r1 != r3 || label1 != label2 || data_objects.contains_key(label1) {
return None;
}
Some(Statement::Instruction(
"load_dynamic".to_string(),
vec![Argument::Register(*r2), Argument::Symbol(label1.clone())],
vec![
Argument::Register(*r2),
Argument::Expression(Expression::Symbol(label1.clone())),
],
))
}
@@ -217,48 +218,55 @@ fn next_multiple_of_four(x: usize) -> usize {
((x + 3) / 4) * 4
}
fn insert_data_positions(
fn substitute_symbols_with_values(
mut statements: Vec<Statement>,
data_positions: &BTreeMap<String, u32>,
) -> Vec<Statement> {
for s in &mut statements {
let Statement::Instruction(_name, args) = s else { continue; };
for arg in args {
match arg {
Argument::RegOffset(_, offset) => replace_data_reference(offset, data_positions),
Argument::Constant(c) => replace_data_reference(c, data_positions),
Argument::Symbol(symb) => {
arg.post_visit_expressions_mut(&mut |expression| match expression {
Expression::Number(_) => {}
Expression::Symbol(symb) => {
if let Some(pos) = data_positions.get(symb) {
*arg = Argument::Constant(Constant::Number(*pos as i64))
*expression = Expression::Number(*pos as i64)
}
}
_ => {}
}
Expression::UnaryOp(op, subexpr) => {
if let Expression::Number(num) = subexpr[0] {
let result = match op {
UnaryOpKind::HiDataRef => num >> 12,
UnaryOpKind::LoDataRef => num & 0xfff,
UnaryOpKind::Negation => -num,
};
*expression = Expression::Number(result);
};
}
Expression::BinaryOp(op, subexprs) => {
if let (Expression::Number(a), Expression::Number(b)) =
(&subexprs[0], &subexprs[1])
{
let result = match op {
parser::BinaryOpKind::Or => a | b,
parser::BinaryOpKind::Xor => a ^ b,
parser::BinaryOpKind::And => a & b,
parser::BinaryOpKind::LeftShift => a << b,
parser::BinaryOpKind::RightShift => a >> b,
parser::BinaryOpKind::Add => a + b,
parser::BinaryOpKind::Sub => a - b,
parser::BinaryOpKind::Mul => a * b,
parser::BinaryOpKind::Div => a / b,
parser::BinaryOpKind::Mod => a % b,
};
*expression = Expression::Number(result);
}
}
});
}
}
statements
}
fn replace_data_reference(constant: &mut Constant, data_positions: &BTreeMap<String, u32>) {
match constant {
Constant::Number(_) => {}
Constant::HiDataRef(data, offset) => {
if let Some(pos) = data_positions.get(data) {
let pos: u32 = pos + *offset as u32;
*constant = Constant::Number((pos >> 12) as i64)
}
// Otherwise, it references a code label
}
Constant::LoDataRef(data, offset) => {
if let Some(pos) = data_positions.get(data) {
let pos: u32 = pos + *offset as u32;
*constant = Constant::Number((pos & 0xfff) as i64)
}
// Otherwise, it references a code label
}
}
}
fn preamble() -> String {
r#"
degree 262144;
@@ -601,19 +609,26 @@ fn escape_label(l: &str) -> String {
}
fn argument_to_number(x: &Argument) -> u32 {
if let Argument::Constant(c) = x {
constant_to_number(c)
if let Argument::Expression(expr) = x {
expression_to_number(expr)
} else {
panic!("Expected number, got {x}")
panic!("Expected numeric expression, got {x}")
}
}
fn constant_to_number(c: &Constant) -> u32 {
match c {
Constant::Number(n) => *n as u32,
Constant::HiDataRef(n, off) | Constant::LoDataRef(n, off) => {
panic!("Data reference was not erased during preprocessing: {n} + {off}");
}
fn expression_to_number(expr: &Expression) -> u32 {
if let Expression::Number(n) = expr {
*n as u32
} else {
panic!("Constant expression could not be fully resolved to a number during preprocessing: {expr}");
}
}
fn argument_to_escaped_symbol(x: &Argument) -> String {
if let Argument::Expression(Expression::Symbol(symb)) = x {
escape_label(symb)
} else {
panic!("Expected a symbol, got {x}");
}
}
@@ -654,8 +669,8 @@ fn rr(args: &[Argument]) -> (Register, Register) {
fn rrl(args: &[Argument]) -> (Register, Register, String) {
match args {
[Argument::Register(r1), Argument::Register(r2), Argument::Symbol(l)] => {
(*r1, *r2, escape_label(l))
[Argument::Register(r1), Argument::Register(r2), l] => {
(*r1, *r2, argument_to_escaped_symbol(l))
}
_ => panic!(),
}
@@ -663,7 +678,7 @@ fn rrl(args: &[Argument]) -> (Register, Register, String) {
fn rl(args: &[Argument]) -> (Register, String) {
match args {
[Argument::Register(r1), Argument::Symbol(l)] => (*r1, escape_label(l)),
[Argument::Register(r1), l] => (*r1, argument_to_escaped_symbol(l)),
_ => panic!(),
}
}
@@ -671,7 +686,7 @@ fn rl(args: &[Argument]) -> (Register, String) {
fn rro(args: &[Argument]) -> (Register, Register, u32) {
match args {
[Argument::Register(r1), Argument::RegOffset(r2, off)] => {
(*r1, *r2, constant_to_number(off))
(*r1, *r2, expression_to_number(off))
}
_ => panic!(),
}
@@ -952,8 +967,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec<String> {
// jump and call
"j" => {
if let [Argument::Symbol(label)] = args {
vec![format!("jump {};", escape_label(label))]
if let [label] = args {
vec![format!("jump {};", argument_to_escaped_symbol(label))]
} else {
panic!()
}
@@ -972,8 +987,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec<String> {
vec![format!("jump_and_link_dyn {rs};")]
}
"call" => {
if let [Argument::Symbol(label)] = args {
vec![format!("call {};", escape_label(label))]
if let [label] = args {
vec![format!("call {};", argument_to_escaped_symbol(label))]
} else {
panic!()
}
@@ -989,8 +1004,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec<String> {
vec!["x0 <=X= ${ (\"print_char\", x10) };\n".to_string()]
}
"tail" => {
if let [Argument::Symbol(label)] = args {
vec![format!("tail {};", escape_label(label))]
if let [label] = args {
vec![format!("tail {};", argument_to_escaped_symbol(label))]
} else {
panic!()
}

View File

@@ -1,6 +1,6 @@
use std::collections::BTreeMap;
use crate::parser::{Argument, Constant, Statement};
use crate::parser::{Argument, Expression, Statement};
pub enum DataValue {
Direct(Vec<u8>),
@@ -34,9 +34,10 @@ pub fn extract_data_objects(
current_label = Some(l.as_str());
}
Statement::Directive(dir, args) => match (dir.as_str(), &args[..]) {
(".type", [Argument::Symbol(name), Argument::Symbol(kind)])
if kind.as_str() == "@object" =>
{
(
".type",
[Argument::Expression(Expression::Symbol(name)), Argument::Expression(Expression::Symbol(kind))],
) if kind.as_str() == "@object" => {
object_order.push(name.clone());
assert!(objects.insert(name.clone(), vec![]).is_none());
}
@@ -47,9 +48,10 @@ pub fn extract_data_objects(
entry.extend(extract_data_value(dir.as_str(), args));
});
}
(".size", [Argument::Symbol(name), Argument::Constant(Constant::Number(n))])
if Some(name.as_str()) == current_label =>
{
(
".size",
[Argument::Expression(Expression::Symbol(name)), Argument::Expression(Expression::Number(n))],
) if Some(name.as_str()) == current_label => {
objects
.entry(current_label.unwrap().into())
.and_modify(|entry| {
@@ -76,9 +78,9 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
match (directive, arguments) {
(
".zero",
[Argument::Constant(Constant::Number(n))]
[Argument::Expression(Expression::Number(n))]
// TODO not clear what the second argument is
| [Argument::Constant(Constant::Number(n)), _],
| [Argument::Expression(Expression::Number(n)), _],
) => {
vec![DataValue::Zero(*n as usize)]
}
@@ -95,7 +97,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
.iter()
.map(|x| {
match x {
Argument::Constant(Constant::Number(n)) =>{
Argument::Expression(Expression::Number(n)) =>{
let n = *n as u32;
DataValue::Direct(vec![
(n & 0xff) as u8,
@@ -104,7 +106,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
(n >> 24 & 0xff) as u8,
])
}
Argument::Symbol(sym) => {
Argument::Expression(Expression::Symbol(sym)) => {
DataValue::Reference(sym.clone())
}
_ => panic!("Invalid .word directive")
@@ -117,7 +119,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
vec![DataValue::Direct(data
.iter()
.map(|x| {
if let Argument::Constant(Constant::Number(n)) = x {
if let Argument::Expression(Expression::Number(n)) = x {
*n as u8
} else {
panic!("Invalid argument to .byte directive")

View File

@@ -2,7 +2,7 @@ use std::collections::HashSet;
use itertools::Itertools;
use crate::parser::{Argument, Constant, Statement};
use crate::parser::{Argument, Expression, Statement};
pub fn disambiguate(assemblies: Vec<(String, Vec<Statement>)>) -> Vec<Statement> {
let globals = assemblies
@@ -12,7 +12,10 @@ pub fn disambiguate(assemblies: Vec<(String, Vec<Statement>)>) -> Vec<Statement>
assemblies
.into_iter()
.map(|(name, statements)| disambiguate_file(&name, statements, &globals))
.map(|(name, mut statements)| {
disambiguate_file(&name, &mut statements, &globals);
statements
})
.concat()
}
@@ -25,7 +28,7 @@ fn extract_globals(statements: &[Statement]) -> HashSet<String> {
return args
.iter()
.map(|a| {
if let Argument::Symbol(s) = a {
if let Argument::Expression(Expression::Symbol(s)) = a {
s.clone()
} else {
panic!("Invalid .globl directive: {s}");
@@ -40,82 +43,30 @@ fn extract_globals(statements: &[Statement]) -> HashSet<String> {
.collect()
}
fn disambiguate_file(
file_name: &str,
statements: Vec<Statement>,
globals: &HashSet<String>,
) -> Vec<Statement> {
fn disambiguate_file(file_name: &str, statements: &mut [Statement], globals: &HashSet<String>) {
let prefix = file_name.replace('-', "_dash_");
statements
.into_iter()
.map(|s| match s {
Statement::Label(l) => {
Statement::Label(disambiguate_symbol_if_needed(l, &prefix, globals))
for s in statements {
match s {
Statement::Label(l) => disambiguate_symbol_if_needed(l, &prefix, globals),
Statement::Directive(_, args) | Statement::Instruction(_, args) => {
for arg in args.iter_mut() {
disambiguate_argument_if_needed(arg, &prefix, globals);
}
}
Statement::Directive(dir, args) => Statement::Directive(
dir,
disambiguate_arguments_if_needed(args, &prefix, globals),
),
Statement::Instruction(instr, args) => Statement::Instruction(
instr,
disambiguate_arguments_if_needed(args, &prefix, globals),
),
})
.collect()
}
fn disambiguate_arguments_if_needed(
args: Vec<Argument>,
prefix: &str,
globals: &HashSet<String>,
) -> Vec<Argument> {
args.into_iter()
.map(|a| disambiguate_argument_if_needed(a, prefix, globals))
.collect()
}
fn disambiguate_argument_if_needed(
arg: Argument,
prefix: &str,
globals: &HashSet<String>,
) -> Argument {
match arg {
Argument::Register(_) | Argument::StringLiteral(_) => arg,
Argument::RegOffset(reg, constant) => Argument::RegOffset(
reg,
disambiguate_constant_if_needed(constant, prefix, globals),
),
Argument::Constant(c) => {
Argument::Constant(disambiguate_constant_if_needed(c, prefix, globals))
}
Argument::Symbol(s) => Argument::Symbol(disambiguate_symbol_if_needed(s, prefix, globals)),
Argument::Difference(l, r) => Argument::Difference(
disambiguate_symbol_if_needed(l, prefix, globals),
disambiguate_symbol_if_needed(r, prefix, globals),
),
}
}
fn disambiguate_constant_if_needed(
c: Constant,
prefix: &str,
globals: &HashSet<String>,
) -> Constant {
match c {
Constant::Number(_) => c,
Constant::HiDataRef(s, offset) => {
Constant::HiDataRef(disambiguate_symbol_if_needed(s, prefix, globals), offset)
}
Constant::LoDataRef(s, offset) => {
Constant::LoDataRef(disambiguate_symbol_if_needed(s, prefix, globals), offset)
}
}
}
fn disambiguate_symbol_if_needed(s: String, prefix: &str, globals: &HashSet<String>) -> String {
if globals.contains(s.as_str()) || s.starts_with('@') {
s
} else {
format!("{prefix}__{s}")
fn disambiguate_argument_if_needed(arg: &mut Argument, prefix: &str, globals: &HashSet<String>) {
arg.post_visit_expressions_mut(&mut |expr| {
if let Expression::Symbol(sym) = expr {
disambiguate_symbol_if_needed(sym, prefix, globals);
}
});
}
fn disambiguate_symbol_if_needed(s: &mut String, prefix: &str, globals: &HashSet<String>) {
if !s.starts_with('@') && !globals.contains(s.as_str()) {
*s = format!("{prefix}__{s}");
}
}

View File

@@ -104,17 +104,19 @@ pub fn compile_riscv_asm_bundle<T: FieldElement>(
/// fixed and witness columns.
pub fn compile_riscv_asm<T: FieldElement>(
original_file_name: &str,
file_name: &str,
file_names: impl Iterator<Item = String>,
inputs: Vec<T>,
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<Backend>,
) {
let contents = fs::read_to_string(file_name).unwrap();
compile_riscv_asm_bundle(
original_file_name,
vec![(file_name.to_string(), contents)]
.into_iter()
file_names
.map(|name| {
let contents = fs::read_to_string(&name).unwrap();
(name, contents)
})
.collect(),
inputs,
output_dir,

View File

@@ -20,11 +20,29 @@ pub enum Statement {
#[derive(Clone)]
pub enum Argument {
Register(Register),
RegOffset(Register, Constant),
RegOffset(Register, Expression),
StringLiteral(Vec<u8>),
Constant(Constant),
Symbol(String),
Difference(String, String),
Expression(Expression),
}
impl Argument {
pub(crate) fn post_visit_expressions_mut(&mut self, f: &mut impl FnMut(&mut Expression)) {
match self {
Argument::Register(_) | Argument::StringLiteral(_) => (),
Argument::RegOffset(_, expr) | Argument::Expression(expr) => {
expr.post_visit_mut(f);
}
}
}
pub(crate) fn post_visit_expressions<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) {
match self {
Argument::Register(_) | Argument::StringLiteral(_) => (),
Argument::RegOffset(_, expr) | Argument::Expression(expr) => {
expr.post_visit(f);
}
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
@@ -36,11 +54,67 @@ impl Register {
}
}
#[derive(Clone, Copy)]
pub enum UnaryOpKind {
HiDataRef,
LoDataRef,
Negation,
}
#[derive(Clone, Copy)]
pub enum BinaryOpKind {
Or,
Xor,
And,
LeftShift,
RightShift,
Add,
Sub,
Mul,
Div,
Mod,
}
#[derive(Clone)]
pub enum Constant {
pub enum Expression {
Number(i64),
HiDataRef(String, i64),
LoDataRef(String, i64),
Symbol(String),
UnaryOp(UnaryOpKind, Box<[Expression]>),
BinaryOp(BinaryOpKind, Box<[Expression; 2]>),
}
impl Expression {
fn post_visit<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) {
match self {
Expression::UnaryOp(_, subexpr) => subexpr.iter(),
Expression::BinaryOp(_, subexprs) => subexprs.iter(),
_ => [].iter(),
}
.for_each(|subexpr| {
Self::post_visit(subexpr, f);
});
f(self);
}
fn post_visit_mut(&mut self, f: &mut impl FnMut(&mut Expression)) {
match self {
Expression::UnaryOp(_, subexpr) => subexpr.iter_mut(),
Expression::BinaryOp(_, subexprs) => subexprs.iter_mut(),
_ => [].iter_mut(),
}
.for_each(|subexpr| {
Self::post_visit_mut(subexpr, f);
});
f(self);
}
}
fn new_unary_op(op: UnaryOpKind, v: Expression) -> Expression {
Expression::UnaryOp(op, Box::new([v]))
}
fn new_binary_op(op: BinaryOpKind, l: Expression, r: Expression) -> Expression {
Expression::BinaryOp(op, Box::new([l, r]))
}
impl Display for Statement {
@@ -57,37 +131,40 @@ impl Display for Argument {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Argument::Register(r) => write!(f, "{r}"),
Argument::Constant(c) => write!(f, "{c}"),
Argument::RegOffset(reg, off) => write!(f, "{off}({reg})"),
Argument::StringLiteral(lit) => write!(f, "\"{}\"", String::from_utf8_lossy(lit)),
Argument::Symbol(s) => write!(f, "{s}"),
Argument::Difference(left, right) => write!(f, "{left} - {right}"),
Argument::Expression(expr) => write!(f, "{expr}"),
}
}
}
impl Display for Constant {
impl Display for Expression {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Constant::Number(n) => write!(f, "{n}"),
Constant::HiDataRef(sym, offset) => {
write!(f, "%hi({sym}{})", format_hi_lo_offset(*offset))
}
Constant::LoDataRef(sym, offset) => {
write!(f, "%lo({sym}{})", format_hi_lo_offset(*offset))
Expression::Number(n) => write!(f, "{n}"),
Expression::Symbol(sym) => write!(f, "{sym}"),
Expression::UnaryOp(UnaryOpKind::Negation, expr) => write!(f, "(-{})", expr[0]),
Expression::UnaryOp(UnaryOpKind::HiDataRef, expr) => write!(f, "%hi({})", expr[0]),
Expression::UnaryOp(UnaryOpKind::LoDataRef, expr) => write!(f, "%lo({})", expr[0]),
Expression::BinaryOp(op, args) => {
let symbol = match op {
BinaryOpKind::Or => "|",
BinaryOpKind::Xor => "^",
BinaryOpKind::And => "&",
BinaryOpKind::LeftShift => "<<",
BinaryOpKind::RightShift => ">>",
BinaryOpKind::Add => "+",
BinaryOpKind::Sub => "-",
BinaryOpKind::Mul => "*",
BinaryOpKind::Div => "/",
BinaryOpKind::Mod => "%",
};
write!(f, "({} {symbol} {})", args[0], args[1])
}
}
}
}
fn format_hi_lo_offset(offset: i64) -> String {
match offset {
0 => String::new(),
1.. => format!(" + {offset}"),
..=-1 => format!(" - {}", -offset),
}
}
fn format_arguments(args: &[Argument]) -> String {
args.iter()
.map(|a| format!("{a}"))

View File

@@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet, HashSet};
use itertools::Itertools;
use crate::data_parser::DataValue;
use crate::parser::{Argument, Constant, Statement};
use crate::parser::{Argument, Expression, Statement};
pub fn filter_reachable_from(
label: &str,
@@ -62,7 +62,10 @@ pub fn filter_reachable_from(
let offset = *label_offsets.get(l).unwrap();
basic_block_code_starting_from(&statements[offset..])
.into_iter()
.map(|s| apply_replacement_to_instruction(s, &replacements))
.map(|mut s| {
apply_replacement_to_instruction(&mut s, &replacements);
s
})
})
.collect();
let referenced_labels = referenced_labels
@@ -82,7 +85,8 @@ fn extract_replacements(statements: &[Statement]) -> BTreeMap<&str, &str> {
.iter()
.filter_map(|s| match s {
Statement::Directive(dir, args) if dir.as_str() == ".set" => {
if let [Argument::Symbol(from), Argument::Symbol(to)] = &args[..] {
if let [Argument::Expression(Expression::Symbol(from)), Argument::Expression(Expression::Symbol(to))] = &args[..]
{
Some((from.as_str(), to.as_str()))
} else {
panic!();
@@ -137,23 +141,20 @@ pub fn extract_label_offsets(statements: &[Statement]) -> BTreeMap<&str, usize>
}
pub fn references_in_statement(statement: &Statement) -> BTreeSet<&str> {
let mut ret = BTreeSet::new();
match statement {
Statement::Label(_) | Statement::Directive(_, _) => Default::default(),
Statement::Instruction(_, args) => args
.iter()
.filter_map(|arg| match arg {
Argument::Register(_) | Argument::StringLiteral(_) => None,
Argument::Symbol(s) => Some(s.as_str()),
Argument::RegOffset(_, c) | Argument::Constant(c) => match c {
Constant::Number(_) => None,
Constant::HiDataRef(s, _offset) | Constant::LoDataRef(s, _offset) => {
Some(s.as_str())
Statement::Label(_) | Statement::Directive(_, _) => (),
Statement::Instruction(_, args) => {
for arg in args {
arg.post_visit_expressions(&mut |expr| {
if let Expression::Symbol(sym) = expr {
ret.insert(sym.as_str());
}
},
Argument::Difference(_, _) => todo!(),
})
.collect(),
}
});
}
}
};
ret
}
fn basic_block_references_starting_from(statements: &[Statement]) -> (Vec<&str>, Vec<&str>) {
@@ -208,41 +209,24 @@ fn ends_control_flow(s: &Statement) -> bool {
}
fn apply_replacement_to_instruction(
statement: Statement,
statement: &mut Statement,
replacements: &BTreeMap<&str, &str>,
) -> Statement {
) {
match statement {
Statement::Label(_) => statement,
Statement::Instruction(instr, args) => Statement::Instruction(
instr,
args.into_iter()
.map(|a| match a {
Argument::Register(_) | Argument::StringLiteral(_) => a,
Argument::Symbol(s) => Argument::Symbol(replace(s, replacements)),
Argument::RegOffset(reg, c) => {
Argument::RegOffset(reg, apply_replacement_to_constant(c, replacements))
Statement::Label(_) => (),
Statement::Instruction(_, args) => {
for a in args {
a.post_visit_expressions_mut(&mut |expr| {
if let Expression::Symbol(s) = expr {
replace(s, replacements);
}
Argument::Constant(c) => {
Argument::Constant(apply_replacement_to_constant(c, replacements))
}
Argument::Difference(l, r) => {
Argument::Difference(replace(l, replacements), replace(r, replacements))
}
})
.collect(),
),
});
}
}
_ => panic!("Expected instruction but got: {statement}"),
}
}
fn apply_replacement_to_constant(c: Constant, replacements: &BTreeMap<&str, &str>) -> Constant {
match c {
Constant::Number(_) => c,
Constant::HiDataRef(s, off) => Constant::HiDataRef(replace(s, replacements), off),
Constant::LoDataRef(s, off) => Constant::LoDataRef(replace(s, replacements), off),
}
}
fn apply_replacement_to_object(object: &mut Vec<DataValue>, replacements: &BTreeMap<&str, &str>) {
for value in object {
if let DataValue::Reference(reference) = value {
@@ -253,9 +237,8 @@ fn apply_replacement_to_object(object: &mut Vec<DataValue>, replacements: &BTree
}
}
fn replace(s: String, replacements: &BTreeMap<&str, &str>) -> String {
match replacements.get(s.as_str()) {
Some(r) => r.to_string(),
None => s,
fn replace(s: &mut String, replacements: &BTreeMap<&str, &str>) {
if let Some(r) = replacements.get(s.as_str()) {
*s = r.to_string();
}
}

View File

@@ -1,5 +1,5 @@
use std::str::FromStr;
use crate::parser::{Statement, Argument, Register, Constant, unescape_string};
use crate::parser::{Statement, Argument, Register, Expression, unescape_string, BinaryOpKind as BOp, UnaryOpKind as UOp, new_binary_op as bin_op, new_unary_op as un_op};
grammar;
@@ -55,9 +55,7 @@ Argument: Argument = {
Register => Argument::Register(<>),
OffsetRegister,
StringLiteral => Argument::StringLiteral(<>),
Symbol => Argument::Symbol(<>),
Constant => Argument::Constant(<>),
Difference,
Expression => Argument::Expression(<>),
}
Register: Register = {
@@ -80,68 +78,59 @@ Register: Register = {
}
OffsetRegister: Argument = {
<c:Constant> "(" <r:Register> ")" => Argument::RegOffset(r, c),
<c:Expression> "(" <r:Register> ")" => Argument::RegOffset(r, c),
}
Constant: Constant = {
ConstantExpression => Constant::Number(<>),
"%hi(" <arg:HiLowArgument> ")" => Constant::HiDataRef(arg.0, arg.1),
"%lo(" <arg:HiLowArgument> ")" => Constant::LoDataRef(arg.0, arg.1),
Expression: Expression = {
ExprBinaryOr,
}
HiLowArgument: (String, i64) = {
<Symbol> => (<>, 0),
<Symbol> "+" <ConstantExpression> => (<>),
<n: Symbol> "-" <o: ConstantExpression> => (n, -o),
ExprBinaryOr: Expression = {
<l:ExprBinaryOr> "|" <r:ExprBinaryXor> => bin_op(BOp::Or, l, r),
ExprBinaryXor,
}
ConstantExpression: i64 = {
<l:ConstantExpression> "|" <r:ConstantBinaryXor> => l | r,
ConstantBinaryXor,
ExprBinaryXor: Expression = {
<l:ExprBinaryXor> "^" <r:ExprBinaryAnd> => bin_op(BOp::Xor, l, r),
ExprBinaryAnd,
}
ConstantBinaryXor: i64 = {
<l:ConstantBinaryXor> "^" <r:ConstantBinaryAnd> => l ^ r,
ConstantBinaryAnd,
ExprBinaryAnd: Expression = {
<l:ExprBinaryAnd> "&" <r:ExprBitShift> => bin_op(BOp::And, l, r),
ExprBitShift,
}
ConstantBinaryAnd: i64 = {
<l:ConstantBinaryAnd> "&" <r:ConstantBitShift> => l & r,
ConstantBitShift,
ExprBitShift: Expression = {
<l:ExprBitShift> "<<" <r:ExprSum> => bin_op(BOp::LeftShift, l, r),
<l:ExprBitShift> ">>" <r:ExprSum> => bin_op(BOp::RightShift, l, r),
ExprSum,
}
ConstantBitShift: i64 = {
<l:ConstantBitShift> "<<" <r:ConstantSum> => l << r,
<l:ConstantBitShift> ">>" <r:ConstantSum> => l >> r,
ConstantSum,
ExprSum: Expression = {
<l:ExprSum> "+" <r:ExprProduct> => bin_op(BOp::Add, l, r),
<l:ExprSum> "-" <r:ExprProduct> => bin_op(BOp::Sub, l, r),
ExprProduct,
}
ConstantSum: i64 = {
<l:ConstantSum> "+" <r:ConstantProduct> => l + r,
<l:ConstantSum> "-" <r:ConstantProduct> => l - r,
ConstantProduct,
ExprProduct: Expression = {
<l:ExprProduct> "*" <r:ExprUnary> => bin_op(BOp::Mul, l, r),
<l:ExprProduct> "/" <r:ExprUnary> => bin_op(BOp::Div, l, r),
<l:ExprProduct> "%" <r:ExprUnary> => bin_op(BOp::Mod, l, r),
ExprUnary,
}
ConstantProduct: i64 = {
<l:ConstantProduct> "*" <r:ConstantUnary> => l * r,
<l:ConstantProduct> "/" <r:ConstantUnary> => l / r,
<l:ConstantProduct> "%" <r:ConstantUnary> => l % r,
ConstantUnary,
ExprUnary: Expression = {
"-" <ExprTerm> => un_op(UOp::Negation, <>),
"+" <ExprTerm>,
ExprTerm,
}
ConstantUnary: i64 = {
"-" <ConstantTerm> => -<>,
"+" <ConstantTerm> => <>,
ConstantTerm,
}
ConstantTerm: i64 = {
Number,
"(" <ConstantExpression> ")",
}
Difference: Argument = {
<Symbol> "-" <Symbol> => Argument::Difference(<>)
ExprTerm: Expression = {
Number => Expression::Number(<>),
"(" <Expression> ")" => <>,
"%hi(" <Expression> ")" => un_op(UOp::HiDataRef, <>),
"%lo(" <Expression> ")" => un_op(UOp::LoDataRef, <>),
Symbol => Expression::Symbol(<>)
}
StringLiteral: Vec<u8> = {
@@ -164,5 +153,4 @@ Symbol: String = {
Number: i64 = {
r"-?[0-9][0-9_]*" => i64::from_str(<>).unwrap().into(),
r"0x[0-9A-Fa-f][0-9A-Fa-f_]*" => i64::from_str_radix(&<>[2..].replace('_', ""), 16).unwrap().into(),
}
}

View File

@@ -5,6 +5,6 @@ use runtime::{get_prover_input, print};
#[no_mangle]
pub fn main() {
let input = get_prover_input(0);
print(format_args!("Input in hex: {input:x}\n"));
print!("Input in hex: {input:x}\n");
assert_eq!([1, 2, 3], [4, 5, 6]);
}