diff --git a/Cargo.toml b/Cargo.toml index a51b6a1a5..3ba25ed36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,8 @@ members = [ "pilgen", "halo2", "backend", + "ast", + "analysis", ] [patch."https://github.com/privacy-scaling-explorations/halo2.git"] diff --git a/analysis/Cargo.toml b/analysis/Cargo.toml new file mode 100644 index 000000000..e6a7c0e36 --- /dev/null +++ b/analysis/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "analysis" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ast = { path = "../ast" } +itertools = "0.10.5" +log = "0.4.18" +number = { version = "0.1.0", path = "../number" } + +[dev-dependencies] +parser = { path = "../parser" } +pretty_assertions = "1.3.0" diff --git a/pilgen/src/batcher.rs b/analysis/src/batcher.rs similarity index 65% rename from pilgen/src/batcher.rs rename to analysis/src/batcher.rs index 556a66aea..043b9a758 100644 --- a/pilgen/src/batcher.rs +++ b/analysis/src/batcher.rs @@ -1,19 +1,22 @@ use std::marker::PhantomData; +use ast::analysis::{ + AnalysisASMFile, BatchMetadata, Incompatible, IncompatibleSet, ProgramStatement, +}; use itertools::Itertools; use number::FieldElement; -use parser::asm_ast::{ - batched::{ASMStatementBatch, BatchedASMFile, Incompatible, IncompatibleSet}, - ASMFile, ASMStatement, -}; -#[derive(Default)] -struct Batch { - statements: Vec>, +pub fn batch(file: AnalysisASMFile) -> AnalysisASMFile { + ProgramBatcher::default().batch(file) } -impl Batch { - fn from_statement(s: ASMStatement) -> Batch { +#[derive(Default)] +struct Batch<'a, T> { + statements: Vec<&'a ProgramStatement>, +} + +impl<'a, T: FieldElement> Batch<'a, T> { + fn from_statement(s: &'a ProgramStatement) -> Batch { Batch { statements: vec![s], } @@ -23,17 +26,20 @@ impl Batch { fn is_only_labels(&self) -> bool { self.statements .iter() - .all(|s| matches!(s, ASMStatement::Label(..))) + .all(|s| matches!(s, ProgramStatement::Label(..))) } /// Returns true iff this batch contains at least one label fn contains_labels(&self) -> bool { self.statements .iter() - .any(|s| matches!(s, ASMStatement::Label(..))) + .any(|s| matches!(s, ProgramStatement::Label(..))) } - fn try_absorb(&mut self, s: ASMStatement) -> Result<(), (ASMStatement, IncompatibleSet)> { + fn try_absorb( + &mut self, + s: &'a ProgramStatement, + ) -> Result<(), (&'a ProgramStatement, IncompatibleSet)> { let batch = Self::from_statement(s); self.try_join(batch) .map_err(|(b, incompatible)| (b.statements.into_iter().next().unwrap(), incompatible)) @@ -55,16 +61,16 @@ impl Batch { } #[derive(Default)] -pub struct ASMBatcher { +struct ProgramBatcher { marker: PhantomData, } -impl ASMBatcher { +impl ProgramBatcher { /// split a list of statements into compatible batches - fn to_compatible_batches( + fn extract_batches<'a>( &self, - statements: impl IntoIterator>, - ) -> Vec> { + statements: impl IntoIterator>, + ) -> Vec { statements .into_iter() .peekable() @@ -74,13 +80,13 @@ impl ASMBatcher { // look at the next statement match it.peek() { // try to add it to this batch - Some(new_s) => match batch.try_absorb(new_s.clone()) { + Some(new_s) => match batch.try_absorb(new_s) { Ok(()) => { it.next().unwrap(); } Err((_, reason)) => { - let res = ASMStatementBatch { - statements: batch.statements, + let res = BatchMetadata { + size: batch.statements.len(), reason: Some(reason), }; break Some(res); @@ -89,8 +95,8 @@ impl ASMBatcher { None => { break match batch.statements.len() { 0 => None, - _ => Some(ASMStatementBatch { - statements: batch.statements, + _ => Some(BatchMetadata { + size: batch.statements.len(), reason: None, }), } @@ -101,22 +107,10 @@ impl ASMBatcher { .collect() } - pub fn convert(&mut self, asm_file: ASMFile) -> BatchedASMFile { - let statements = asm_file.0.into_iter().peekable(); + pub fn batch(&mut self, mut asm_file: AnalysisASMFile) -> AnalysisASMFile { + let batches = self.extract_batches(&asm_file.program); - let (declarations, statements) = statements.into_iter().partition(|s| match s { - ASMStatement::Degree(..) - | ASMStatement::RegisterDeclaration(..) - | ASMStatement::InstructionDeclaration(..) - | ASMStatement::InlinePil(..) => true, - ASMStatement::Assignment(_, _, _, _) - | ASMStatement::Instruction(_, _, _) - | ASMStatement::Label(_, _) => false, - }); - - let batches = self.to_compatible_batches(statements); - - let lines_before = batches.iter().map(ASMStatementBatch::size).sum::() as f32; + let lines_before = batches.iter().map(BatchMetadata::size).sum::() as f32; let lines_after = batches.len() as f32; log::debug!( @@ -124,10 +118,9 @@ impl ASMBatcher { (1. - lines_after / lines_before) * 100. ); - BatchedASMFile { - declarations, - batches, - } + asm_file.batches = Some(batches); + + asm_file } } @@ -139,18 +132,20 @@ mod tests { use number::GoldilocksField; use pretty_assertions::assert_eq; - use super::*; + use crate::{batcher, macro_expansion, type_check}; fn test_batching(path: &str) { let base_path = PathBuf::from("../test_data/asm/batching"); let file_name = base_path.join(path); let contents = fs::read_to_string(&file_name).unwrap(); - let batched_asm = parser::parse_asm::( + let parsed = parser::parse_asm::( Some(file_name.as_os_str().to_str().unwrap()), &contents, ) - .map(|ast| ASMBatcher::default().convert(ast)) .unwrap(); + let expanded = macro_expansion::expand(parsed); + let checked = type_check::check(expanded).unwrap(); + let batched = batcher::batch(checked); let mut expected_file_name = file_name; expected_file_name.set_file_name(format!( "{}_batched.asm", @@ -159,7 +154,7 @@ mod tests { let expected = fs::read_to_string(expected_file_name).unwrap(); assert_eq!( - format!("{batched_asm}").replace("\n\n", "\n"), + format!("{batched}").replace("\n\n", "\n"), expected.replace("\n\n", "\n") ); } diff --git a/analysis/src/lib.rs b/analysis/src/lib.rs new file mode 100644 index 000000000..555462818 --- /dev/null +++ b/analysis/src/lib.rs @@ -0,0 +1,16 @@ +mod batcher; +mod macro_expansion; +mod type_check; + +/// expose the macro expander for use in the pil_analyzer +pub use macro_expansion::MacroExpander; + +use ast::{analysis::AnalysisASMFile, parsed::asm::ASMFile}; +use number::FieldElement; + +pub fn analyze(file: ASMFile) -> Result, String> { + let expanded = macro_expansion::expand(file); + let checked = type_check::check(expanded)?; + let batched = batcher::batch(checked); + Ok(batched) +} diff --git a/parser/src/macro_expansion.rs b/analysis/src/macro_expansion.rs similarity index 71% rename from parser/src/macro_expansion.rs rename to analysis/src/macro_expansion.rs index 11cc9420f..d3a355451 100644 --- a/parser/src/macro_expansion.rs +++ b/analysis/src/macro_expansion.rs @@ -3,9 +3,18 @@ use std::{ ops::ControlFlow, }; -use crate::ast::*; +use ast::parsed::{ + asm::{ASMFile, ASMStatement, InstructionBodyElement}, + postvisit_expression_in_statement_mut, postvisit_expression_mut, Expression, + FunctionDefinition, SelectedExpressions, Statement, +}; use number::FieldElement; +pub fn expand(file: ASMFile) -> ASMFile { + let mut expander = MacroExpander::default(); + expander.expand_asm(file) +} + #[derive(Debug, Default)] pub struct MacroExpander { macros: HashMap>, @@ -26,8 +35,29 @@ impl MacroExpander where T: FieldElement, { - pub fn new() -> Self { - Default::default() + fn expand_asm(&mut self, mut file: ASMFile) -> ASMFile { + let mut expander = MacroExpander::default(); + file.0.iter_mut().for_each(|s| match s { + ASMStatement::InstructionDeclaration(_, _, _, body) => { + body.iter_mut().for_each(|e| match e { + InstructionBodyElement::Expression(e) => { + self.process_expression(e); + } + InstructionBodyElement::PlookupIdentity(left, _, right) => { + self.process_selected_expressions(left); + self.process_selected_expressions(right); + } + InstructionBodyElement::FunctionCall(_, inputs) => { + self.process_expressions(inputs); + } + }); + } + ASMStatement::InlinePil(_, statements) => { + *statements = expander.expand_macros(std::mem::take(statements)); + } + _ => {} + }); + file } /// Expands all macro references inside the statements and also adds @@ -42,7 +72,7 @@ where std::mem::take(&mut self.statements) } - pub fn handle_statement(&mut self, mut statement: Statement) { + fn handle_statement(&mut self, mut statement: Statement) { let mut added_locals = false; if let Statement::PolynomialConstantDefinition(_, _, f) | Statement::PolynomialCommitDeclaration(_, _, Some(f)) = &statement @@ -139,4 +169,21 @@ where ControlFlow::<()>::Continue(()) } + + fn process_expressions(&mut self, exprs: &mut [Expression]) -> ControlFlow<()> { + for e in exprs.iter_mut() { + self.process_expression(e)?; + } + ControlFlow::Continue(()) + } + + fn process_selected_expressions( + &mut self, + exprs: &mut SelectedExpressions, + ) -> ControlFlow<()> { + if let Some(e) = &mut exprs.selector { + self.process_expression(e)?; + }; + self.process_expressions(&mut exprs.expressions) + } } diff --git a/analysis/src/type_check.rs b/analysis/src/type_check.rs new file mode 100644 index 000000000..d15e2a065 --- /dev/null +++ b/analysis/src/type_check.rs @@ -0,0 +1,74 @@ +use ast::{ + analysis::{ + AnalysisASMFile, AssignmentStatement, DegreeStatement, InstructionDefinitionStatement, + InstructionStatement, LabelStatement, PilBlock, RegisterDeclarationStatement, + }, + parsed::asm::{ASMFile, ASMStatement}, +}; +use number::FieldElement; + +/// A very stupid type checker. TODO: make it smart +pub fn check(file: ASMFile) -> Result, String> { + let mut degree = None; + let mut registers = vec![]; + let mut pil = vec![]; + let mut instructions = vec![]; + let mut program = vec![]; + + for s in file.0 { + match s { + ASMStatement::Degree(_, degree_value) => { + degree = Some(DegreeStatement { + degree: degree_value, + }); + } + ASMStatement::RegisterDeclaration(start, name, flag) => { + registers.push(RegisterDeclarationStatement { start, name, flag }); + } + ASMStatement::InstructionDeclaration(start, name, params, body) => { + instructions.push(InstructionDefinitionStatement { + start, + name, + params, + body, + }); + } + ASMStatement::InlinePil(start, statements) => { + pil.push(PilBlock { start, statements }); + } + ASMStatement::Assignment(start, lhs, using_reg, rhs) => { + program.push( + AssignmentStatement { + start, + lhs, + using_reg, + rhs, + } + .into(), + ); + } + ASMStatement::Instruction(start, instruction, inputs) => { + program.push( + InstructionStatement { + start, + instruction, + inputs, + } + .into(), + ); + } + ASMStatement::Label(start, name) => { + program.push(LabelStatement { start, name }.into()); + } + } + } + + Ok(AnalysisASMFile { + degree, + registers, + pil, + instructions, + program, + batches: None, + }) +} diff --git a/ast/Cargo.toml b/ast/Cargo.toml new file mode 100644 index 000000000..bb9c62a06 --- /dev/null +++ b/ast/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ast" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +num-bigint = "0.4.3" +number = { path = "../number" } + diff --git a/ast/src/analysis/display.rs b/ast/src/analysis/display.rs new file mode 100644 index 000000000..2262d21d4 --- /dev/null +++ b/ast/src/analysis/display.rs @@ -0,0 +1,174 @@ +use std::fmt::{Display, Formatter, Result}; + +use super::{ + AnalysisASMFile, AssignmentStatement, DegreeStatement, Incompatible, IncompatibleSet, + InstructionDefinitionStatement, InstructionStatement, LabelStatement, PilBlock, + ProgramStatement, RegisterDeclarationStatement, +}; + +impl Display for AnalysisASMFile { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + for s in &self.degree { + writeln!(f, "{s}")?; + } + for s in &self.registers { + writeln!(f, "{s}")?; + } + for s in &self.pil { + writeln!(f, "{s}")?; + } + for i in &self.instructions { + writeln!(f, "{i}")?; + } + + let mut statements = self.program.iter(); + + match self.batches.as_ref() { + Some(batches) => { + for batch in batches { + for s in (&mut statements).take(batch.size) { + writeln!(f, "{s}")?; + } + writeln!( + f, + "// END BATCH{}", + batch + .reason + .as_ref() + .map(|reason| format!(" {reason}")) + .unwrap_or_default() + )?; + } + } + None => { + for s in statements { + writeln!(f, "{s}")?; + } + } + } + Ok(()) + } +} + +impl Display for DegreeStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "degree {};", self.degree) + } +} + +impl Display for ProgramStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + ProgramStatement::Assignment(s) => write!(f, "{s}"), + ProgramStatement::Instruction(s) => write!(f, "{s}"), + ProgramStatement::Label(s) => write!(f, "{s}"), + } + } +} + +impl Display for AssignmentStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "{} <={}= {};", + self.lhs.join(", "), + self.using_reg + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(), + self.rhs + ) + } +} + +impl Display for InstructionStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "{}{};", + self.instruction, + if self.inputs.is_empty() { + "".to_string() + } else { + format!( + " {}", + self.inputs + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(", ") + ) + } + ) + } +} + +impl Display for LabelStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "{}::", self.name) + } +} + +impl Display for PilBlock { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "pil{{\n{}\n}}", + self.statements + .iter() + .map(|s| format!("{}", s)) + .collect::>() + .join("\n") + ) + } +} + +impl Display for RegisterDeclarationStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "reg {}{};", + self.name, + self.flag + .as_ref() + .map(|flag| format!("[{flag}]")) + .unwrap_or_default() + ) + } +} + +impl Display for InstructionDefinitionStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "instr {}{} {{{}}}", + self.name, + self.params, + self.body + .iter() + .map(|e| format!("{e}")) + .collect::>() + .join(" ") + ) + } +} + +impl Display for Incompatible { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "{:?}", self) + } +} + +impl Display for IncompatibleSet { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "{}", + self.0 + .iter() + .map(|r| r.to_string()) + .collect::>() + .join(", ") + ) + } +} diff --git a/ast/src/analysis/mod.rs b/ast/src/analysis/mod.rs new file mode 100644 index 000000000..8991000bc --- /dev/null +++ b/ast/src/analysis/mod.rs @@ -0,0 +1,106 @@ +mod display; + +use std::collections::BTreeSet; + +use num_bigint::BigUint; + +use crate::parsed::{ + asm::{InstructionBodyElement, InstructionParams, RegisterFlag}, + Expression, Statement, +}; + +pub struct RegisterDeclarationStatement { + pub start: usize, + pub name: String, + pub flag: Option, +} + +pub struct InstructionDefinitionStatement { + pub start: usize, + pub name: String, + pub params: InstructionParams, + pub body: Vec>, +} + +pub struct DegreeStatement { + pub degree: BigUint, +} + +pub enum ProgramStatement { + Assignment(AssignmentStatement), + Instruction(InstructionStatement), + Label(LabelStatement), +} + +impl From> for ProgramStatement { + fn from(value: AssignmentStatement) -> Self { + Self::Assignment(value) + } +} + +impl From> for ProgramStatement { + fn from(value: InstructionStatement) -> Self { + Self::Instruction(value) + } +} + +impl From for ProgramStatement { + fn from(value: LabelStatement) -> Self { + Self::Label(value) + } +} + +pub struct AssignmentStatement { + pub start: usize, + pub lhs: Vec, + pub using_reg: Option, + pub rhs: Box>, +} + +pub struct InstructionStatement { + pub start: usize, + pub instruction: String, + pub inputs: Vec>, +} + +pub struct LabelStatement { + pub start: usize, + pub name: String, +} + +pub struct PilBlock { + pub start: usize, + pub statements: Vec>, +} + +pub struct AnalysisASMFile { + pub degree: Option, + pub registers: Vec, + pub pil: Vec>, + pub instructions: Vec>, + pub program: Vec>, + pub batches: Option>, +} + +#[derive(Default, Debug, PartialEq, Eq, Clone)] +pub struct BatchMetadata { + // the set of compatible statements + pub size: usize, + // the reason why this batch ended (for debugging purposes), None if we ran out of statements to batch + pub reason: Option, +} + +impl BatchMetadata { + pub fn size(&self) -> usize { + self.size + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub enum Incompatible { + Label, + Unimplemented, +} + +#[derive(Debug, PartialEq, Eq, Default, Clone)] +pub struct IncompatibleSet(pub BTreeSet); diff --git a/ast/src/lib.rs b/ast/src/lib.rs new file mode 100644 index 000000000..c74fa8425 --- /dev/null +++ b/ast/src/lib.rs @@ -0,0 +1,4 @@ +/// A typed-checked ASM + PIL AST optimised for analysis +pub mod analysis; +/// A parsed ASM + PIL AST +pub mod parsed; diff --git a/parser/src/asm_ast.rs b/ast/src/parsed/asm.rs similarity index 97% rename from parser/src/asm_ast.rs rename to ast/src/parsed/asm.rs index e1c9026ab..813615f37 100644 --- a/parser/src/asm_ast.rs +++ b/ast/src/parsed/asm.rs @@ -1,6 +1,6 @@ use number::AbstractNumberType; -use super::ast::{Expression, SelectedExpressions, Statement}; +use super::{Expression, SelectedExpressions, Statement}; pub mod batched { use std::collections::BTreeSet; diff --git a/parser/src/display.rs b/ast/src/parsed/display.rs similarity index 75% rename from parser/src/display.rs rename to ast/src/parsed/display.rs index acbac008b..c6bbc9e3b 100644 --- a/parser/src/display.rs +++ b/ast/src/parsed/display.rs @@ -1,14 +1,8 @@ use std::fmt::{Display, Formatter, Result}; -use parser_util::quote; +use crate::parsed::{BinaryOperator, UnaryOperator}; -use crate::asm_ast::{ - batched::{ASMStatementBatch, BatchedASMFile, Incompatible, IncompatibleSet}, - ASMFile, ASMStatement, InstructionBodyElement, InstructionParam, InstructionParamList, - InstructionParams, PlookupOperator, RegisterFlag, -}; - -use super::ast::*; +use super::{asm::*, *}; // TODO indentation @@ -30,22 +24,10 @@ impl Display for ASMFile { } } -impl Display for BatchedASMFile { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - for s in &self.declarations { - writeln!(f, "{s}")?; - } - for s in &self.batches { - writeln!(f, "{s}")?; - } - Ok(()) - } -} - impl Display for ASMStatement { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { - ASMStatement::Degree(_, degree) => write!(f, "degree {degree};"), + ASMStatement::Degree(_, degree) => write!(f, "degree {};", degree), ASMStatement::RegisterDeclaration(_, name, flag) => write!( f, "reg {}{};", @@ -54,25 +36,29 @@ impl Display for ASMStatement { .map(|flag| format!("[{flag}]")) .unwrap_or_default() ), - ASMStatement::InstructionDeclaration(_, name, params, body) => write!( - f, - "instr {}{} {{{}}}", - name, - params, - body.iter() - .map(|e| format!("{e}")) - .collect::>() - .join(" ") - ), - ASMStatement::InlinePil(_, statements) => write!( - f, - "pil{{{}}}", - statements - .iter() - .map(|s| format!("{}", s)) - .collect::>() - .join("\n") - ), + ASMStatement::InstructionDeclaration(_, name, params, body) => { + write!( + f, + "instr {}{} {{{}}}", + name, + params, + body.iter() + .map(|e| format!("{e}")) + .collect::>() + .join(" ") + ) + } + ASMStatement::InlinePil(_, statements) => { + write!( + f, + "pil{{\n{}\n}}", + statements + .iter() + .map(|s| format!("{}", s)) + .collect::>() + .join("\n") + ) + } ASMStatement::Assignment(_, write_regs, assignment_reg, expression) => write!( f, "{} <={}= {};", @@ -195,40 +181,8 @@ impl Display for InstructionParam { } } -impl Display for ASMStatementBatch { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - for s in &self.statements { - writeln!(f, "{s}")?; - } - write!( - f, - "// END BATCH{}", - self.reason - .as_ref() - .map(|reason| format!(" {reason}")) - .unwrap_or_default() - ) - } -} - -impl Display for Incompatible { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - write!(f, "{:?}", self) - } -} - -impl Display for IncompatibleSet { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - write!( - f, - "{}", - self.0 - .iter() - .map(|r| r.to_string()) - .collect::>() - .join(", ") - ) - } +pub fn quote(input: &str) -> String { + format!("\"{}\"", input.replace('\\', "\\\\").replace('"', "\\\"")) } impl Display for Statement { @@ -292,14 +246,6 @@ impl Display for Statement { Statement::FunctionCall(_, name, args) => { write!(f, "{name}({});", format_expressions(args)) } - Statement::ASMBlock(_, statements) => { - writeln!(f, "assembly {{")?; - for _s in statements { - // TODO display for asm statements - //writeln!(f, "{s}")?; - } - writeln!(f, "}}") - } } } } @@ -461,60 +407,3 @@ impl Display for UnaryOperator { ) } } - -#[cfg(test)] -mod test { - use number::GoldilocksField; - - use crate::parse; - - #[test] - fn reparse() { - let input = r#" -constant %N = 16; -namespace Fibonacci(%N); -constant %last_row = (%N - 1); -macro bool(X) { (X * (1 - X)) = 0; }; -macro is_nonzero(X) { match X { 0 => 0, _ => 1, } }; -macro is_zero(X) { (1 - is_nonzero(X)) }; -macro is_equal(A, B) { is_zero((A - B)) }; -macro is_one(X) { is_equal(X, 1) }; -macro ite(C, A, B) { ((is_nonzero(C) * A) + (is_zero(C) * B)) }; -macro one_hot(i, index) { ite(is_equal(i, index), 1, 0) }; -pol constant ISLAST(i) { one_hot(i, %last_row) }; -pol commit x, y; -macro constrain_equal_expr(A, B) { (A - B) }; -macro force_equal_on_last_row(poly, value) { (ISLAST * constrain_equal_expr(poly, value)) = 0; }; -force_equal_on_last_row(x', 1); -force_equal_on_last_row(y', 1); -macro on_regular_row(cond) { ((1 - ISLAST) * cond) = 0; }; -on_regular_row(constrain_equal_expr(x', y)); -on_regular_row(constrain_equal_expr(y', (x + y))); -public out = y(%last_row);"#; - let printed = format!( - "{}", - parse::(Some("input"), input).unwrap() - ); - assert_eq!(input.trim(), printed.trim()); - } - - #[test] - fn reparse_witness_query() { - let input = r#"pol commit wit(i) query (x(i), y(i));"#; - let printed = format!( - "{}", - parse::(Some("input"), input).unwrap() - ); - assert_eq!(input.trim(), printed.trim()); - } - - #[test] - fn reparse_strings_and_tuples() { - let input = r#"constant %N = ("abc", 3);"#; - let printed = format!( - "{}", - parse::(Some("input"), input).unwrap() - ); - assert_eq!(input.trim(), printed.trim()); - } -} diff --git a/parser/src/ast.rs b/ast/src/parsed/mod.rs similarity index 97% rename from parser/src/ast.rs rename to ast/src/parsed/mod.rs index cb8fe4e47..83307bc51 100644 --- a/parser/src/ast.rs +++ b/ast/src/parsed/mod.rs @@ -1,9 +1,9 @@ +pub mod asm; +pub mod display; use std::{iter::once, ops::ControlFlow}; use number::{DegreeType, FieldElement}; -use crate::asm_ast::ASMStatement; - #[derive(Debug, PartialEq, Eq)] pub struct PILFile(pub Vec>); @@ -31,7 +31,6 @@ pub enum Statement { Option>, ), FunctionCall(usize, String, Vec>), - ASMBlock(usize, Vec>), } #[derive(Debug, PartialEq, Eq, Clone)] @@ -217,7 +216,7 @@ where } /// Traverses the expression trees of the statement and calls `f` in post-order. -/// Does not enter ASMBlocks or macro definitions. +/// Does not enter macro definitions. pub fn postvisit_expression_in_statement_mut( statement: &mut Statement, f: &mut F, @@ -258,8 +257,7 @@ where Statement::PolynomialCommitDeclaration(_, _, None) | Statement::Include(_, _) | Statement::PolynomialConstantDeclaration(_, _) - | Statement::MacroDefinition(_, _, _, _, _) - | Statement::ASMBlock(_, _) => ControlFlow::Continue(()), + | Statement::MacroDefinition(_, _, _, _, _) => ControlFlow::Continue(()), } } diff --git a/compiler/Cargo.toml b/compiler/Cargo.toml index 2065a0248..004c9a30a 100644 --- a/compiler/Cargo.toml +++ b/compiler/Cargo.toml @@ -15,4 +15,6 @@ executor = { path = "../executor" } pilgen = { path = "../pilgen" } pil_analyzer = { path = "../pil_analyzer" } halo2 = { path = "../halo2" } -json = "^0.12" \ No newline at end of file +json = "^0.12" +ast = { version = "0.1.0", path = "../ast" } +analysis = { version = "0.1.0", path = "../analysis" } diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 7bd35e735..f9ba22ad7 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -12,14 +12,15 @@ use json::JsonValue; pub mod util; mod verify; +use analysis::analyze; pub use backend::{Backend, Proof}; use number::write_polys_file; use pil_analyzer::{json_exporter, Analyzed}; pub use verify::{verify, verify_asm_string}; +use ast::parsed::PILFile; use executor::constant_evaluator; use number::FieldElement; -use parser::ast::PILFile; pub fn no_callback() -> Option Option> { None @@ -125,12 +126,13 @@ pub fn compile_asm_string( force_overwrite: bool, prove_with: Option, ) -> String { - let pil = pilgen::compile(Some(file_name), contents).unwrap_or_else(|err| { + let parsed = parser::parse_asm(Some(file_name), contents).unwrap_or_else(|err| { eprintln!("Error parsing .asm file:"); err.output_to_stderr(); panic!(); }); - + let analysed = analyze(parsed).unwrap(); + let pil = pilgen::compile(analysed); let pil_file_name = format!( "{}.pil", Path::new(file_name).file_stem().unwrap().to_str().unwrap() diff --git a/compiler/tests/pil.rs b/compiler/tests/pil.rs index 840715e82..592a9e4e0 100644 --- a/compiler/tests/pil.rs +++ b/compiler/tests/pil.rs @@ -76,31 +76,3 @@ fn test_pair_lookup() { fn test_block_lookup_or() { verify_pil("block_lookup_or.pil", None); } - -#[test] -fn test_simple_sum_asm_pil() { - verify_pil( - "simple_sum_asm.pil", - Some(|q| match q { - "\"input\", 0" => Some(13.into()), - "\"input\", 1" => Some(2.into()), - "\"input\", 2" => Some(11.into()), - "\"input\", 3" => Some(2.into()), - _ => Some(7.into()), - }), - ) -} - -#[test] -fn test_simple_sum_asm_macro_pil() { - verify_pil( - "simple_sum_asm_macro.pil", - Some(|q| match q { - "\"input\", 0" => Some(13.into()), - "\"input\", 1" => Some(2.into()), - "\"input\", 2" => Some(11.into()), - "\"input\", 3" => Some(2.into()), - _ => Some(7.into()), - }), - ) -} diff --git a/executor/Cargo.toml b/executor/Cargo.toml index c04b10bf5..20a8ad348 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -8,7 +8,7 @@ itertools = "^0.10" log = "0.4.17" number = { path = "../number" } parser_util = { path = "../parser_util" } -parser = { path = "../parser" } pil_analyzer = { path = "../pil_analyzer" } rayon = "1.7.0" num-traits = "0.2.15" +ast = { version = "0.1.0", path = "../ast" } diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 2c6cb5ffc..b7a517c68 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; +use ast::parsed::{BinaryOperator, UnaryOperator}; use itertools::Itertools; use number::{DegreeType, FieldElement}; -use pil_analyzer::{Analyzed, BinaryOperator, Expression, FunctionValueDefinition, UnaryOperator}; +use pil_analyzer::{Analyzed, Expression, FunctionValueDefinition}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; /// Generates the constant polynomial values for all constant polynomials diff --git a/executor/src/witgen/bit_constraints.rs b/executor/src/witgen/bit_constraints.rs index e591101c7..a937ed720 100644 --- a/executor/src/witgen/bit_constraints.rs +++ b/executor/src/witgen/bit_constraints.rs @@ -1,8 +1,9 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; +use ast::parsed::BinaryOperator; use number::{log2_exact, BigInt, FieldElement}; -use pil_analyzer::{BinaryOperator, Expression, Identity, IdentityKind, PolynomialReference}; +use pil_analyzer::{Expression, Identity, IdentityKind, PolynomialReference}; use super::expression_evaluator::ExpressionEvaluator; use super::symbolic_evaluator::SymbolicEvaluator; diff --git a/executor/src/witgen/expression_evaluator.rs b/executor/src/witgen/expression_evaluator.rs index 6981e58b8..0eaa1dddf 100644 --- a/executor/src/witgen/expression_evaluator.rs +++ b/executor/src/witgen/expression_evaluator.rs @@ -1,7 +1,8 @@ use std::marker::PhantomData; +use ast::parsed::{BinaryOperator, UnaryOperator}; use number::FieldElement; -use pil_analyzer::{BinaryOperator, Expression, PolynomialReference, UnaryOperator}; +use pil_analyzer::{Expression, PolynomialReference}; use super::{affine_expression::AffineResult, IncompleteCause}; diff --git a/halo2/Cargo.toml b/halo2/Cargo.toml index fa6de6b1f..8a65be643 100644 --- a/halo2/Cargo.toml +++ b/halo2/Cargo.toml @@ -14,7 +14,10 @@ itertools = "0.10.5" num-bigint = "^0.4" log = "0.4.17" rand = "0.8.5" +ast = { version = "0.1.0", path = "../ast" } [dev-dependencies] +analysis = { path = "../analysis" } executor = { path = "../executor" } -pilgen = { path = "../pilgen" } \ No newline at end of file +parser = { path = "../parser" } +pilgen = { path = "../pilgen" } diff --git a/halo2/src/circuit_builder.rs b/halo2/src/circuit_builder.rs index 69f6cfb20..f4640caac 100644 --- a/halo2/src/circuit_builder.rs +++ b/halo2/src/circuit_builder.rs @@ -1,3 +1,4 @@ +use ast::parsed::BinaryOperator; use num_bigint::BigUint; use polyexen::expr::{ColumnKind, ColumnQuery, Expr, PlonkVar}; use polyexen::plaf::backends::halo2::PlafH2Circuit; @@ -7,7 +8,7 @@ use polyexen::plaf::{ use num_traits::One; use number::{BigInt, FieldElement}; -use pil_analyzer::{self, BinaryOperator, Expression, IdentityKind, SelectedExpressions}; +use pil_analyzer::{self, Expression, IdentityKind, SelectedExpressions}; use super::circuit_data::CircuitData; diff --git a/halo2/src/mock_prover.rs b/halo2/src/mock_prover.rs index eded1267c..d859ba8c0 100644 --- a/halo2/src/mock_prover.rs +++ b/halo2/src/mock_prover.rs @@ -35,7 +35,9 @@ pub fn mock_prove( mod test { use std::fs; + use analysis::analyze; use number::Bn254Field; + use parser::parse_asm; use super::*; @@ -43,11 +45,9 @@ mod test { // read and compile PIL. let contents = fs::read_to_string(file_name).unwrap(); - let pil = pilgen::compile::(Some(file_name), &contents).unwrap_or_else(|err| { - eprintln!("Error parsing .asm file:"); - err.output_to_stderr(); - panic!(); - }); + let parsed = parse_asm::(Some(file_name), &contents).unwrap(); + let analysed = analyze(parsed).unwrap(); + let pil = pilgen::compile(analysed); let query_callback = |query: &str| -> Option { let items = query.split(',').map(|s| s.trim()).collect::>(); diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 568754af9..540d0ce49 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -6,12 +6,17 @@ build = "build.rs" [dependencies] lalrpop-util = {version = "^0.19", features = ["lexer"]} +num-bigint = "0.4.3" num-traits = "0.2.15" number = { path = "../number" } +ast = { path = "../ast" } parser_util = { path = "../parser_util" } # This is only here to work around https://github.com/lalrpop/lalrpop/issues/750 # It should be removed once that workaround is no longer needed. regex-syntax = { version = "0.6", default_features = false, features = ["unicode"] } +[dev-dependencies] +pretty_assertions = "1.3.0" + [build-dependencies] lalrpop = "^0.19" diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 1a2cdd074..dcc4bd948 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -5,11 +5,6 @@ use lalrpop_util::*; use number::FieldElement; use parser_util::{handle_parse_error, ParseError}; -pub mod asm_ast; -pub mod ast; -pub mod display; -pub mod macro_expansion; - lalrpop_mod!( #[allow(clippy::all)] powdr, @@ -19,7 +14,7 @@ lalrpop_mod!( pub fn parse<'a, T: FieldElement>( file_name: Option<&str>, input: &'a str, -) -> Result, ParseError<'a>> { +) -> Result, ParseError<'a>> { powdr::PILFileParser::new() .parse(input) .map_err(|err| handle_parse_error(err, file_name, input)) @@ -28,7 +23,7 @@ pub fn parse<'a, T: FieldElement>( pub fn parse_asm<'a, T: FieldElement>( file_name: Option<&str>, input: &'a str, -) -> Result, ParseError<'a>> { +) -> Result, ParseError<'a>> { powdr::ASMFileParser::new() .parse(input) .map_err(|err| handle_parse_error(err, file_name, input)) @@ -36,11 +31,13 @@ pub fn parse_asm<'a, T: FieldElement>( #[cfg(test)] mod test { - use std::fs; - - use super::{asm_ast::ASMFile, *}; - use ast::*; + use super::*; + use ast::parsed::{ + asm::ASMFile, BinaryOperator, Expression, PILFile, PolynomialName, PolynomialReference, + SelectedExpressions, Statement, + }; use number::GoldilocksField; + use std::fs; #[test] fn empty() { @@ -197,8 +194,61 @@ mod test { parse_asm_file("asm/simple_sum.asm"); } - #[test] - fn parse_mixed_pil_asm_files() { - parse_file("pil/simple_sum_asm.pil"); + mod display { + use number::GoldilocksField; + + use pretty_assertions::assert_eq; + + use crate::parse; + + #[test] + fn reparse() { + let input = r#" +constant %N = 16; +namespace Fibonacci(%N); +constant %last_row = (%N - 1); +macro bool(X) { (X * (1 - X)) = 0; }; +macro is_nonzero(X) { match X { 0 => 0, _ => 1, } }; +macro is_zero(X) { (1 - is_nonzero(X)) }; +macro is_equal(A, B) { is_zero((A - B)) }; +macro is_one(X) { is_equal(X, 1) }; +macro ite(C, A, B) { ((is_nonzero(C) * A) + (is_zero(C) * B)) }; +macro one_hot(i, index) { ite(is_equal(i, index), 1, 0) }; +pol constant ISLAST(i) { one_hot(i, %last_row) }; +pol commit x, y; +macro constrain_equal_expr(A, B) { (A - B) }; +macro force_equal_on_last_row(poly, value) { (ISLAST * constrain_equal_expr(poly, value)) = 0; }; +force_equal_on_last_row(x', 1); +force_equal_on_last_row(y', 1); +macro on_regular_row(cond) { ((1 - ISLAST) * cond) = 0; }; +on_regular_row(constrain_equal_expr(x', y)); +on_regular_row(constrain_equal_expr(y', (x + y))); +public out = y(%last_row);"#; + let printed = format!( + "{}", + parse::(Some("input"), input).unwrap() + ); + assert_eq!(input.trim(), printed.trim()); + } + + #[test] + fn reparse_witness_query() { + let input = r#"pol commit wit(i) query (x(i), y(i));"#; + let printed = format!( + "{}", + parse::(Some("input"), input).unwrap() + ); + assert_eq!(input.trim(), printed.trim()); + } + + #[test] + fn reparse_strings_and_tuples() { + let input = r#"constant %N = ("abc", 3);"#; + let printed = format!( + "{}", + parse::(Some("input"), input).unwrap() + ); + assert_eq!(input.trim(), printed.trim()); + } } } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index bd3556393..7917a7ad9 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -1,6 +1,5 @@ use std::str::FromStr; -use crate::ast::*; -use crate::asm_ast::*; +use ast::parsed::{*, asm::*}; use number::{AbstractNumberType, FieldElement}; use num_traits::Num; @@ -39,7 +38,6 @@ Statement = { ConnectIdentity, MacroDefinition, FunctionCallStatement, - ASMBlock, }; Include: Statement = { @@ -134,10 +132,6 @@ FunctionCallStatement: Statement = { <@L> "(" ")" => Statement::FunctionCall(<>) } -ASMBlock: Statement = { - <@L> "assembly" "{" <(ASMStatement)*> "}" => Statement::ASMBlock(<>) -} - PolCol = { "pol", "col" } diff --git a/parser_util/src/lib.rs b/parser_util/src/lib.rs index ce3c6d7ae..d22d77848 100644 --- a/parser_util/src/lib.rs +++ b/parser_util/src/lib.rs @@ -2,10 +2,6 @@ pub mod lines; -pub fn quote(input: &str) -> String { - format!("\"{}\"", input.replace('\\', "\\\\").replace('"', "\\\"")) -} - #[derive(Debug)] pub struct ParseError<'a> { start: usize, diff --git a/pil_analyzer/Cargo.toml b/pil_analyzer/Cargo.toml index 2b27924d9..8bdf6e3a9 100644 --- a/pil_analyzer/Cargo.toml +++ b/pil_analyzer/Cargo.toml @@ -9,5 +9,6 @@ mktemp = "0.5.0" number = { path = "../number" } parser_util = { path = "../parser_util" } parser = { path = "../parser" } -pilgen = { path = "../pilgen" } -itertools = "^0.10" \ No newline at end of file +itertools = "^0.10" +ast = { version = "0.1.0", path = "../ast" } +analysis = { version = "0.1", path = "../analysis" } diff --git a/pil_analyzer/src/lib.rs b/pil_analyzer/src/lib.rs index 9f663c555..ad6f3ec36 100644 --- a/pil_analyzer/src/lib.rs +++ b/pil_analyzer/src/lib.rs @@ -7,8 +7,9 @@ use std::hash::Hash; use std::path::Path; use std::{collections::HashMap, fmt::Display}; +use ast::parsed::{BinaryOperator, UnaryOperator}; use number::{DegreeType, FieldElement}; -pub use parser::ast::{BinaryOperator, UnaryOperator}; + use util::expr_any; pub fn analyze(path: &Path) -> Analyzed { diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 449b923d2..0b78deb19 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -2,15 +2,18 @@ use std::collections::{HashMap, HashSet}; use std::fs; use std::path::{Path, PathBuf}; +use analysis::MacroExpander; +use ast::parsed::{ + ArrayExpression, BinaryOperator, FunctionDefinition, PolynomialName, Statement, UnaryOperator, +}; use number::{BigInt, DegreeType, FieldElement}; -use parser::asm_ast::ASMStatement; -use parser::ast; -pub use parser::ast::{BinaryOperator, UnaryOperator}; -use parser::macro_expansion::MacroExpander; use crate::util::previsit_expressions_in_pil_file_mut; - -use super::*; +use crate::{ + Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, Polynomial, + PolynomialReference, PolynomialType, PublicDeclaration, RepeatedArray, SelectedExpressions, + SourceRef, StatementIdentifier, +}; pub fn process_pil_file(path: &Path) -> Analyzed { let mut ctx = PILContext::new(); @@ -105,6 +108,7 @@ impl PILContext { return; } let contents = fs::read_to_string(path.clone()).unwrap(); + self.process_file_contents(&path, &contents); } @@ -132,8 +136,7 @@ impl PILContext { self.line_starts = old_line_starts; } - fn handle_statement(&mut self, statement: ast::Statement) { - use ast::Statement; + fn handle_statement(&mut self, statement: Statement) { match statement { Statement::Include(_, include) => self.handle_include(include), Statement::Namespace(_, name, degree) => self.handle_namespace(name, degree), @@ -143,7 +146,7 @@ impl PILContext { name, None, PolynomialType::Intermediate, - Some(ast::FunctionDefinition::Mapping(vec![], value)), + Some(FunctionDefinition::Mapping(vec![], value)), ); } Statement::PublicDeclaration(start, name, polynomial, index) => { @@ -187,9 +190,6 @@ impl PILContext { Statement::MacroDefinition(_, _, _, _, _) => { panic!("Macros should have been eliminated."); } - Statement::ASMBlock(start, asm_statements) => { - self.handle_assembly(self.to_source_ref(start), asm_statements) - } _ => { self.handle_identity_statement(statement); } @@ -204,9 +204,9 @@ impl PILContext { } } - fn handle_identity_statement(&mut self, statement: ast::Statement) { + fn handle_identity_statement(&mut self, statement: Statement) { let (start, kind, left, right) = match statement { - ast::Statement::PolynomialIdentity(start, expression) => ( + Statement::PolynomialIdentity(start, expression) => ( start, IdentityKind::Polynomial, SelectedExpressions { @@ -215,19 +215,19 @@ impl PILContext { }, SelectedExpressions::default(), ), - ast::Statement::PlookupIdentity(start, key, haystack) => ( + Statement::PlookupIdentity(start, key, haystack) => ( start, IdentityKind::Plookup, self.process_selected_expression(key), self.process_selected_expression(haystack), ), - ast::Statement::PermutationIdentity(start, left, right) => ( + Statement::PermutationIdentity(start, left, right) => ( start, IdentityKind::Permutation, self.process_selected_expression(left), self.process_selected_expression(right), ), - ast::Statement::ConnectIdentity(start, left, right) => ( + Statement::ConnectIdentity(start, left, right) => ( start, IdentityKind::Connect, SelectedExpressions { @@ -263,7 +263,7 @@ impl PILContext { self.process_file(&dir); } - fn handle_namespace(&mut self, name: String, degree: ast::Expression) { + fn handle_namespace(&mut self, name: String, degree: ::ast::parsed::Expression) { // TODO: the polynomial degree should be handled without going through a field element. This requires having types in Expression self.polynomial_degree = self.evaluate_expression(°ree).unwrap().to_degree(); self.namespace = name; @@ -272,10 +272,10 @@ impl PILContext { fn handle_polynomial_declarations( &mut self, source: SourceRef, - polynomials: Vec>, + polynomials: Vec>, polynomial_type: PolynomialType, ) { - for ast::PolynomialName { name, array_size } in polynomials { + for PolynomialName { name, array_size } in polynomials { self.handle_polynomial_definition( source.clone(), name, @@ -290,9 +290,9 @@ impl PILContext { &mut self, source: SourceRef, name: String, - array_size: Option>, + array_size: Option<::ast::parsed::Expression>, polynomial_type: PolynomialType, - value: Option>, + value: Option>, ) -> u64 { let have_array_size = array_size.is_some(); let length = array_size @@ -320,7 +320,7 @@ impl PILContext { let name = poly.absolute_name.clone(); let value = value.map(|v| match v { - ast::FunctionDefinition::Mapping(params, expr) => { + FunctionDefinition::Mapping(params, expr) => { assert!(!have_array_size); assert!( poly.poly_type == PolynomialType::Constant @@ -328,12 +328,12 @@ impl PILContext { ); FunctionValueDefinition::Mapping(self.process_function(params, expr)) } - ast::FunctionDefinition::Query(params, expr) => { + FunctionDefinition::Query(params, expr) => { assert!(!have_array_size); assert_eq!(poly.poly_type, PolynomialType::Committed); FunctionValueDefinition::Query(self.process_function(params, expr)) } - ast::FunctionDefinition::Array(value) => { + FunctionDefinition::Array(value) => { let star_value = value.solve(self.polynomial_degree); let expression = self.process_array_expression(value, star_value); assert_eq!( @@ -356,7 +356,7 @@ impl PILContext { fn process_function( &mut self, params: Vec, - expression: ast::Expression, + expression: ::ast::parsed::Expression, ) -> Expression { assert!(self.local_variables.is_empty()); self.local_variables = params @@ -373,8 +373,8 @@ impl PILContext { &mut self, source: SourceRef, name: String, - poly: ast::PolynomialReference, - index: ast::Expression, + poly: ::ast::parsed::PolynomialReference, + index: ::ast::parsed::Expression, ) { let id = self.public_declarations.len() as u64; self.public_declarations.insert( @@ -391,7 +391,7 @@ impl PILContext { .push(StatementIdentifier::PublicDeclaration(name)); } - fn handle_constant_definition(&mut self, name: String, value: ast::Expression) { + fn handle_constant_definition(&mut self, name: String, value: ::ast::parsed::Expression) { // TODO does the order matter here? let is_new = self .constants @@ -407,13 +407,6 @@ impl PILContext { id } - fn handle_assembly(&mut self, _source: SourceRef, asm_statements: Vec>) { - let statements = pilgen::asm_to_pil(asm_statements.into_iter(), &mut self.macro_expander); - for s in statements { - self.handle_statement(s) - } - } - fn namespaced(&self, name: &str) -> String { self.namespaced_ref(&None, name) } @@ -424,7 +417,7 @@ impl PILContext { fn process_selected_expression( &mut self, - expr: ast::SelectedExpressions, + expr: ::ast::parsed::SelectedExpressions, ) -> SelectedExpressions { SelectedExpressions { selector: expr.selector.map(|e| self.process_expression(e)), @@ -434,15 +427,15 @@ impl PILContext { fn process_array_expression( &mut self, - array_expression: ast::ArrayExpression, + array_expression: ::ast::parsed::ArrayExpression, star_value: Option, ) -> Vec> { match array_expression { - ast::ArrayExpression::Value(expressions) => vec![RepeatedArray { + ArrayExpression::Value(expressions) => vec![RepeatedArray { values: self.process_expressions(expressions), repetitions: 1, }], - ast::ArrayExpression::RepeatedValue(expressions) => { + ArrayExpression::RepeatedValue(expressions) => { if star_value.unwrap() == 0 { vec![] } else { @@ -452,7 +445,7 @@ impl PILContext { }] } } - ast::ArrayExpression::Concat(left, right) => self + ArrayExpression::Concat(left, right) => self .process_array_expression(*left, star_value) .into_iter() .chain(self.process_array_expression(*right, star_value)) @@ -460,17 +453,21 @@ impl PILContext { } } - fn process_expressions(&mut self, exprs: Vec>) -> Vec> { + fn process_expressions( + &mut self, + exprs: Vec<::ast::parsed::Expression>, + ) -> Vec> { exprs .into_iter() .map(|e| self.process_expression(e)) .collect() } - fn process_expression(&mut self, expr: ast::Expression) -> Expression { + fn process_expression(&mut self, expr: ::ast::parsed::Expression) -> Expression { + use ::ast::parsed::Expression::*; match expr { - ast::Expression::Constant(name) => Expression::Constant(name), - ast::Expression::PolynomialReference(poly) => { + Constant(name) => Expression::Constant(name), + PolynomialReference(poly) => { if poly.namespace.is_none() && self.local_variables.contains_key(&poly.name) { let id = self.local_variables[&poly.name]; assert!(!poly.next); @@ -480,11 +477,11 @@ impl PILContext { Expression::PolynomialReference(self.process_polynomial_reference(poly)) } } - ast::Expression::PublicReference(name) => Expression::PublicReference(name), - ast::Expression::Number(n) => Expression::Number(n), - ast::Expression::String(value) => Expression::String(value), - ast::Expression::Tuple(items) => Expression::Tuple(self.process_expressions(items)), - ast::Expression::BinaryOperation(left, op, right) => { + PublicReference(name) => Expression::PublicReference(name), + Number(n) => Expression::Number(n), + String(value) => Expression::String(value), + Tuple(items) => Expression::Tuple(self.process_expressions(items)), + BinaryOperation(left, op, right) => { if let Some(value) = self.evaluate_binary_operation(&left, op, &right) { Expression::Number(value) } else { @@ -495,18 +492,18 @@ impl PILContext { ) } } - ast::Expression::UnaryOperation(op, value) => { + UnaryOperation(op, value) => { if let Some(value) = self.evaluate_unary_operation(op, &value) { Expression::Number(value) } else { Expression::UnaryOperation(op, Box::new(self.process_expression(*value))) } } - ast::Expression::FunctionCall(name, arguments) => Expression::FunctionCall( + FunctionCall(name, arguments) => Expression::FunctionCall( self.namespaced(&name), self.process_expressions(arguments), ), - ast::Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( + MatchExpression(scrutinee, arms) => Expression::MatchExpression( Box::new(self.process_expression(*scrutinee)), arms.into_iter() .map(|(n, e)| { @@ -521,13 +518,13 @@ impl PILContext { }) .collect(), ), - ast::Expression::FreeInput(_) => panic!(), + FreeInput(_) => panic!(), } } fn process_polynomial_reference( &self, - poly: ast::PolynomialReference, + poly: ::ast::parsed::PolynomialReference, ) -> PolynomialReference { let index = poly .index @@ -542,34 +539,33 @@ impl PILContext { } } - fn evaluate_expression(&self, expr: &ast::Expression) -> Option { + fn evaluate_expression(&self, expr: &::ast::parsed::Expression) -> Option { + use ::ast::parsed::Expression::*; match expr { - ast::Expression::Constant(name) => Some( + Constant(name) => Some( *self .constants .get(name) .unwrap_or_else(|| panic!("Constant {name} not found.")), ), - ast::Expression::PolynomialReference(_) => None, - ast::Expression::PublicReference(_) => None, - ast::Expression::Number(n) => Some(*n), - ast::Expression::String(_) => None, - ast::Expression::Tuple(_) => None, - ast::Expression::BinaryOperation(left, op, right) => { - self.evaluate_binary_operation(left, *op, right) - } - ast::Expression::UnaryOperation(op, value) => self.evaluate_unary_operation(*op, value), - ast::Expression::FunctionCall(_, _) => None, - ast::Expression::FreeInput(_) => panic!(), - ast::Expression::MatchExpression(_, _) => None, + PolynomialReference(_) => None, + PublicReference(_) => None, + Number(n) => Some(*n), + String(_) => None, + Tuple(_) => None, + BinaryOperation(left, op, right) => self.evaluate_binary_operation(left, *op, right), + UnaryOperation(op, value) => self.evaluate_unary_operation(*op, value), + FunctionCall(_, _) => None, + FreeInput(_) => panic!(), + MatchExpression(_, _) => None, } } fn evaluate_binary_operation( &self, - left: &ast::Expression, + left: &::ast::parsed::Expression, op: BinaryOperator, - right: &ast::Expression, + right: &::ast::parsed::Expression, ) -> Option { if let (Some(left), Some(right)) = ( self.evaluate_expression(left), @@ -607,7 +603,11 @@ impl PILContext { } } - fn evaluate_unary_operation(&self, op: UnaryOperator, value: &ast::Expression) -> Option { + fn evaluate_unary_operation( + &self, + op: UnaryOperator, + value: &::ast::parsed::Expression, + ) -> Option { self.evaluate_expression(value).map(|v| match op { UnaryOperator::Plus => v, UnaryOperator::Minus => -v, diff --git a/pilgen/Cargo.toml b/pilgen/Cargo.toml index 1d0757976..4f0370744 100644 --- a/pilgen/Cargo.toml +++ b/pilgen/Cargo.toml @@ -9,4 +9,8 @@ number = { path = "../number" } parser_util = { path = "../parser_util" } parser = { path = "../parser" } itertools = "0.10.5" -pretty_assertions = "1.3.0" +num-bigint = "0.4.3" +ast = { path = "../ast" } + +[dev-dependencies] +analysis = { path = "../analysis" } diff --git a/pilgen/src/lib.rs b/pilgen/src/lib.rs index 2976ed56d..efff85606 100644 --- a/pilgen/src/lib.rs +++ b/pilgen/src/lib.rs @@ -1,47 +1,24 @@ //! Compilation from powdr assembly to PIL -mod batcher; +use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::collections::BTreeMap; -use std::collections::BTreeSet; -use std::collections::HashMap; - -use batcher::ASMBatcher; +use ast::{ + analysis::{ + AnalysisASMFile, AssignmentStatement, BatchMetadata, InstructionDefinitionStatement, + InstructionStatement, LabelStatement, PilBlock, ProgramStatement, + RegisterDeclarationStatement, + }, + parsed::{ + asm::{InstructionBodyElement, PlookupOperator, RegisterFlag}, + postvisit_expression_in_statement_mut, ArrayExpression, BinaryOperator, Expression, + FunctionDefinition, PILFile, PolynomialName, PolynomialReference, SelectedExpressions, + Statement, UnaryOperator, + }, +}; use number::FieldElement; -use parser::asm_ast::batched::{ASMStatementBatch, BatchedASMFile}; -use parser::asm_ast::*; -use parser::ast::*; -use parser::macro_expansion::MacroExpander; -use parser_util::ParseError; - -/// Compiles a stand-alone assembly file to PIL. -pub fn compile<'a, T: FieldElement>( - file_name: Option<&str>, - input: &'a str, -) -> Result, ParseError<'a>> { - let statements = parser::parse_asm(file_name, input) - .map(|ast| ASMBatcher::default().convert(ast)) - .map(|batched_asm| { - let mut macro_expander = MacroExpander::new(); - ASMPILConverter::new(&mut macro_expander).convert(batched_asm, ASMKind::StandAlone) - })?; - Ok(PILFile(statements)) -} - -/// Compiles inline assembly to PIL. -pub fn asm_to_pil( - statements: impl IntoIterator>, - macro_expander: &mut MacroExpander, -) -> Vec> { - let batched_asm = ASMBatcher::default().convert(ASMFile(statements.into_iter().collect())); - ASMPILConverter::new(macro_expander).convert(batched_asm, ASMKind::Inline) -} - -#[derive(PartialEq)] -pub enum ASMKind { - Inline, - StandAlone, +pub fn compile(analysis: AnalysisASMFile) -> PILFile { + PILFile(ASMPILConverter::default().compile(analysis)) } pub enum Input { @@ -55,8 +32,8 @@ pub enum LiteralKind { UnsignedConstant, } -struct ASMPILConverter<'a, T> { - macro_expander: &'a mut MacroExpander, +#[derive(Default)] +struct ASMPILConverter { pil: Vec>, pc_name: Option, registers: BTreeMap>, @@ -68,41 +45,27 @@ struct ASMPILConverter<'a, T> { program_constant_names: Vec, } -impl<'a, T: FieldElement> ASMPILConverter<'a, T> { - fn new(macro_expander: &'a mut MacroExpander) -> Self { - Self { - macro_expander, - pil: Default::default(), - pc_name: None, - registers: Default::default(), - instructions: Default::default(), - code_lines: Default::default(), - line_lookup: Default::default(), - program_constant_names: Default::default(), - } +impl ASMPILConverter { + fn handle_inline_pil(&mut self, PilBlock { statements, .. }: PilBlock) { + self.pil.extend(statements) } - fn convert(&mut self, batched_asm: BatchedASMFile, asm_kind: ASMKind) -> Vec> { - let mut declarations = batched_asm.declarations.into_iter().peekable(); - if asm_kind == ASMKind::StandAlone { - let degree = if let Some(ASMStatement::Degree(_, deg)) = declarations.peek() { - let deg = T::from(deg.clone()).to_degree(); - declarations.next(); - deg - } else { - 1024 - }; + fn compile(&mut self, input: AnalysisASMFile) -> Vec> { + let degree = if let Some(s) = input.degree { + T::from(s.degree).to_degree() + } else { + 1024 + }; - assert!( - degree.is_power_of_two(), - "Degree should be a power of two, found {degree}", - ); - self.pil.push(Statement::Namespace( - 0, - "Assembly".to_string(), - Expression::Number(degree.into()), - )); - } + assert!( + degree.is_power_of_two(), + "Degree should be a power of two, found {degree}", + ); + self.pil.push(Statement::Namespace( + 0, + "Assembly".to_string(), + Expression::Number(degree.into()), + )); self.pil.push(Statement::PolynomialConstantDefinition( 0, @@ -112,24 +75,34 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { ), )); - for s in declarations { - match s { - ASMStatement::RegisterDeclaration(start, name, flags) => { - self.handle_register_declaration(flags, &name, start); - } - ASMStatement::InstructionDeclaration(start, name, params, body) => { - self.handle_instruction_def(start, body, name, params); - } - ASMStatement::InlinePil(start, statements) => { - self.handle_inline_pil(start, statements); - } - _ => unreachable!(), - } + for reg in input.registers { + self.handle_register_declaration(reg); } - for batch in batched_asm.batches { - self.handle_batch(batch); + for block in input.pil { + self.handle_inline_pil(block); } + + for instr in input.instructions { + self.handle_instruction_def(instr); + } + + let batches = input.batches.unwrap_or_else(|| { + vec![ + BatchMetadata { + size: 1, + reason: None + }; + input.program.len() + ] + }); + + let mut statements = input.program.into_iter(); + + for batch in batches { + self.handle_batch(batch, &mut statements); + } + let assignment_registers = self.assignment_registers().cloned().collect::>(); for reg in assignment_registers { self.create_constraints_for_assignment_reg(reg); @@ -181,9 +154,13 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { std::mem::take(&mut self.pil) } - fn handle_batch(&mut self, batch: ASMStatementBatch) { - let code_lines = batch - .into_statements() + fn handle_batch( + &mut self, + batch: BatchMetadata, + statements: &mut impl Iterator>, + ) { + let code_lines = statements + .take(batch.size) .filter_map(|s| self.handle_statement(s)) .reduce(|mut acc, e| { // we write to the union of the target registers. @@ -203,49 +180,44 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { self.code_lines.extend(code_lines); } - fn handle_statement(&mut self, statement: ASMStatement) -> Option> { + fn handle_statement(&mut self, statement: ProgramStatement) -> Option> { match statement { - ASMStatement::Assignment(start, write_regs, assign_reg, value) => Some(match *value { - Expression::FunctionCall(function_name, args) => self - .handle_functional_instruction( - write_regs, - assign_reg.unwrap(), - function_name, - args, - ), - _ => self.handle_assignment(start, write_regs, assign_reg, *value), + ProgramStatement::Assignment(AssignmentStatement { + start, + lhs, + using_reg, + rhs, + }) => Some(match *rhs { + Expression::FunctionCall(function_name, args) => { + self.handle_functional_instruction(lhs, using_reg.unwrap(), function_name, args) + } + _ => self.handle_assignment(start, lhs, using_reg, *rhs), }), - ASMStatement::Instruction(_start, instr_name, args) => { - Some(self.handle_instruction(instr_name, args)) - } - ASMStatement::Label(_start, name) => Some(CodeLine { + ProgramStatement::Instruction(InstructionStatement { + instruction, + inputs, + .. + }) => Some(self.handle_instruction(instruction, inputs)), + ProgramStatement::Label(LabelStatement { name, .. }) => Some(CodeLine { labels: [name].into(), ..Default::default() }), - s => unreachable!("{:?}", s), } } - fn handle_inline_pil(&mut self, _size: usize, statements: Vec>) { - self.pil - .extend(self.macro_expander.expand_macros(statements)) - } - fn handle_register_declaration( &mut self, - flags: Option, - name: &str, - start: usize, + RegisterDeclarationStatement { start, flag, name }: RegisterDeclarationStatement, ) { let mut conditioned_updates = vec![]; let mut default_update = None; - match flags { + match flag { Some(RegisterFlag::IsPC) => { assert_eq!(self.pc_name, None); self.pc_name = Some(name.to_string()); self.line_lookup .push((name.to_string(), "p_line".to_string())); - default_update = Some(build_add(direct_reference(name), build_number(1u64))); + default_update = Some(build_add(direct_reference(&name), build_number(1u64))); } Some(RegisterFlag::IsAssignment) => { // no updates @@ -255,7 +227,7 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { // be zero in the first row. self.pil.push(Statement::PolynomialIdentity( start, - build_mul(direct_reference("first_step"), direct_reference(name)), + build_mul(direct_reference("first_step"), direct_reference(&name)), )); conditioned_updates = vec![ // The value here is actually irrelevant, it is only important @@ -270,7 +242,7 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { conditioned_updates .push((direct_reference(&write_flag), direct_reference(®))); } - default_update = Some(direct_reference(name)); + default_update = Some(direct_reference(&name)); } }; self.registers.insert( @@ -278,7 +250,7 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { Register { conditioned_updates, default_update, - is_assignment: flags == Some(RegisterFlag::IsAssignment), + is_assignment: flag == Some(RegisterFlag::IsAssignment), }, ); self.pil.push(witness_column(start, name, None)); @@ -286,10 +258,12 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { fn handle_instruction_def( &mut self, - start: usize, - body: Vec>, - name: String, - params: InstructionParams, + InstructionDefinitionStatement { + start, + body, + name, + params, + }: InstructionDefinitionStatement, ) { let instruction_flag = format!("instr_{name}"); self.create_witness_fixed_pair(start, &instruction_flag); @@ -373,7 +347,7 @@ impl<'a, T: FieldElement> ASMPILConverter<'a, T> { }); // Expand macros and analyze resulting statements. - for mut statement in self.macro_expander.expand_macros(statements) { + for mut statement in statements { if let Statement::PolynomialIdentity(_start, expr) = statement { match extract_update(expr) { (Some(var), expr) => { @@ -994,11 +968,20 @@ fn extract_update(expr: Expression) -> (Option, Expr mod test { use std::fs; - use pretty_assertions::assert_eq; + use ast::analysis::AnalysisASMFile; + use number::{FieldElement, GoldilocksField}; - use number::GoldilocksField; + use analysis::analyze; + use parser::parse_asm; - use super::compile; + use crate::compile; + + fn parse_and_analyse<'a, T: FieldElement>( + file_name: Option<&str>, + input: &'a str, + ) -> AnalysisASMFile { + analyze(parse_asm(file_name, input).unwrap()).unwrap() + } #[test] pub fn compile_simple_sum() { @@ -1054,7 +1037,8 @@ pol constant p_reg_write_X_CNT = [1, 0, 0, 0, 0, 0, 0, 0] + [0]*; "#; let file_name = "../test_data/asm/simple_sum.asm"; let contents = fs::read_to_string(file_name).unwrap(); - let pil = compile::(Some(file_name), &contents).unwrap(); + let analysed = parse_and_analyse::(Some(file_name), &contents); + let pil = compile(analysed); assert_eq!(format!("{pil}").trim(), expectation.trim()); } @@ -1092,7 +1076,9 @@ pol constant p_instr_inc_fp = [1, 0] + [0]*; pol constant p_instr_inc_fp_param_amount = [7, 0] + [0]*; { pc, instr_inc_fp, instr_inc_fp_param_amount, instr_adjust_fp, instr_adjust_fp_param_amount, instr_adjust_fp_param_t } in { p_line, p_instr_inc_fp, p_instr_inc_fp_param_amount, p_instr_adjust_fp, p_instr_adjust_fp_param_amount, p_instr_adjust_fp_param_t }; "#; - let pil = compile::(None, source).unwrap(); + let parsed = parse_asm::(None, source).unwrap(); + let analysis = analyze(parsed).unwrap(); + let pil = compile(analysis); assert_eq!(format!("{pil}").trim(), expectation.trim()); } @@ -1107,6 +1093,8 @@ instr instro x: unsigned { pc' = pc + x } instro 9223372034707292161; "#; - compile::(None, source).unwrap(); + let parsed = parse_asm::(None, source).unwrap(); + let analysis = analyze(parsed).unwrap(); + let _ = compile(analysis); } } diff --git a/test_data/pil/simple_sum_asm.pil b/test_data/pil/simple_sum_asm.pil deleted file mode 100644 index 44f5b5949..000000000 --- a/test_data/pil/simple_sum_asm.pil +++ /dev/null @@ -1,38 +0,0 @@ -namespace Main(2**10); - col witness XInv; - col witness XIsZero; - XIsZero * (1 - XIsZero) = 0; - -assembly { - reg pc[@pc]; - reg X[<=]; - reg A; - reg CNT; - - pil { - // Just to test if pil-inside-assembly-inside-pil works. - XIsZero = 1 - X * XInv; - XIsZero * X = 0; - } - - instr jmpz X, l: label { pc' = XIsZero * l + (1 - XIsZero) * (pc + 1) } - instr jmp l: label { pc' = l } - instr dec_CNT { CNT' = CNT - 1 } - instr assert_zero X { XIsZero = 1 } - - CNT <=X= ${ ("input", 1) }; - - start:: - jmpz CNT, check; - A <=X= A + ${ ("input", CNT + 1) }; - // Could use "CNT <=X= CNT - 1", but that would need X. - dec_CNT; - jmp start; - - check:: - A <=X= A - ${ ("input", 0) }; - assert_zero A; - - end:: - jmp end; -}; \ No newline at end of file diff --git a/test_data/pil/simple_sum_asm_macro.pil b/test_data/pil/simple_sum_asm_macro.pil deleted file mode 100644 index 43454380e..000000000 --- a/test_data/pil/simple_sum_asm_macro.pil +++ /dev/null @@ -1,42 +0,0 @@ -namespace Main(2**10); - col witness XInv; - col witness XIsZero; - XIsZero * (1 - XIsZero) = 0; - - macro if_then_else(condition, true_value, false_value) { condition * true_value + (1 - condition) * false_value }; - macro jump_to(target) { pc' = target; }; - macro jump_to_if(condition, target) { jump_to(if_then_else(condition, target, pc + 1)); }; - -assembly { - reg pc[@pc]; - reg X[<=]; - reg A; - reg CNT; - - pil { - // Just to test if pil-inside-assembly-inside-pil works. - XIsZero = 1 - X * XInv; - XIsZero * X = 0; - } - - instr jmpz X, l: label { jump_to_if(XIsZero, l) } - instr jmp l: label { jump_to(l) } - instr dec_CNT { CNT' = CNT - 1 } - instr assert_zero X { XIsZero = 1 } - - CNT <=X= ${ ("input", 1) }; - - start:: - jmpz CNT, check; - A <=X= A + ${ ("input", CNT + 1) }; - // Could use "CNT <=X= CNT - 1", but that would need X. - dec_CNT; - jmp start; - - check:: - A <=X= A - ${ ("input", 0) }; - assert_zero A; - - end:: - jmp end; -}; \ No newline at end of file