diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index aaf943ce7..e4d812253 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -71,6 +71,7 @@ impl Display for FunctionValueDefinition { write!(f, " = {}", items.iter().map(|i| i.to_string()).join(" + ")) } FunctionValueDefinition::Query(e) => write!(f, "(i) query {e}"), + FunctionValueDefinition::Expression(e) => write!(f, " = {e}"), } } } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index c7565505e..30c27bc2f 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -61,6 +61,12 @@ impl Analyzed { self.definitions_in_source_order(PolynomialType::Committed) } + pub fn intermediate_polys_in_source_order( + &self, + ) -> Vec<&(Polynomial, Option>)> { + self.definitions_in_source_order(PolynomialType::Intermediate) + } + pub fn definitions_in_source_order( &self, poly_type: PolynomialType, @@ -96,12 +102,12 @@ impl Analyzed { /// so that they are contiguous again. /// There must not be any reference to the removed polynomials left. pub fn remove_polynomials(&mut self, to_remove: &BTreeSet) { - // TODO intermediate polys let replacements: BTreeMap = [ // We have to do it separately because we need to re-start the counter // for each kind. self.committed_polys_in_source_order(), self.constant_polys_in_source_order(), + self.intermediate_polys_in_source_order(), ] .map(|polys| { polys @@ -239,6 +245,7 @@ pub enum FunctionValueDefinition { Mapping(Expression), Array(Vec>), Query(Expression), + Expression(Expression), } /// An array of elements that might be repeated. diff --git a/ast/src/analyzed/util.rs b/ast/src/analyzed/util.rs index 9636a82aa..bb7730f2c 100644 --- a/ast/src/analyzed/util.rs +++ b/ast/src/analyzed/util.rs @@ -22,6 +22,7 @@ where .iter_mut() .flat_map(|e| e.pattern.iter_mut()) .try_for_each(|e| previsit_expression_mut(e, f)), + Some(FunctionValueDefinition::Expression(e)) => previsit_expression_mut(e, f), None => ControlFlow::Continue(()), })?; @@ -66,6 +67,7 @@ where .iter_mut() .flat_map(|e| e.pattern.iter_mut()) .try_for_each(|e| postvisit_expression_mut(e, f)), + Some(FunctionValueDefinition::Expression(e)) => postvisit_expression_mut(e, f), None => ControlFlow::Continue(()), })?; diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 920488520..074cdb932 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -370,6 +370,9 @@ impl Display for FunctionDefinition { FunctionDefinition::Query(params, value) => { write!(f, "({}) query {value}", params.join(", "),) } + FunctionDefinition::Expression(e) => { + write!(f, " = {e}") + } } } } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 8a5b5f205..d886c90ad 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -289,7 +289,6 @@ pub struct FunctionCall { } /// The definition of a function (excluding its name): -/// Either a param-value mapping or an array expression. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub enum FunctionDefinition { /// Parameter-value-mapping. @@ -298,6 +297,8 @@ pub enum FunctionDefinition { Array(ArrayExpression), /// Prover query. Query(Vec, Expression), + /// Expression, for intermediate polynomials + Expression(Expression), } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] @@ -451,6 +452,7 @@ where postvisit_expression_mut(e, f) } FunctionDefinition::Array(ae) => postvisit_expression_in_array_expression_mut(ae, f), + FunctionDefinition::Expression(e) => postvisit_expression_mut(e, f), }, PilStatement::PolynomialCommitDeclaration(_, _, None) | PilStatement::Include(_, _) diff --git a/backend/src/pilcom_cli/json_exporter/mod.rs b/backend/src/pilcom_cli/json_exporter/mod.rs index e04329f87..fc8eb1cac 100644 --- a/backend/src/pilcom_cli/json_exporter/mod.rs +++ b/backend/src/pilcom_cli/json_exporter/mod.rs @@ -34,7 +34,7 @@ pub fn export(analyzed: &Analyzed) -> JsonValue { StatementIdentifier::Definition(name) => { if let (poly, Some(value)) = &analyzed.definitions[name] { if poly.poly_type == PolynomialType::Intermediate { - if let FunctionValueDefinition::Mapping(value) = value { + if let FunctionValueDefinition::Expression(value) = value { let expression_id = exporter.extract_expression(value, 1); assert_eq!( expression_id, diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 57140fbd9..346a40d18 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -163,6 +163,7 @@ pub fn compile_asm_string( ); return Ok((pil_file_path, None)); } + fs::write(pil_file_path.clone(), format!("{pil}")).unwrap(); let pil_file_name = pil_file_path.file_name().unwrap(); diff --git a/compiler/tests/asm.rs b/compiler/tests/asm.rs index 079d8e44b..cb54d5681 100644 --- a/compiler/tests/asm.rs +++ b/compiler/tests/asm.rs @@ -162,6 +162,22 @@ fn full_pil_constant() { gen_halo2_proof(f, Default::default()); } +#[test] +#[should_panic = "assertion failed: poly.is_fixed() || poly.is_witness()"] +fn intermediate() { + let f = "intermediate.asm"; + verify_asm::(f, Default::default()); + gen_halo2_proof(f, Default::default()); +} + +#[test] +#[should_panic = "assertion failed: poly.is_fixed() || poly.is_witness()"] +fn intermediate_nested() { + let f = "intermediate_nested.asm"; + verify_asm::(f, Default::default()); + gen_halo2_proof(f, Default::default()); +} + #[test] fn book() { for f in fs::read_dir("../test_data/asm/book/").unwrap() { diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index afda27ea8..e140cce2c 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -74,6 +74,9 @@ fn generate_values( values } FunctionValueDefinition::Query(_) => panic!("Query used for fixed column."), + FunctionValueDefinition::Expression(_) => { + panic!("Expression used for fixed column, only expected for intermediate polynomials") + } } } diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index fd7e21379..22fa80e8c 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -147,7 +147,7 @@ impl PILContext { name, None, PolynomialType::Intermediate, - Some(FunctionDefinition::Mapping(vec![], value)), + Some(FunctionDefinition::Expression(value)), ); } PilStatement::PublicDeclaration(start, name, polynomial, index) => { @@ -321,12 +321,14 @@ impl PILContext { let name = poly.absolute_name.clone(); let value = value.map(|v| match v { + FunctionDefinition::Expression(expr) => { + assert!(!have_array_size); + assert!(poly.poly_type == PolynomialType::Intermediate); + FunctionValueDefinition::Expression(self.process_expression(expr)) + } FunctionDefinition::Mapping(params, expr) => { assert!(!have_array_size); - assert!( - poly.poly_type == PolynomialType::Constant - || poly.poly_type == PolynomialType::Intermediate - ); + assert!(poly.poly_type == PolynomialType::Constant); FunctionValueDefinition::Mapping(self.process_function(params, expr)) } FunctionDefinition::Query(params, expr) => { @@ -660,4 +662,20 @@ namespace T(65536); } assert_eq!(input, formatted); } + + #[test] + fn intermediate() { + let input = r#"namespace N(65536); + col witness x; + col intermediate = x; + intermediate = intermediate; +"#; + let expected = r#"namespace N(65536); + col witness x; + col intermediate = N.x; + N.intermediate = N.intermediate; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, expected); + } } diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 3c86e1e3c..7f6617681 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -96,6 +96,7 @@ fn constant_value(function: &FunctionValueDefinition) -> Opt } } FunctionValueDefinition::Query(_) => None, + FunctionValueDefinition::Expression(_) => None, } } @@ -411,6 +412,22 @@ mod test { N.Z = ((1 + N.A) * 2); N.A = (1 + N.A); N.Z = (1 + N.A); +"#; + let optimized = optimize(process_pil_file_contents::(input)).to_string(); + assert_eq!(optimized, expectation); + } + + #[test] + fn intermediate() { + let input = r#"namespace N(65536); + col witness x; + col intermediate = x; + intermediate = intermediate; + "#; + let expectation = r#"namespace N(65536); + col witness x; + col intermediate = N.x; + N.intermediate = N.intermediate; "#; let optimized = optimize(process_pil_file_contents::(input)).to_string(); assert_eq!(optimized, expectation); diff --git a/test_data/asm/intermediate.asm b/test_data/asm/intermediate.asm new file mode 100644 index 000000000..979e3f0f6 --- /dev/null +++ b/test_data/asm/intermediate.asm @@ -0,0 +1,9 @@ +machine Intermediate(latch, operation_id) { + constraints { + col fixed latch = [1]*; + col fixed operation_id = [0]*; + col witness x; + col intermediate = x; + intermediate = intermediate; + } +} \ No newline at end of file diff --git a/test_data/asm/intermediate_nested.asm b/test_data/asm/intermediate_nested.asm new file mode 100644 index 000000000..54326c671 --- /dev/null +++ b/test_data/asm/intermediate_nested.asm @@ -0,0 +1,11 @@ +machine Intermediate(latch, operation_id) { + constraints { + col fixed latch = [1]*; + col fixed operation_id = [0]*; + col witness x; + col intermediate = x; + col int2 = intermediate; + col int3 = int2 + intermediate; + int3 = 2 * x; + } +} \ No newline at end of file