mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-04-20 03:03:25 -04:00
Merge pull request #313 from powdr-org/generic_riscv_asm_parsing
More generic RISC-V asm parsing
This commit is contained in:
@@ -4,7 +4,7 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "^4.1", features = ["derive"] }
|
||||
clap = { version = "^4.3", features = ["derive"] }
|
||||
env_logger = "0.10.0"
|
||||
log = "0.4.17"
|
||||
compiler = { path = "../compiler" }
|
||||
|
||||
@@ -8,7 +8,7 @@ use env_logger::{Builder, Target};
|
||||
use log::LevelFilter;
|
||||
use number::{Bn254Field, FieldElement, GoldilocksField};
|
||||
use riscv::{compile_riscv_asm, compile_rust};
|
||||
use std::{fs, io::Write, path::Path};
|
||||
use std::{borrow::Cow, fs, io::Write, path::Path};
|
||||
use strum::{Display, EnumString, EnumVariantNames};
|
||||
|
||||
#[derive(Clone, EnumString, EnumVariantNames, Display)]
|
||||
@@ -99,8 +99,9 @@ enum Commands {
|
||||
/// Compiles riscv assembly to powdr assembly and then to PIL
|
||||
/// and generates fixed and witness columns.
|
||||
RiscvAsm {
|
||||
/// Input file
|
||||
file: String,
|
||||
/// Input files
|
||||
#[arg(required = true)]
|
||||
files: Vec<String>,
|
||||
|
||||
/// The field to use
|
||||
#[arg(long)]
|
||||
@@ -113,7 +114,7 @@ enum Commands {
|
||||
#[arg(default_value_t = String::new())]
|
||||
inputs: String,
|
||||
|
||||
/// Directory for output files.
|
||||
/// Directory for output files.
|
||||
#[arg(short, long)]
|
||||
#[arg(default_value_t = String::from("."))]
|
||||
output_directory: String,
|
||||
@@ -186,22 +187,31 @@ fn main() {
|
||||
prove_with
|
||||
),
|
||||
Commands::RiscvAsm {
|
||||
file,
|
||||
files,
|
||||
field,
|
||||
inputs,
|
||||
output_directory,
|
||||
force,
|
||||
prove_with,
|
||||
} => call_with_field!(
|
||||
compile_riscv_asm,
|
||||
field,
|
||||
&file,
|
||||
&file,
|
||||
split_inputs(&inputs),
|
||||
Path::new(&output_directory),
|
||||
force,
|
||||
prove_with
|
||||
),
|
||||
} => {
|
||||
assert!(!files.is_empty());
|
||||
let name = if files.len() == 1 {
|
||||
Cow::Owned(files[0].clone())
|
||||
} else {
|
||||
Cow::Borrowed("output")
|
||||
};
|
||||
|
||||
call_with_field!(
|
||||
compile_riscv_asm,
|
||||
field,
|
||||
&name,
|
||||
files.into_iter(),
|
||||
split_inputs(&inputs),
|
||||
Path::new(&output_directory),
|
||||
force,
|
||||
prove_with
|
||||
);
|
||||
}
|
||||
Commands::Reformat { file } => {
|
||||
let contents = fs::read_to_string(&file).unwrap();
|
||||
match parser::parse::<GoldilocksField>(Some(&file), &contents) {
|
||||
|
||||
@@ -3,10 +3,10 @@ use std::collections::{BTreeMap, BTreeSet};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::data_parser::{self, DataValue};
|
||||
use crate::parser::{self, Argument, Register, Statement};
|
||||
use crate::parser::{self, Argument, Register, Statement, UnaryOpKind};
|
||||
use crate::{disambiguator, reachability};
|
||||
|
||||
use super::parser::Constant;
|
||||
use super::parser::Expression;
|
||||
|
||||
/// Compiles riscv assembly to POWDR assembly. Adds required library routines.
|
||||
pub fn compile_riscv_asm(mut assemblies: BTreeMap<String, String>) -> String {
|
||||
@@ -52,7 +52,7 @@ pub fn compile_riscv_asm(mut assemblies: BTreeMap<String, String>) -> String {
|
||||
"jump __runtime_start;".to_string(),
|
||||
])
|
||||
.chain(
|
||||
insert_data_positions(statements, &data_positions)
|
||||
substitute_symbols_with_values(statements, &data_positions)
|
||||
.into_iter()
|
||||
.flat_map(process_statement),
|
||||
)
|
||||
@@ -125,19 +125,20 @@ fn replace_dynamic_label_reference(
|
||||
if instr1.as_str() != "lui" || instr2.as_str() != "addi" {
|
||||
return None;
|
||||
};
|
||||
let [Argument::Register(r1), Argument::Constant(Constant::HiDataRef(label1, offset1))] = &args1[..] else { return None; };
|
||||
let [Argument::Register(r2), Argument::Register(r3), Argument::Constant(Constant::LoDataRef(label2, offset2))] = &args2[..] else { return None; };
|
||||
if r1 != r3
|
||||
|| label1 != label2
|
||||
|| *offset1 != 0
|
||||
|| *offset2 != 0
|
||||
|| data_objects.contains_key(label1)
|
||||
{
|
||||
let [Argument::Register(r1), Argument::Expression(Expression::UnaryOp(UnaryOpKind::HiDataRef, expr1))] = &args1[..] else { return None; };
|
||||
// Maybe should try to reduce expr1 and expr2 before comparing deciding it is a pure symbol?
|
||||
let Expression::Symbol(label1) = &expr1[0] else { return None; };
|
||||
let [Argument::Register(r2), Argument::Register(r3), Argument::Expression(Expression::UnaryOp(UnaryOpKind::LoDataRef, expr2))] = &args2[..] else { return None; };
|
||||
let Expression::Symbol(label2) = &expr2[0] else { return None; };
|
||||
if r1 != r3 || label1 != label2 || data_objects.contains_key(label1) {
|
||||
return None;
|
||||
}
|
||||
Some(Statement::Instruction(
|
||||
"load_dynamic".to_string(),
|
||||
vec![Argument::Register(*r2), Argument::Symbol(label1.clone())],
|
||||
vec![
|
||||
Argument::Register(*r2),
|
||||
Argument::Expression(Expression::Symbol(label1.clone())),
|
||||
],
|
||||
))
|
||||
}
|
||||
|
||||
@@ -217,48 +218,55 @@ fn next_multiple_of_four(x: usize) -> usize {
|
||||
((x + 3) / 4) * 4
|
||||
}
|
||||
|
||||
fn insert_data_positions(
|
||||
fn substitute_symbols_with_values(
|
||||
mut statements: Vec<Statement>,
|
||||
data_positions: &BTreeMap<String, u32>,
|
||||
) -> Vec<Statement> {
|
||||
for s in &mut statements {
|
||||
let Statement::Instruction(_name, args) = s else { continue; };
|
||||
for arg in args {
|
||||
match arg {
|
||||
Argument::RegOffset(_, offset) => replace_data_reference(offset, data_positions),
|
||||
Argument::Constant(c) => replace_data_reference(c, data_positions),
|
||||
Argument::Symbol(symb) => {
|
||||
arg.post_visit_expressions_mut(&mut |expression| match expression {
|
||||
Expression::Number(_) => {}
|
||||
Expression::Symbol(symb) => {
|
||||
if let Some(pos) = data_positions.get(symb) {
|
||||
*arg = Argument::Constant(Constant::Number(*pos as i64))
|
||||
*expression = Expression::Number(*pos as i64)
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Expression::UnaryOp(op, subexpr) => {
|
||||
if let Expression::Number(num) = subexpr[0] {
|
||||
let result = match op {
|
||||
UnaryOpKind::HiDataRef => num >> 12,
|
||||
UnaryOpKind::LoDataRef => num & 0xfff,
|
||||
UnaryOpKind::Negation => -num,
|
||||
};
|
||||
*expression = Expression::Number(result);
|
||||
};
|
||||
}
|
||||
Expression::BinaryOp(op, subexprs) => {
|
||||
if let (Expression::Number(a), Expression::Number(b)) =
|
||||
(&subexprs[0], &subexprs[1])
|
||||
{
|
||||
let result = match op {
|
||||
parser::BinaryOpKind::Or => a | b,
|
||||
parser::BinaryOpKind::Xor => a ^ b,
|
||||
parser::BinaryOpKind::And => a & b,
|
||||
parser::BinaryOpKind::LeftShift => a << b,
|
||||
parser::BinaryOpKind::RightShift => a >> b,
|
||||
parser::BinaryOpKind::Add => a + b,
|
||||
parser::BinaryOpKind::Sub => a - b,
|
||||
parser::BinaryOpKind::Mul => a * b,
|
||||
parser::BinaryOpKind::Div => a / b,
|
||||
parser::BinaryOpKind::Mod => a % b,
|
||||
};
|
||||
*expression = Expression::Number(result);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
statements
|
||||
}
|
||||
|
||||
fn replace_data_reference(constant: &mut Constant, data_positions: &BTreeMap<String, u32>) {
|
||||
match constant {
|
||||
Constant::Number(_) => {}
|
||||
Constant::HiDataRef(data, offset) => {
|
||||
if let Some(pos) = data_positions.get(data) {
|
||||
let pos: u32 = pos + *offset as u32;
|
||||
*constant = Constant::Number((pos >> 12) as i64)
|
||||
}
|
||||
// Otherwise, it references a code label
|
||||
}
|
||||
Constant::LoDataRef(data, offset) => {
|
||||
if let Some(pos) = data_positions.get(data) {
|
||||
let pos: u32 = pos + *offset as u32;
|
||||
*constant = Constant::Number((pos & 0xfff) as i64)
|
||||
}
|
||||
// Otherwise, it references a code label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn preamble() -> String {
|
||||
r#"
|
||||
degree 262144;
|
||||
@@ -601,19 +609,26 @@ fn escape_label(l: &str) -> String {
|
||||
}
|
||||
|
||||
fn argument_to_number(x: &Argument) -> u32 {
|
||||
if let Argument::Constant(c) = x {
|
||||
constant_to_number(c)
|
||||
if let Argument::Expression(expr) = x {
|
||||
expression_to_number(expr)
|
||||
} else {
|
||||
panic!("Expected number, got {x}")
|
||||
panic!("Expected numeric expression, got {x}")
|
||||
}
|
||||
}
|
||||
|
||||
fn constant_to_number(c: &Constant) -> u32 {
|
||||
match c {
|
||||
Constant::Number(n) => *n as u32,
|
||||
Constant::HiDataRef(n, off) | Constant::LoDataRef(n, off) => {
|
||||
panic!("Data reference was not erased during preprocessing: {n} + {off}");
|
||||
}
|
||||
fn expression_to_number(expr: &Expression) -> u32 {
|
||||
if let Expression::Number(n) = expr {
|
||||
*n as u32
|
||||
} else {
|
||||
panic!("Constant expression could not be fully resolved to a number during preprocessing: {expr}");
|
||||
}
|
||||
}
|
||||
|
||||
fn argument_to_escaped_symbol(x: &Argument) -> String {
|
||||
if let Argument::Expression(Expression::Symbol(symb)) = x {
|
||||
escape_label(symb)
|
||||
} else {
|
||||
panic!("Expected a symbol, got {x}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -654,8 +669,8 @@ fn rr(args: &[Argument]) -> (Register, Register) {
|
||||
|
||||
fn rrl(args: &[Argument]) -> (Register, Register, String) {
|
||||
match args {
|
||||
[Argument::Register(r1), Argument::Register(r2), Argument::Symbol(l)] => {
|
||||
(*r1, *r2, escape_label(l))
|
||||
[Argument::Register(r1), Argument::Register(r2), l] => {
|
||||
(*r1, *r2, argument_to_escaped_symbol(l))
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
@@ -663,7 +678,7 @@ fn rrl(args: &[Argument]) -> (Register, Register, String) {
|
||||
|
||||
fn rl(args: &[Argument]) -> (Register, String) {
|
||||
match args {
|
||||
[Argument::Register(r1), Argument::Symbol(l)] => (*r1, escape_label(l)),
|
||||
[Argument::Register(r1), l] => (*r1, argument_to_escaped_symbol(l)),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
@@ -671,7 +686,7 @@ fn rl(args: &[Argument]) -> (Register, String) {
|
||||
fn rro(args: &[Argument]) -> (Register, Register, u32) {
|
||||
match args {
|
||||
[Argument::Register(r1), Argument::RegOffset(r2, off)] => {
|
||||
(*r1, *r2, constant_to_number(off))
|
||||
(*r1, *r2, expression_to_number(off))
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
@@ -952,8 +967,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec<String> {
|
||||
|
||||
// jump and call
|
||||
"j" => {
|
||||
if let [Argument::Symbol(label)] = args {
|
||||
vec![format!("jump {};", escape_label(label))]
|
||||
if let [label] = args {
|
||||
vec![format!("jump {};", argument_to_escaped_symbol(label))]
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
@@ -972,8 +987,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec<String> {
|
||||
vec![format!("jump_and_link_dyn {rs};")]
|
||||
}
|
||||
"call" => {
|
||||
if let [Argument::Symbol(label)] = args {
|
||||
vec![format!("call {};", escape_label(label))]
|
||||
if let [label] = args {
|
||||
vec![format!("call {};", argument_to_escaped_symbol(label))]
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
@@ -989,8 +1004,8 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec<String> {
|
||||
vec!["x0 <=X= ${ (\"print_char\", x10) };\n".to_string()]
|
||||
}
|
||||
"tail" => {
|
||||
if let [Argument::Symbol(label)] = args {
|
||||
vec![format!("tail {};", escape_label(label))]
|
||||
if let [label] = args {
|
||||
vec![format!("tail {};", argument_to_escaped_symbol(label))]
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::parser::{Argument, Constant, Statement};
|
||||
use crate::parser::{Argument, Expression, Statement};
|
||||
|
||||
pub enum DataValue {
|
||||
Direct(Vec<u8>),
|
||||
@@ -34,9 +34,10 @@ pub fn extract_data_objects(
|
||||
current_label = Some(l.as_str());
|
||||
}
|
||||
Statement::Directive(dir, args) => match (dir.as_str(), &args[..]) {
|
||||
(".type", [Argument::Symbol(name), Argument::Symbol(kind)])
|
||||
if kind.as_str() == "@object" =>
|
||||
{
|
||||
(
|
||||
".type",
|
||||
[Argument::Expression(Expression::Symbol(name)), Argument::Expression(Expression::Symbol(kind))],
|
||||
) if kind.as_str() == "@object" => {
|
||||
object_order.push(name.clone());
|
||||
assert!(objects.insert(name.clone(), vec![]).is_none());
|
||||
}
|
||||
@@ -47,9 +48,10 @@ pub fn extract_data_objects(
|
||||
entry.extend(extract_data_value(dir.as_str(), args));
|
||||
});
|
||||
}
|
||||
(".size", [Argument::Symbol(name), Argument::Constant(Constant::Number(n))])
|
||||
if Some(name.as_str()) == current_label =>
|
||||
{
|
||||
(
|
||||
".size",
|
||||
[Argument::Expression(Expression::Symbol(name)), Argument::Expression(Expression::Number(n))],
|
||||
) if Some(name.as_str()) == current_label => {
|
||||
objects
|
||||
.entry(current_label.unwrap().into())
|
||||
.and_modify(|entry| {
|
||||
@@ -76,9 +78,9 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
|
||||
match (directive, arguments) {
|
||||
(
|
||||
".zero",
|
||||
[Argument::Constant(Constant::Number(n))]
|
||||
[Argument::Expression(Expression::Number(n))]
|
||||
// TODO not clear what the second argument is
|
||||
| [Argument::Constant(Constant::Number(n)), _],
|
||||
| [Argument::Expression(Expression::Number(n)), _],
|
||||
) => {
|
||||
vec![DataValue::Zero(*n as usize)]
|
||||
}
|
||||
@@ -95,7 +97,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
|
||||
.iter()
|
||||
.map(|x| {
|
||||
match x {
|
||||
Argument::Constant(Constant::Number(n)) =>{
|
||||
Argument::Expression(Expression::Number(n)) =>{
|
||||
let n = *n as u32;
|
||||
DataValue::Direct(vec![
|
||||
(n & 0xff) as u8,
|
||||
@@ -104,7 +106,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
|
||||
(n >> 24 & 0xff) as u8,
|
||||
])
|
||||
}
|
||||
Argument::Symbol(sym) => {
|
||||
Argument::Expression(Expression::Symbol(sym)) => {
|
||||
DataValue::Reference(sym.clone())
|
||||
}
|
||||
_ => panic!("Invalid .word directive")
|
||||
@@ -117,7 +119,7 @@ fn extract_data_value(directive: &str, arguments: &[Argument]) -> Vec<DataValue>
|
||||
vec![DataValue::Direct(data
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if let Argument::Constant(Constant::Number(n)) = x {
|
||||
if let Argument::Expression(Expression::Number(n)) = x {
|
||||
*n as u8
|
||||
} else {
|
||||
panic!("Invalid argument to .byte directive")
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashSet;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::parser::{Argument, Constant, Statement};
|
||||
use crate::parser::{Argument, Expression, Statement};
|
||||
|
||||
pub fn disambiguate(assemblies: Vec<(String, Vec<Statement>)>) -> Vec<Statement> {
|
||||
let globals = assemblies
|
||||
@@ -12,7 +12,10 @@ pub fn disambiguate(assemblies: Vec<(String, Vec<Statement>)>) -> Vec<Statement>
|
||||
|
||||
assemblies
|
||||
.into_iter()
|
||||
.map(|(name, statements)| disambiguate_file(&name, statements, &globals))
|
||||
.map(|(name, mut statements)| {
|
||||
disambiguate_file(&name, &mut statements, &globals);
|
||||
statements
|
||||
})
|
||||
.concat()
|
||||
}
|
||||
|
||||
@@ -25,7 +28,7 @@ fn extract_globals(statements: &[Statement]) -> HashSet<String> {
|
||||
return args
|
||||
.iter()
|
||||
.map(|a| {
|
||||
if let Argument::Symbol(s) = a {
|
||||
if let Argument::Expression(Expression::Symbol(s)) = a {
|
||||
s.clone()
|
||||
} else {
|
||||
panic!("Invalid .globl directive: {s}");
|
||||
@@ -40,82 +43,30 @@ fn extract_globals(statements: &[Statement]) -> HashSet<String> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn disambiguate_file(
|
||||
file_name: &str,
|
||||
statements: Vec<Statement>,
|
||||
globals: &HashSet<String>,
|
||||
) -> Vec<Statement> {
|
||||
fn disambiguate_file(file_name: &str, statements: &mut [Statement], globals: &HashSet<String>) {
|
||||
let prefix = file_name.replace('-', "_dash_");
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|s| match s {
|
||||
Statement::Label(l) => {
|
||||
Statement::Label(disambiguate_symbol_if_needed(l, &prefix, globals))
|
||||
for s in statements {
|
||||
match s {
|
||||
Statement::Label(l) => disambiguate_symbol_if_needed(l, &prefix, globals),
|
||||
Statement::Directive(_, args) | Statement::Instruction(_, args) => {
|
||||
for arg in args.iter_mut() {
|
||||
disambiguate_argument_if_needed(arg, &prefix, globals);
|
||||
}
|
||||
}
|
||||
Statement::Directive(dir, args) => Statement::Directive(
|
||||
dir,
|
||||
disambiguate_arguments_if_needed(args, &prefix, globals),
|
||||
),
|
||||
Statement::Instruction(instr, args) => Statement::Instruction(
|
||||
instr,
|
||||
disambiguate_arguments_if_needed(args, &prefix, globals),
|
||||
),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn disambiguate_arguments_if_needed(
|
||||
args: Vec<Argument>,
|
||||
prefix: &str,
|
||||
globals: &HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
args.into_iter()
|
||||
.map(|a| disambiguate_argument_if_needed(a, prefix, globals))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn disambiguate_argument_if_needed(
|
||||
arg: Argument,
|
||||
prefix: &str,
|
||||
globals: &HashSet<String>,
|
||||
) -> Argument {
|
||||
match arg {
|
||||
Argument::Register(_) | Argument::StringLiteral(_) => arg,
|
||||
Argument::RegOffset(reg, constant) => Argument::RegOffset(
|
||||
reg,
|
||||
disambiguate_constant_if_needed(constant, prefix, globals),
|
||||
),
|
||||
Argument::Constant(c) => {
|
||||
Argument::Constant(disambiguate_constant_if_needed(c, prefix, globals))
|
||||
}
|
||||
Argument::Symbol(s) => Argument::Symbol(disambiguate_symbol_if_needed(s, prefix, globals)),
|
||||
Argument::Difference(l, r) => Argument::Difference(
|
||||
disambiguate_symbol_if_needed(l, prefix, globals),
|
||||
disambiguate_symbol_if_needed(r, prefix, globals),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn disambiguate_constant_if_needed(
|
||||
c: Constant,
|
||||
prefix: &str,
|
||||
globals: &HashSet<String>,
|
||||
) -> Constant {
|
||||
match c {
|
||||
Constant::Number(_) => c,
|
||||
Constant::HiDataRef(s, offset) => {
|
||||
Constant::HiDataRef(disambiguate_symbol_if_needed(s, prefix, globals), offset)
|
||||
}
|
||||
Constant::LoDataRef(s, offset) => {
|
||||
Constant::LoDataRef(disambiguate_symbol_if_needed(s, prefix, globals), offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn disambiguate_symbol_if_needed(s: String, prefix: &str, globals: &HashSet<String>) -> String {
|
||||
if globals.contains(s.as_str()) || s.starts_with('@') {
|
||||
s
|
||||
} else {
|
||||
format!("{prefix}__{s}")
|
||||
fn disambiguate_argument_if_needed(arg: &mut Argument, prefix: &str, globals: &HashSet<String>) {
|
||||
arg.post_visit_expressions_mut(&mut |expr| {
|
||||
if let Expression::Symbol(sym) = expr {
|
||||
disambiguate_symbol_if_needed(sym, prefix, globals);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn disambiguate_symbol_if_needed(s: &mut String, prefix: &str, globals: &HashSet<String>) {
|
||||
if !s.starts_with('@') && !globals.contains(s.as_str()) {
|
||||
*s = format!("{prefix}__{s}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,17 +104,19 @@ pub fn compile_riscv_asm_bundle<T: FieldElement>(
|
||||
/// fixed and witness columns.
|
||||
pub fn compile_riscv_asm<T: FieldElement>(
|
||||
original_file_name: &str,
|
||||
file_name: &str,
|
||||
file_names: impl Iterator<Item = String>,
|
||||
inputs: Vec<T>,
|
||||
output_dir: &Path,
|
||||
force_overwrite: bool,
|
||||
prove_with: Option<Backend>,
|
||||
) {
|
||||
let contents = fs::read_to_string(file_name).unwrap();
|
||||
compile_riscv_asm_bundle(
|
||||
original_file_name,
|
||||
vec![(file_name.to_string(), contents)]
|
||||
.into_iter()
|
||||
file_names
|
||||
.map(|name| {
|
||||
let contents = fs::read_to_string(&name).unwrap();
|
||||
(name, contents)
|
||||
})
|
||||
.collect(),
|
||||
inputs,
|
||||
output_dir,
|
||||
|
||||
@@ -20,11 +20,29 @@ pub enum Statement {
|
||||
#[derive(Clone)]
|
||||
pub enum Argument {
|
||||
Register(Register),
|
||||
RegOffset(Register, Constant),
|
||||
RegOffset(Register, Expression),
|
||||
StringLiteral(Vec<u8>),
|
||||
Constant(Constant),
|
||||
Symbol(String),
|
||||
Difference(String, String),
|
||||
Expression(Expression),
|
||||
}
|
||||
|
||||
impl Argument {
|
||||
pub(crate) fn post_visit_expressions_mut(&mut self, f: &mut impl FnMut(&mut Expression)) {
|
||||
match self {
|
||||
Argument::Register(_) | Argument::StringLiteral(_) => (),
|
||||
Argument::RegOffset(_, expr) | Argument::Expression(expr) => {
|
||||
expr.post_visit_mut(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn post_visit_expressions<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) {
|
||||
match self {
|
||||
Argument::Register(_) | Argument::StringLiteral(_) => (),
|
||||
Argument::RegOffset(_, expr) | Argument::Expression(expr) => {
|
||||
expr.post_visit(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
@@ -36,11 +54,67 @@ impl Register {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum UnaryOpKind {
|
||||
HiDataRef,
|
||||
LoDataRef,
|
||||
Negation,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryOpKind {
|
||||
Or,
|
||||
Xor,
|
||||
And,
|
||||
LeftShift,
|
||||
RightShift,
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
Mod,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Constant {
|
||||
pub enum Expression {
|
||||
Number(i64),
|
||||
HiDataRef(String, i64),
|
||||
LoDataRef(String, i64),
|
||||
Symbol(String),
|
||||
UnaryOp(UnaryOpKind, Box<[Expression]>),
|
||||
BinaryOp(BinaryOpKind, Box<[Expression; 2]>),
|
||||
}
|
||||
|
||||
impl Expression {
|
||||
fn post_visit<'a>(&'a self, f: &mut impl FnMut(&'a Expression)) {
|
||||
match self {
|
||||
Expression::UnaryOp(_, subexpr) => subexpr.iter(),
|
||||
Expression::BinaryOp(_, subexprs) => subexprs.iter(),
|
||||
_ => [].iter(),
|
||||
}
|
||||
.for_each(|subexpr| {
|
||||
Self::post_visit(subexpr, f);
|
||||
});
|
||||
f(self);
|
||||
}
|
||||
|
||||
fn post_visit_mut(&mut self, f: &mut impl FnMut(&mut Expression)) {
|
||||
match self {
|
||||
Expression::UnaryOp(_, subexpr) => subexpr.iter_mut(),
|
||||
Expression::BinaryOp(_, subexprs) => subexprs.iter_mut(),
|
||||
_ => [].iter_mut(),
|
||||
}
|
||||
.for_each(|subexpr| {
|
||||
Self::post_visit_mut(subexpr, f);
|
||||
});
|
||||
f(self);
|
||||
}
|
||||
}
|
||||
|
||||
fn new_unary_op(op: UnaryOpKind, v: Expression) -> Expression {
|
||||
Expression::UnaryOp(op, Box::new([v]))
|
||||
}
|
||||
|
||||
fn new_binary_op(op: BinaryOpKind, l: Expression, r: Expression) -> Expression {
|
||||
Expression::BinaryOp(op, Box::new([l, r]))
|
||||
}
|
||||
|
||||
impl Display for Statement {
|
||||
@@ -57,37 +131,40 @@ impl Display for Argument {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Argument::Register(r) => write!(f, "{r}"),
|
||||
Argument::Constant(c) => write!(f, "{c}"),
|
||||
Argument::RegOffset(reg, off) => write!(f, "{off}({reg})"),
|
||||
Argument::StringLiteral(lit) => write!(f, "\"{}\"", String::from_utf8_lossy(lit)),
|
||||
Argument::Symbol(s) => write!(f, "{s}"),
|
||||
Argument::Difference(left, right) => write!(f, "{left} - {right}"),
|
||||
Argument::Expression(expr) => write!(f, "{expr}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Constant {
|
||||
impl Display for Expression {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Constant::Number(n) => write!(f, "{n}"),
|
||||
Constant::HiDataRef(sym, offset) => {
|
||||
write!(f, "%hi({sym}{})", format_hi_lo_offset(*offset))
|
||||
}
|
||||
Constant::LoDataRef(sym, offset) => {
|
||||
write!(f, "%lo({sym}{})", format_hi_lo_offset(*offset))
|
||||
Expression::Number(n) => write!(f, "{n}"),
|
||||
Expression::Symbol(sym) => write!(f, "{sym}"),
|
||||
Expression::UnaryOp(UnaryOpKind::Negation, expr) => write!(f, "(-{})", expr[0]),
|
||||
Expression::UnaryOp(UnaryOpKind::HiDataRef, expr) => write!(f, "%hi({})", expr[0]),
|
||||
Expression::UnaryOp(UnaryOpKind::LoDataRef, expr) => write!(f, "%lo({})", expr[0]),
|
||||
Expression::BinaryOp(op, args) => {
|
||||
let symbol = match op {
|
||||
BinaryOpKind::Or => "|",
|
||||
BinaryOpKind::Xor => "^",
|
||||
BinaryOpKind::And => "&",
|
||||
BinaryOpKind::LeftShift => "<<",
|
||||
BinaryOpKind::RightShift => ">>",
|
||||
BinaryOpKind::Add => "+",
|
||||
BinaryOpKind::Sub => "-",
|
||||
BinaryOpKind::Mul => "*",
|
||||
BinaryOpKind::Div => "/",
|
||||
BinaryOpKind::Mod => "%",
|
||||
};
|
||||
write!(f, "({} {symbol} {})", args[0], args[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_hi_lo_offset(offset: i64) -> String {
|
||||
match offset {
|
||||
0 => String::new(),
|
||||
1.. => format!(" + {offset}"),
|
||||
..=-1 => format!(" - {}", -offset),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_arguments(args: &[Argument]) -> String {
|
||||
args.iter()
|
||||
.map(|a| format!("{a}"))
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet, HashSet};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::data_parser::DataValue;
|
||||
use crate::parser::{Argument, Constant, Statement};
|
||||
use crate::parser::{Argument, Expression, Statement};
|
||||
|
||||
pub fn filter_reachable_from(
|
||||
label: &str,
|
||||
@@ -62,7 +62,10 @@ pub fn filter_reachable_from(
|
||||
let offset = *label_offsets.get(l).unwrap();
|
||||
basic_block_code_starting_from(&statements[offset..])
|
||||
.into_iter()
|
||||
.map(|s| apply_replacement_to_instruction(s, &replacements))
|
||||
.map(|mut s| {
|
||||
apply_replacement_to_instruction(&mut s, &replacements);
|
||||
s
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let referenced_labels = referenced_labels
|
||||
@@ -82,7 +85,8 @@ fn extract_replacements(statements: &[Statement]) -> BTreeMap<&str, &str> {
|
||||
.iter()
|
||||
.filter_map(|s| match s {
|
||||
Statement::Directive(dir, args) if dir.as_str() == ".set" => {
|
||||
if let [Argument::Symbol(from), Argument::Symbol(to)] = &args[..] {
|
||||
if let [Argument::Expression(Expression::Symbol(from)), Argument::Expression(Expression::Symbol(to))] = &args[..]
|
||||
{
|
||||
Some((from.as_str(), to.as_str()))
|
||||
} else {
|
||||
panic!();
|
||||
@@ -137,23 +141,20 @@ pub fn extract_label_offsets(statements: &[Statement]) -> BTreeMap<&str, usize>
|
||||
}
|
||||
|
||||
pub fn references_in_statement(statement: &Statement) -> BTreeSet<&str> {
|
||||
let mut ret = BTreeSet::new();
|
||||
match statement {
|
||||
Statement::Label(_) | Statement::Directive(_, _) => Default::default(),
|
||||
Statement::Instruction(_, args) => args
|
||||
.iter()
|
||||
.filter_map(|arg| match arg {
|
||||
Argument::Register(_) | Argument::StringLiteral(_) => None,
|
||||
Argument::Symbol(s) => Some(s.as_str()),
|
||||
Argument::RegOffset(_, c) | Argument::Constant(c) => match c {
|
||||
Constant::Number(_) => None,
|
||||
Constant::HiDataRef(s, _offset) | Constant::LoDataRef(s, _offset) => {
|
||||
Some(s.as_str())
|
||||
Statement::Label(_) | Statement::Directive(_, _) => (),
|
||||
Statement::Instruction(_, args) => {
|
||||
for arg in args {
|
||||
arg.post_visit_expressions(&mut |expr| {
|
||||
if let Expression::Symbol(sym) = expr {
|
||||
ret.insert(sym.as_str());
|
||||
}
|
||||
},
|
||||
Argument::Difference(_, _) => todo!(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
ret
|
||||
}
|
||||
|
||||
fn basic_block_references_starting_from(statements: &[Statement]) -> (Vec<&str>, Vec<&str>) {
|
||||
@@ -208,41 +209,24 @@ fn ends_control_flow(s: &Statement) -> bool {
|
||||
}
|
||||
|
||||
fn apply_replacement_to_instruction(
|
||||
statement: Statement,
|
||||
statement: &mut Statement,
|
||||
replacements: &BTreeMap<&str, &str>,
|
||||
) -> Statement {
|
||||
) {
|
||||
match statement {
|
||||
Statement::Label(_) => statement,
|
||||
Statement::Instruction(instr, args) => Statement::Instruction(
|
||||
instr,
|
||||
args.into_iter()
|
||||
.map(|a| match a {
|
||||
Argument::Register(_) | Argument::StringLiteral(_) => a,
|
||||
Argument::Symbol(s) => Argument::Symbol(replace(s, replacements)),
|
||||
Argument::RegOffset(reg, c) => {
|
||||
Argument::RegOffset(reg, apply_replacement_to_constant(c, replacements))
|
||||
Statement::Label(_) => (),
|
||||
Statement::Instruction(_, args) => {
|
||||
for a in args {
|
||||
a.post_visit_expressions_mut(&mut |expr| {
|
||||
if let Expression::Symbol(s) = expr {
|
||||
replace(s, replacements);
|
||||
}
|
||||
Argument::Constant(c) => {
|
||||
Argument::Constant(apply_replacement_to_constant(c, replacements))
|
||||
}
|
||||
Argument::Difference(l, r) => {
|
||||
Argument::Difference(replace(l, replacements), replace(r, replacements))
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected instruction but got: {statement}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_replacement_to_constant(c: Constant, replacements: &BTreeMap<&str, &str>) -> Constant {
|
||||
match c {
|
||||
Constant::Number(_) => c,
|
||||
Constant::HiDataRef(s, off) => Constant::HiDataRef(replace(s, replacements), off),
|
||||
Constant::LoDataRef(s, off) => Constant::LoDataRef(replace(s, replacements), off),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_replacement_to_object(object: &mut Vec<DataValue>, replacements: &BTreeMap<&str, &str>) {
|
||||
for value in object {
|
||||
if let DataValue::Reference(reference) = value {
|
||||
@@ -253,9 +237,8 @@ fn apply_replacement_to_object(object: &mut Vec<DataValue>, replacements: &BTree
|
||||
}
|
||||
}
|
||||
|
||||
fn replace(s: String, replacements: &BTreeMap<&str, &str>) -> String {
|
||||
match replacements.get(s.as_str()) {
|
||||
Some(r) => r.to_string(),
|
||||
None => s,
|
||||
fn replace(s: &mut String, replacements: &BTreeMap<&str, &str>) {
|
||||
if let Some(r) = replacements.get(s.as_str()) {
|
||||
*s = r.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::str::FromStr;
|
||||
use crate::parser::{Statement, Argument, Register, Constant, unescape_string};
|
||||
use crate::parser::{Statement, Argument, Register, Expression, unescape_string, BinaryOpKind as BOp, UnaryOpKind as UOp, new_binary_op as bin_op, new_unary_op as un_op};
|
||||
|
||||
grammar;
|
||||
|
||||
@@ -55,9 +55,7 @@ Argument: Argument = {
|
||||
Register => Argument::Register(<>),
|
||||
OffsetRegister,
|
||||
StringLiteral => Argument::StringLiteral(<>),
|
||||
Symbol => Argument::Symbol(<>),
|
||||
Constant => Argument::Constant(<>),
|
||||
Difference,
|
||||
Expression => Argument::Expression(<>),
|
||||
}
|
||||
|
||||
Register: Register = {
|
||||
@@ -80,68 +78,59 @@ Register: Register = {
|
||||
}
|
||||
|
||||
OffsetRegister: Argument = {
|
||||
<c:Constant> "(" <r:Register> ")" => Argument::RegOffset(r, c),
|
||||
<c:Expression> "(" <r:Register> ")" => Argument::RegOffset(r, c),
|
||||
}
|
||||
|
||||
Constant: Constant = {
|
||||
ConstantExpression => Constant::Number(<>),
|
||||
"%hi(" <arg:HiLowArgument> ")" => Constant::HiDataRef(arg.0, arg.1),
|
||||
"%lo(" <arg:HiLowArgument> ")" => Constant::LoDataRef(arg.0, arg.1),
|
||||
Expression: Expression = {
|
||||
ExprBinaryOr,
|
||||
}
|
||||
|
||||
HiLowArgument: (String, i64) = {
|
||||
<Symbol> => (<>, 0),
|
||||
<Symbol> "+" <ConstantExpression> => (<>),
|
||||
<n: Symbol> "-" <o: ConstantExpression> => (n, -o),
|
||||
ExprBinaryOr: Expression = {
|
||||
<l:ExprBinaryOr> "|" <r:ExprBinaryXor> => bin_op(BOp::Or, l, r),
|
||||
ExprBinaryXor,
|
||||
}
|
||||
|
||||
ConstantExpression: i64 = {
|
||||
<l:ConstantExpression> "|" <r:ConstantBinaryXor> => l | r,
|
||||
ConstantBinaryXor,
|
||||
ExprBinaryXor: Expression = {
|
||||
<l:ExprBinaryXor> "^" <r:ExprBinaryAnd> => bin_op(BOp::Xor, l, r),
|
||||
ExprBinaryAnd,
|
||||
}
|
||||
|
||||
ConstantBinaryXor: i64 = {
|
||||
<l:ConstantBinaryXor> "^" <r:ConstantBinaryAnd> => l ^ r,
|
||||
ConstantBinaryAnd,
|
||||
ExprBinaryAnd: Expression = {
|
||||
<l:ExprBinaryAnd> "&" <r:ExprBitShift> => bin_op(BOp::And, l, r),
|
||||
ExprBitShift,
|
||||
}
|
||||
|
||||
ConstantBinaryAnd: i64 = {
|
||||
<l:ConstantBinaryAnd> "&" <r:ConstantBitShift> => l & r,
|
||||
ConstantBitShift,
|
||||
ExprBitShift: Expression = {
|
||||
<l:ExprBitShift> "<<" <r:ExprSum> => bin_op(BOp::LeftShift, l, r),
|
||||
<l:ExprBitShift> ">>" <r:ExprSum> => bin_op(BOp::RightShift, l, r),
|
||||
ExprSum,
|
||||
}
|
||||
|
||||
ConstantBitShift: i64 = {
|
||||
<l:ConstantBitShift> "<<" <r:ConstantSum> => l << r,
|
||||
<l:ConstantBitShift> ">>" <r:ConstantSum> => l >> r,
|
||||
ConstantSum,
|
||||
ExprSum: Expression = {
|
||||
<l:ExprSum> "+" <r:ExprProduct> => bin_op(BOp::Add, l, r),
|
||||
<l:ExprSum> "-" <r:ExprProduct> => bin_op(BOp::Sub, l, r),
|
||||
ExprProduct,
|
||||
}
|
||||
|
||||
ConstantSum: i64 = {
|
||||
<l:ConstantSum> "+" <r:ConstantProduct> => l + r,
|
||||
<l:ConstantSum> "-" <r:ConstantProduct> => l - r,
|
||||
ConstantProduct,
|
||||
ExprProduct: Expression = {
|
||||
<l:ExprProduct> "*" <r:ExprUnary> => bin_op(BOp::Mul, l, r),
|
||||
<l:ExprProduct> "/" <r:ExprUnary> => bin_op(BOp::Div, l, r),
|
||||
<l:ExprProduct> "%" <r:ExprUnary> => bin_op(BOp::Mod, l, r),
|
||||
ExprUnary,
|
||||
}
|
||||
|
||||
ConstantProduct: i64 = {
|
||||
<l:ConstantProduct> "*" <r:ConstantUnary> => l * r,
|
||||
<l:ConstantProduct> "/" <r:ConstantUnary> => l / r,
|
||||
<l:ConstantProduct> "%" <r:ConstantUnary> => l % r,
|
||||
ConstantUnary,
|
||||
ExprUnary: Expression = {
|
||||
"-" <ExprTerm> => un_op(UOp::Negation, <>),
|
||||
"+" <ExprTerm>,
|
||||
ExprTerm,
|
||||
}
|
||||
|
||||
ConstantUnary: i64 = {
|
||||
"-" <ConstantTerm> => -<>,
|
||||
"+" <ConstantTerm> => <>,
|
||||
ConstantTerm,
|
||||
}
|
||||
|
||||
ConstantTerm: i64 = {
|
||||
Number,
|
||||
"(" <ConstantExpression> ")",
|
||||
}
|
||||
|
||||
Difference: Argument = {
|
||||
<Symbol> "-" <Symbol> => Argument::Difference(<>)
|
||||
ExprTerm: Expression = {
|
||||
Number => Expression::Number(<>),
|
||||
"(" <Expression> ")" => <>,
|
||||
"%hi(" <Expression> ")" => un_op(UOp::HiDataRef, <>),
|
||||
"%lo(" <Expression> ")" => un_op(UOp::LoDataRef, <>),
|
||||
Symbol => Expression::Symbol(<>)
|
||||
}
|
||||
|
||||
StringLiteral: Vec<u8> = {
|
||||
@@ -164,5 +153,4 @@ Symbol: String = {
|
||||
Number: i64 = {
|
||||
r"-?[0-9][0-9_]*" => i64::from_str(<>).unwrap().into(),
|
||||
r"0x[0-9A-Fa-f][0-9A-Fa-f_]*" => i64::from_str_radix(&<>[2..].replace('_', ""), 16).unwrap().into(),
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,6 @@ use runtime::{get_prover_input, print};
|
||||
#[no_mangle]
|
||||
pub fn main() {
|
||||
let input = get_prover_input(0);
|
||||
print(format_args!("Input in hex: {input:x}\n"));
|
||||
print!("Input in hex: {input:x}\n");
|
||||
assert_eq!([1, 2, 3], [4, 5, 6]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user