From 7ecf0de94b76fe63ee080d51623009f56a46ff4e Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 15:12:57 +0200 Subject: [PATCH] Change intermediate column syntax. (#1630) Fixes #1190 Fixes https://github.com/powdr-labs/powdr/issues/1488 ``` // creates intermediate column. let x: inter = ... // same here, expects an array on the rhs let x: inter[k] = ... // Creates an intermediate column, this is printed from Analyzed col x = ...; // Creates an array of intermediate columns, this is printed from analyzed and it's actually new syntax. col x[k] = ...; // old syntax for intermediate columns, this just defines a "generic variable" after the change, // essentially an intermediate column that is always inlined. let x: expr = ...; ``` --------- Co-authored-by: Leo --- asm-to-pil/src/vm_to_constrained.rs | 2 +- ast/src/analyzed/display.rs | 2 +- ast/src/parsed/display.rs | 1 + ast/src/parsed/mod.rs | 32 +++-- ast/src/parsed/types.rs | 4 + book/src/pil/declarations.md | 6 +- book/src/pil/types.md | 25 +++- parser/src/powdr.lalrpop | 4 +- pil-analyzer/src/condenser.rs | 143 ++++++++++++++++++----- pil-analyzer/src/evaluator.rs | 52 +++++---- pil-analyzer/src/statement_processor.rs | 30 ++--- pil-analyzer/src/type_builtins.rs | 17 +-- pil-analyzer/src/type_inference.rs | 11 +- pil-analyzer/tests/condenser.rs | 71 ++++++++++- pil-analyzer/tests/parse_display.rs | 53 ++++++++- pilopt/src/lib.rs | 15 ++- std/machines/arith.asm | 26 ++--- std/machines/hash/poseidon_bn254.asm | 12 +- std/machines/hash/poseidon_gl.asm | 14 +-- std/machines/hash/poseidon_gl_memory.asm | 15 +-- test_data/pil/referencing_array.pil | 2 +- 21 files changed, 391 insertions(+), 146 deletions(-) diff --git a/asm-to-pil/src/vm_to_constrained.rs b/asm-to-pil/src/vm_to_constrained.rs index 191edb442..66dc14f4b 100644 --- a/asm-to-pil/src/vm_to_constrained.rs +++ b/asm-to-pil/src/vm_to_constrained.rs @@ -1169,7 +1169,7 @@ impl VMConverter { ); self.pil.push(PilStatement::PolynomialDefinition( SourceRef::unknown(), - intermediate_name.to_string(), + intermediate_name.clone().into(), left * right, )); (counter + 1, direct_reference(intermediate_name)) diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 1ef834d84..881ec30a0 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -108,7 +108,7 @@ impl Display for Analyzed { writeln_indented( f, format!( - "let {name}: expr[{length}] = [{}];", + "col {name}[{length}] = [{}];", definition.iter().format(", ") ), )?; diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index bd80067d7..2985e9ce2 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -925,6 +925,7 @@ impl Display for Type { Type::Fe => write!(f, "fe"), Type::String => write!(f, "string"), Type::Col => write!(f, "col"), + Type::Inter => write!(f, "inter"), Type::Expr => write!(f, "expr"), Type::Array(array) => write!(f, "{array}"), Type::Tuple(tuple) => write!(f, "{tuple}"), diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 741723245..e606b3ba4 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -68,7 +68,7 @@ pub enum PilStatement { Option>, Option, ), - PolynomialDefinition(SourceRef, String, Expression), + PolynomialDefinition(SourceRef, PolynomialName, Expression), PublicDeclaration( SourceRef, /// The name of the public value. @@ -128,7 +128,7 @@ impl PilStatement { &self, ) -> Box, SymbolCategory)> + '_> { match self { - PilStatement::PolynomialDefinition(_, name, _) + PilStatement::PolynomialDefinition(_, PolynomialName { name, .. }, _) | PilStatement::PolynomialConstantDefinition(_, name, _) | PilStatement::PublicDeclaration(_, name, _, _, _) | PilStatement::LetStatement(_, name, _, _) => { @@ -182,9 +182,12 @@ impl Children for PilStatement { PilStatement::ConnectIdentity(_start, left, right) => { Box::new(left.iter().chain(right.iter())) } - PilStatement::Expression(_, e) - | PilStatement::Namespace(_, _, Some(e)) - | PilStatement::PolynomialDefinition(_, _, e) => Box::new(once(e)), + PilStatement::Expression(_, e) | PilStatement::Namespace(_, _, Some(e)) => { + Box::new(once(e)) + } + PilStatement::PolynomialDefinition(_, PolynomialName { array_size, .. }, e) => { + Box::new(array_size.iter().chain(once(e))) + } PilStatement::EnumDeclaration(_, enum_decl) => enum_decl.children(), PilStatement::TraitImplementation(_, trait_impl) => trait_impl.children(), @@ -218,9 +221,13 @@ impl Children for PilStatement { PilStatement::ConnectIdentity(_start, left, right) => { Box::new(left.iter_mut().chain(right.iter_mut())) } - PilStatement::Expression(_, e) - | PilStatement::Namespace(_, _, Some(e)) - | PilStatement::PolynomialDefinition(_, _, e) => Box::new(once(e)), + PilStatement::Expression(_, e) | PilStatement::Namespace(_, _, Some(e)) => { + Box::new(once(e)) + } + + PilStatement::PolynomialDefinition(_, PolynomialName { array_size, .. }, e) => { + Box::new(array_size.iter_mut().chain(once(e))) + } PilStatement::EnumDeclaration(_, enum_decl) => enum_decl.children_mut(), PilStatement::TraitImplementation(_, trait_impl) => trait_impl.children_mut(), @@ -793,6 +800,15 @@ pub struct PolynomialName { pub array_size: Option, } +impl From for PolynomialName { + fn from(name: String) -> Self { + Self { + name, + array_size: None, + } + } +} + #[derive(Debug, PartialEq, Eq, Default, Clone, PartialOrd, Ord)] /// A polynomial with an optional namespace /// This is different from SymbolPath mainly due to different formatting. diff --git a/ast/src/parsed/types.rs b/ast/src/parsed/types.rs index a618e8536..056ece68a 100644 --- a/ast/src/parsed/types.rs +++ b/ast/src/parsed/types.rs @@ -25,6 +25,8 @@ pub enum Type { String, /// Column Col, + /// Intermediate column + Inter, /// Algebraic expression Expr, Array(ArrayType), @@ -48,6 +50,7 @@ impl Type { | Type::Fe | Type::String | Type::Col + | Type::Inter | Type::Expr => true, Type::Array(_) | Type::Tuple(_) @@ -248,6 +251,7 @@ impl From>> for Type { Type::Fe => Type::Fe, Type::String => Type::String, Type::Col => Type::Col, + Type::Inter => Type::Inter, Type::Expr => Type::Expr, Type::Array(a) => Type::Array(a.into()), Type::Tuple(t) => Type::Tuple(t.into()), diff --git a/book/src/pil/declarations.md b/book/src/pil/declarations.md index ca5dd2925..96a3ec836 100644 --- a/book/src/pil/declarations.md +++ b/book/src/pil/declarations.md @@ -11,12 +11,12 @@ Symbols with a generic type can be defined using ``let : The `col` type is special in that it is only used for declaring columns, but cannot appear as the type of an expression. +> In addition, there are the `col` and `inter` types, but they are special in that +> they are only used for declaring columns, but cannot appear as the type of an expression. > See [Declaring and Referencing Columns](#declaring-and-referencing-columns) for details. Powdr-pil performs Hindley-Milner type inference. This means that, similar to Rust, the type of @@ -41,14 +42,28 @@ let add_one: T -> T = |i| i + 1; ## Declaring and Referencing Columns -A symbol declared to have type `col` (or `col[k]`) is a bit special: +A symbol declared to have type `col` or `inter` (or `col[k]` / `inter[k]`) is a bit special: -If you assign it a value, that value is expected to have type `int -> fe` or `int -> int` (or an array thereof). -This allows the simple declaration of a column `let byte: col = |i| i & 0xff;` without complicated conversions. +These symbols represent columns in the arithmetization and the types of values that can be assigned to +such symbols and the references to the symbols are different from their declared type. + +If you assign a value to a `col` symbol, that value is expected to have type `int -> fe` or `int -> int` (or an array thereof). +This allows the simple declaration of a fixed column `let byte: col = |i| i & 0xff;` without complicated conversions. The integer value is converted to a field element during evaluation, but it has to be non-negative and less than the field modulus. -If you reference such a symbol, the type of the reference is `expr`. +Symbols of declared type `col` are fixed (those with value) or witness columns (those without value). + +A symbol of declared type `inter` is an intermediate column. You can assign it a value of type `expr`. +The idea of an intermediate column is that it is an algebraic expression of other columns that you do +not want to compute multiple times. + +> Note that if you use `let x: expr = a * b;`, the symbol `x` is just a name in the PIL environment, +> this will not create an intermediate column. The difference between `inter` and `expr` in this case +> is that if you use `let x: inter = ...`, the expression might not be inlined into constraints (depending on the backend), +> while if you use `let x: expr = ...`, it will always be inlined. + +If you reference a symbol of declared type `inter` or `col`, the type of the reference is `expr` (or `expr[]`). A byte constraint is as easy as `[ X ] in [ byte ]`, since the expected types in plookup columns is `expr`. The downside is that you cannot evaluate columns as functions. If you want to do that, you either have to assign a copy to an `int -> int` symbol: `let byte_f: int -> int = |i| i & 0xff; let byte: col = byte_f;`. diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index aba625929..bc55eb348 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -142,7 +142,7 @@ LetStatement: PilStatement = { } PolynomialDefinition: PilStatement = { - PolCol "=" ";" => PilStatement::PolynomialDefinition(ctx.source_ref(start, end), id, expr) + PolCol "=" ";" => PilStatement::PolynomialDefinition(ctx.source_ref(start, end), name, expr) } PublicDeclaration: PilStatement = { @@ -777,6 +777,7 @@ TypeTerm: Type = { "fe" => Type::Fe, "string" => Type::String, "col" => Type::Col, + "inter" => Type::Inter, "expr" => Type::Expr, > "[" "]" => Type::Array(ArrayType{base: Box::new(base), length}), "(" > "," )+> > ")" => { items.push(end); Type::Tuple(TupleType{items}) }, @@ -845,6 +846,7 @@ SpecialIdentifier: &'input str = { "loc", "insn", "int", + "inter", "fe", "expr", "bool", diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 5fddac3c6..aac5e4923 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -79,8 +79,8 @@ pub fn condense( && matches!( &scheme.unwrap().ty, Type::Array(ArrayType { base, length: _ }) - if base.as_ref() == &Type::Expr), - "Intermediate column type has to be expr[], but got: {}", + if base.as_ref() == &Type::Inter), + "Intermediate column type has to be inter[], but got: {}", format_type_scheme_around_name(&name, &e.type_scheme) ); let result = condenser.condense_to_array_of_algebraic_expressions(&e.e); @@ -89,8 +89,8 @@ pub fn condense( } else { assert_eq!( e.type_scheme, - Some(Type::Expr.into()), - "Intermediate column type has to be expr, but got: {}", + Some(Type::Inter.into()), + "Intermediate column type has to be inter, but got: {}", format_type_scheme_around_name(&name, &e.type_scheme) ); vec![condenser.condense_to_algebraic_expression(&e.e)] @@ -100,17 +100,28 @@ pub fn condense( } s => Some(s), }; + + let mut intermediate_values = condenser.extract_new_intermediate_column_values(); + // Extract and prepend the new columns, then identities // and finally the original statement (if it exists). let new_cols = condenser .extract_new_columns() .into_iter() .map(|new_col| { - new_columns.push(new_col.clone()); + if new_col.kind == SymbolKind::Poly(PolynomialType::Intermediate) { + let name = new_col.absolute_name.clone(); + let values = intermediate_values.remove(&name).unwrap(); + intermediate_columns.insert(name, (new_col.clone(), values)); + } else { + new_columns.push(new_col.clone()); + } StatementIdentifier::Definition(new_col.absolute_name) }) .collect::>(); + assert!(intermediate_values.is_empty(), ""); + let identity_statements = condenser .extract_new_constraints() .into_iter() @@ -121,8 +132,8 @@ pub fn condense( }) .collect::>(); - for (name, hint) in condenser.extract_new_column_values() { - if new_values.insert(name.clone(), hint).is_some() { + for (name, value) in condenser.extract_new_column_values() { + if new_values.insert(name.clone(), value).is_some() { panic!("Column {name} already has a hint set, but tried to add another one.",) } } @@ -187,6 +198,8 @@ pub struct Condenser<'a, T> { new_columns: Vec, /// The hints and fixed column definitions added since the last extraction. new_column_values: HashMap, + /// The values of intermediate columns generated since the last extraction. + new_intermediate_column_values: HashMap>>, /// The names of all new columns ever generated, to avoid duplicates. new_symbols: HashSet, new_constraints: Vec>, @@ -203,6 +216,7 @@ impl<'a, T: FieldElement> Condenser<'a, T> { counters, new_columns: vec![], new_column_values: Default::default(), + new_intermediate_column_values: Default::default(), new_symbols: HashSet::new(), new_constraints: vec![], } @@ -259,6 +273,13 @@ impl<'a, T: FieldElement> Condenser<'a, T> { std::mem::take(&mut self.new_column_values) } + /// Return the values of intermediate columns generated since the last call to this function. + pub fn extract_new_intermediate_column_values( + &mut self, + ) -> HashMap>> { + std::mem::take(&mut self.new_intermediate_column_values) + } + /// Returns the new constraints generated since the last call to this function. pub fn extract_new_constraints(&mut self) -> Vec> { std::mem::take(&mut self.new_constraints) @@ -346,49 +367,113 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { fn new_column( &mut self, name: &str, + ty: Option<&Type>, value: Option>>, source: SourceRef, ) -> Result>, EvalError> { let name = self.find_unused_name(name); - let kind = SymbolKind::Poly(if value.is_some() { - PolynomialType::Constant - } else { - PolynomialType::Committed - }); - let value = value - .map(|v| { - closure_to_function(&source, v.as_ref(), FunctionKind::Pure).map_err(|e| match e { - EvalError::TypeError(e) => { - EvalError::TypeError(format!("Error creating fixed column {name}: {e}")) + let mut length = None; + let mut is_array = false; + let kind = match (ty, &value) { + (Some(Type::Inter), Some(_)) => SymbolKind::Poly(PolynomialType::Intermediate), + (Some(Type::Array(ArrayType { base, length: len })), Some(_)) + if base.as_ref() == &Type::Inter => + { + is_array = true; + length = *len; + SymbolKind::Poly(PolynomialType::Intermediate) + } + (Some(Type::Col) | None, Some(_)) => SymbolKind::Poly(PolynomialType::Constant), + (Some(Type::Col) | None, None) => SymbolKind::Poly(PolynomialType::Committed), + _ => { + return Err(EvalError::TypeError(format!( + "Invalid type for new column {name}: {}.", + ty.map(|ty| ty.to_string()).unwrap_or_default(), + ))) + } + }; + + if kind == SymbolKind::Poly(PolynomialType::Intermediate) { + let expr = if is_array { + let Value::Array(exprs) = value.unwrap().as_ref().clone() else { + panic!("Expected array"); + }; + if let Some(length) = length { + if exprs.len() as u64 != length { + return Err(EvalError::TypeError(format!( + "Error creating intermediate column array {name}: Expected array of length {length} as value but it has {} elements." , + exprs.len(), + ))); } - _ => e, - }) - }) - .transpose()?; + } else { + length = Some(exprs.len() as u64); + } + exprs + .into_iter() + .map(|expr| { + let Value::Expression(expr) = expr.as_ref() else { + panic!("Expected algebraic expression"); + }; + expr.clone() + }) + .collect() + } else { + let Value::Expression(expr) = value.unwrap().as_ref().clone() else { + panic!("Expected algebraic expression"); + }; + vec![expr] + }; + self.new_intermediate_column_values + .insert(name.clone(), expr); + } else if let Some(value) = value { + let value = + closure_to_function(&source, value.as_ref(), FunctionKind::Pure).map_err(|e| { + match e { + EvalError::TypeError(e) => { + EvalError::TypeError(format!("Error creating fixed column {name}: {e}")) + } + _ => e, + } + })?; + + self.new_column_values.insert(name.clone(), value); + } let symbol = Symbol { - id: self.counters.dispense_symbol_id(kind, None), + id: self.counters.dispense_symbol_id(kind, length), source, absolute_name: name.clone(), stage: None, kind, - length: None, + length, degree: self.degree, }; self.new_symbols.insert(name.clone()); self.new_columns.push(symbol.clone()); - if let Some(value) = value { - self.new_column_values.insert(name.clone(), value); - } - Ok( + + Ok((if is_array { + Value::Array( + symbol + .array_elements() + .map(|(name, poly_id)| { + Value::Expression(AlgebraicExpression::Reference(AlgebraicReference { + name, + poly_id, + next: false, + })) + .into() + }) + .collect(), + ) + } else { Value::Expression(AlgebraicExpression::Reference(AlgebraicReference { name, poly_id: PolyID::from(&symbol), next: false, })) - .into(), - ) + }) + .into()) } fn set_hint( diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index a432f7a2a..f6e2fd416 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -15,7 +15,7 @@ use powdr_ast::{ }, parsed::{ display::quote, - types::{Type, TypeScheme}, + types::{ArrayType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, IndexAccess, LambdaExpression, LetStatementInsideBlock, MatchArm, MatchExpression, Number, Pattern, StatementInsideBlock, UnaryOperation, UnaryOperator, @@ -546,6 +546,7 @@ pub trait SymbolLookup<'a, T: FieldElement> { fn new_column( &mut self, name: &str, + _type: Option<&Type>, _value: Option>>, _source: SourceRef, ) -> Result>, EvalError> { @@ -648,27 +649,7 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { self.local_vars = new_locals; self.type_args = new_type_args; } - Operation::LetStatement(s) => { - let value = match (&s.ty, &s.value.as_ref()) { - (Some(Type::Col), value) | (None, value @ None) => { - let Pattern::Variable(_, name) = &s.pattern else { - unreachable!() - }; - self.symbols.new_column( - name, - value.map(|_| self.value_stack.pop().unwrap()), - SourceRef::unknown(), - )? - } - (_, Some(_)) => self.value_stack.pop().unwrap(), - _ => unreachable!(), - }; - self.local_vars.extend( - Value::try_match_pattern(&value, &s.pattern).unwrap_or_else(|| { - panic!("Irrefutable pattern did not match: {} = {value}", s.pattern) - }), - ); - } + Operation::LetStatement(s) => self.evaluate_let_statement(s)?, Operation::AddConstraint => { let result = self.value_stack.pop().unwrap(); match result.as_ref() { @@ -797,6 +778,33 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> { Ok(()) } + fn evaluate_let_statement( + &mut self, + s: &'a LetStatementInsideBlock, + ) -> Result<(), EvalError> { + let value = if s.value.is_none() + || matches!(&s.ty, Some(Type::Col) | Some(Type::Inter)) + || matches!(&s.ty, Some(Type::Array(ArrayType { base, .. })) if matches!(base.as_ref(), Type::Col | Type::Inter)) + { + // Dynamic column creation + let Pattern::Variable(_, name) = &s.pattern else { + unreachable!() + }; + let value = s.value.as_ref().map(|_| self.value_stack.pop().unwrap()); + self.symbols + .new_column(name, s.ty.as_ref(), value, SourceRef::unknown())? + } else { + // Regular local variable declaration. + self.value_stack.pop().unwrap() + }; + self.local_vars.extend( + Value::try_match_pattern(&value, &s.pattern).unwrap_or_else(|| { + panic!("Irrefutable pattern did not match: {} = {value}", s.pattern) + }), + ); + Ok(()) + } + fn evaluate_reference( &mut self, reference: &'a Reference, diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index d69443c4b..cc03a1dd5 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -124,15 +124,17 @@ where PilStatement::Namespace(_, _, _) => { panic!("Namespaces must be handled outside the statement processor.") } - PilStatement::PolynomialDefinition(source, name, value) => self - .handle_symbol_definition( + PilStatement::PolynomialDefinition(source, name, value) => { + let (name, ty) = self.name_and_type_from_polynomial_name(name, Type::Inter); + self.handle_symbol_definition( source, name, SymbolKind::Poly(PolynomialType::Intermediate), None, - Some(Type::Expr.into()), + ty, Some(FunctionDefinition::Expression(value)), - ), + ) + } PilStatement::PublicDeclaration(source, name, polynomial, array_index, index) => { self.handle_public_declaration(source, name, polynomial, array_index, index) } @@ -167,14 +169,14 @@ where ) => { assert!(polynomials.len() == 1); let (name, ty) = - self.name_and_type_from_polynomial_name(polynomials.pop().unwrap()); + self.name_and_type_from_polynomial_name(polynomials.pop().unwrap(), Type::Col); self.handle_symbol_definition( source, name, SymbolKind::Poly(PolynomialType::Committed), stage, - ty.map(Into::into), + ty, Some(definition), ) } @@ -207,9 +209,10 @@ where fn name_and_type_from_polynomial_name( &mut self, PolynomialName { name, array_size }: PolynomialName, - ) -> (String, Option) { + base_type: Type, + ) -> (String, Option) { let ty = Some(match array_size { - None => Type::Col, + None => base_type.into(), Some(len) => { let length = untyped_evaluator::evaluate_expression_to_int(self.driver, len) .map(|length| { @@ -222,9 +225,10 @@ where }) .ok(); Type::Array(ArrayType { - base: Box::new(Type::Col), + base: Box::new(base_type), length, }) + .into() } }); (name, ty) @@ -313,13 +317,13 @@ where return SymbolKind::Other(); } match &ts.ty { - Type::Expr => SymbolKind::Poly(PolynomialType::Intermediate), + Type::Inter => SymbolKind::Poly(PolynomialType::Intermediate), Type::Col => SymbolKind::Poly(PolynomialType::Constant), Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Col => { // Array of fixed columns SymbolKind::Poly(PolynomialType::Constant) } - Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Expr => { + Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Inter => { SymbolKind::Poly(PolynomialType::Intermediate) } // Otherwise, treat it as "generic definition" @@ -390,13 +394,13 @@ where polynomials .into_iter() .flat_map(|poly_name| { - let (name, ty) = self.name_and_type_from_polynomial_name(poly_name); + let (name, ty) = self.name_and_type_from_polynomial_name(poly_name, Type::Col); self.handle_symbol_definition( source.clone(), name, SymbolKind::Poly(polynomial_type), stage, - ty.map(Into::into), + ty, None, ) }) diff --git a/pil-analyzer/src/type_builtins.rs b/pil-analyzer/src/type_builtins.rs index 0e150d8b3..a64170abc 100644 --- a/pil-analyzer/src/type_builtins.rs +++ b/pil-analyzer/src/type_builtins.rs @@ -17,21 +17,16 @@ pub fn type_for_reference(declared: &Type) -> Type { match declared { // References to columns are exprs Type::Col => Type::Expr, - // Similar for arrays of columns - Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Col => { + Type::Inter => Type::Expr, + // Similarly to arrays of columns, we ignore the length. + Type::Array(ArrayType { base, length: _ }) + if matches!(base.as_ref(), &Type::Col | &Type::Inter) => + { Type::Array(ArrayType { base: Type::Expr.into(), length: None, }) } - // Arrays of intermediate columns lose their length. - Type::Array(ArrayType { - base, - length: Some(_), - }) if base.as_ref() == &Type::Expr => Type::Array(ArrayType { - base: base.clone(), - length: None, - }), t => t.clone(), } } @@ -156,7 +151,7 @@ pub fn elementary_type_bounds(ty: &Type) -> &'static [&'static str] { "Neg", "Eq", ], - Type::Col => &[], + Type::Col | Type::Inter => &[], Type::Array(_) => &["Add"], Type::Tuple(_) => &[], Type::Function(_) => &[], diff --git a/pil-analyzer/src/type_inference.rs b/pil-analyzer/src/type_inference.rs index 18406b1a3..6c4644cf8 100644 --- a/pil-analyzer/src/type_inference.rs +++ b/pil-analyzer/src/type_inference.rs @@ -271,14 +271,15 @@ impl TypeChecker { }); self.expect_type_allow_fe_or_int(&arr, value, &return_type) } - Type::Array(ArrayType { - base, - length: Some(_), - }) if base.as_ref() == &Type::Expr => { + Type::Inter => { + // Values of intermediate columns have type `expr` + self.expect_type(&Type::Expr, value) + } + Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Inter => { // An array of intermediate columns with fixed length. We ignore the length. // The condenser will have to check the actual length. let arr = Type::Array(ArrayType { - base: base.clone(), + base: Type::Expr.into(), length: None, }); self.expect_type(&arr, value) diff --git a/pil-analyzer/tests/condenser.rs b/pil-analyzer/tests/condenser.rs index 0a5282261..da0034a71 100644 --- a/pil-analyzer/tests/condenser.rs +++ b/pil-analyzer/tests/condenser.rs @@ -326,7 +326,7 @@ fn set_hint_no_col() { enum Query { Hint(fe), None, } namespace N(16); let x; - let y: expr = x; + let y: inter = x; std::prover::set_hint(y, query |_| std::prover::Query::Hint(1)); "#; analyze_string::(input); @@ -406,3 +406,72 @@ namespace N(16); let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, expected); } + +#[test] +fn intermediate_syntax() { + let input = r#"namespace N(65536); + col witness x[5]; + let inter: inter = x[2]; + let inter_arr: inter[5] = x; +"#; + let analyzed = analyze_string::(input); + assert_eq!(analyzed.intermediate_count(), 6); + let expected = r#"namespace N(65536); + col witness x[5]; + col inter = N.x[2]; + col inter_arr[5] = [N.x[0], N.x[1], N.x[2], N.x[3], N.x[4]]; +"#; + assert_eq!(analyzed.to_string(), expected); +} + +#[test] +fn intermediate_dynamic() { + let input = r#"namespace N(65536); + col witness x[5]; + { + let inte: inter = x[2]; + let inter_arr: inter[5] = x; + inte = 8; + inter_arr[3] = 9; + }; +"#; + let analyzed = analyze_string::(input); + assert_eq!(analyzed.intermediate_count(), 6); + let expected = r#"namespace N(65536); + col witness x[5]; + col inte = N.x[2]; + col inter_arr[5] = [N.x[0], N.x[1], N.x[2], N.x[3], N.x[4]]; + N.inte = 8; + N.inter_arr[3] = 9; +"#; + assert_eq!(analyzed.to_string(), expected); +} + +#[test] +fn intermediate_arr_no_length() { + let input = r#"namespace N(65536); + col witness x[5]; + { + let inte: inter[] = x; + }; +"#; + let analyzed = analyze_string::(input); + assert_eq!(analyzed.intermediate_count(), 5); + let expected = r#"namespace N(65536); + col witness x[5]; + col inte[5] = [N.x[0], N.x[1], N.x[2], N.x[3], N.x[4]]; +"#; + assert_eq!(analyzed.to_string(), expected); +} + +#[test] +#[should_panic = "Error creating intermediate column array N.inte: Expected array of length 6 as value but it has 2 elements."] +fn intermediate_arr_wrong_length() { + let input = r#"namespace N(65536); + col witness x[2]; + { + let inte: inter[6] = x; + }; +"#; + analyze_string::(input); +} diff --git a/pil-analyzer/tests/parse_display.rs b/pil-analyzer/tests/parse_display.rs index 780f2aa20..afa7018ce 100644 --- a/pil-analyzer/tests/parse_display.rs +++ b/pil-analyzer/tests/parse_display.rs @@ -83,6 +83,28 @@ fn intermediate() { assert_eq!(formatted, expected); } +#[test] +fn intermediate_array() { + let input = r#"namespace N(65536); + col witness x; + col intermediate[3] = [x, x + 2, x * x]; + intermediate[0] = intermediate[0]; + intermediate[1] = intermediate[1]; + intermediate[2] = intermediate[2]; +"#; + let expected = r#"namespace N(65536); + col witness x; + col intermediate[3] = [N.x, N.x + 2, N.x * N.x]; + N.intermediate[0] = N.intermediate[0]; + N.intermediate[1] = N.intermediate[1]; + N.intermediate[2] = N.intermediate[2]; +"#; + let analyzed = analyze_string::(input); + assert_eq!(analyzed.intermediate_count(), 3); + let formatted = analyzed.to_string(); + assert_eq!(formatted, expected); +} + #[test] fn intermediate_nested() { let input = r#"namespace N(65536); @@ -456,14 +478,16 @@ fn challenges() { x' = (x + 1) * (1 - first); y' = (x + a) * (1 - first); "; - let formatted = analyze_string::(input).to_string(); + let analyzed = analyze_string::(input); + assert_eq!(analyzed.intermediate_count(), 0); + let formatted = analyzed.to_string(); let expected = r#"namespace Main(8); col fixed first = [1] + [0]*; col witness x; col witness stage(2) y; - col a = std::prelude::challenge(2, 1); + let a: expr = std::prelude::challenge(2, 1); Main.x' = (Main.x + 1) * (1 - Main.first); - Main.y' = (Main.x + Main.a) * (1 - Main.first); + Main.y' = (Main.x + std::prelude::challenge(2, 1)) * (1 - Main.first); "#; assert_eq!(formatted, expected); } @@ -830,7 +854,7 @@ namespace Main(16); fn reparse_array_typed_intermediate_col() { let input = r#"namespace Main(16); col witness w; - let clocks: expr[4] = [Main.w, Main.w, Main.w, Main.w]; + col clocks[4] = [Main.w, Main.w, Main.w, Main.w]; "#; let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, input); @@ -849,3 +873,24 @@ fn reparse_type_args_generic_enum() { let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, input); } + +#[test] +fn intermediate_syntax() { + let input = r#"namespace X(16); + let w; + let a: inter = w; + let b: inter[1] = [w]; + col c = w; + col d[1] = [w]; +"#; + let expected = r#"namespace X(16); + col witness w; + col a = X.w; + col b[1] = [X.w]; + col c = X.w; + col d[1] = [X.w]; +"#; + let analyzed = analyze_string::(input); + assert_eq!(analyzed.intermediate_count(), 4); + assert_eq!(analyzed.to_string(), expected); +} diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 1b6063b7b..e7a8c40c0 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -721,18 +721,17 @@ namespace N(65536); fn remove_unreferenced_parts_of_arrays() { let input = r#"namespace N(65536); col witness x[5]; - let inter: expr[5] = x; - x[2] = inter[4]; + let inte: inter[5] = x; + x[2] = inte[4]; "#; - // If we change the handling of intermediate columns, - // make sure that "inter" is still an array of intermediate columns. let expectation = r#"namespace N(65536); col witness x[5]; - let inter: expr[5] = [N.x[0], N.x[1], N.x[2], N.x[3], N.x[4]]; - N.x[2] = N.inter[4]; + col inte[5] = [N.x[0], N.x[1], N.x[2], N.x[3], N.x[4]]; + N.x[2] = N.inte[4]; "#; - let optimized = optimize(analyze_string::(input)).to_string(); - assert_eq!(optimized, expectation); + let optimized = optimize(analyze_string::(input)); + assert_eq!(optimized.intermediate_count(), 5); + assert_eq!(optimized.to_string(), expectation); } #[test] diff --git a/std/machines/arith.asm b/std/machines/arith.asm index fdb21de32..994218f8a 100644 --- a/std/machines/arith.asm +++ b/std/machines/arith.asm @@ -172,7 +172,7 @@ machine Arith with col witness y1_14(i) query hint_if_eq0(quotient_hint, 14); col witness y1_15(i) query hint_if_eq0(quotient_hint, 15); - let y1: expr[16] = [y1_0, y1_1, y1_2, y1_3, y1_4, y1_5, y1_6, y1_7, y1_8, y1_9, y1_10, y1_11, y1_12, y1_13, y1_14, y1_15]; + let y1: expr[] = [y1_0, y1_1, y1_2, y1_3, y1_4, y1_5, y1_6, y1_7, y1_8, y1_9, y1_10, y1_11, y1_12, y1_13, y1_14, y1_15]; col witness x2_0(i) query hint_if_eq0(remainder_hint, 0); col witness x2_1(i) query hint_if_eq0(remainder_hint, 1); @@ -191,7 +191,7 @@ machine Arith with col witness x2_14(i) query hint_if_eq0(remainder_hint, 14); col witness x2_15(i) query hint_if_eq0(remainder_hint, 15); - let x2: expr[16] = [x2_0, x2_1, x2_2, x2_3, x2_4, x2_5, x2_6, x2_7, x2_8, x2_9, x2_10, x2_11, x2_12, x2_13, x2_14, x2_15]; + let x2: expr[] = [x2_0, x2_1, x2_2, x2_3, x2_4, x2_5, x2_6, x2_7, x2_8, x2_9, x2_10, x2_11, x2_12, x2_13, x2_14, x2_15]; col witness s_0(i) query Query::Hint(fe(select_limb(s_hint(), 0))); col witness s_1(i) query Query::Hint(fe(select_limb(s_hint(), 1))); @@ -210,7 +210,7 @@ machine Arith with col witness s_14(i) query Query::Hint(fe(select_limb(s_hint(), 14))); col witness s_15(i) query Query::Hint(fe(select_limb(s_hint(), 15))); - let s: expr[16] = [s_0, s_1, s_2, s_3, s_4, s_5, s_6, s_7, s_8, s_9, s_10, s_11, s_12, s_13, s_14, s_15]; + let s: expr[] = [s_0, s_1, s_2, s_3, s_4, s_5, s_6, s_7, s_8, s_9, s_10, s_11, s_12, s_13, s_14, s_15]; col witness q0_0(i) query Query::Hint(fe(select_limb(q0_hint(), 0))); col witness q0_1(i) query Query::Hint(fe(select_limb(q0_hint(), 1))); @@ -229,7 +229,7 @@ machine Arith with col witness q0_14(i) query Query::Hint(fe(select_limb(q0_hint(), 14))); col witness q0_15(i) query Query::Hint(fe(select_limb(q0_hint(), 15))); - let q0: expr[16] = [q0_0, q0_1, q0_2, q0_3, q0_4, q0_5, q0_6, q0_7, q0_8, q0_9, q0_10, q0_11, q0_12, q0_13, q0_14, q0_15]; + let q0: expr[] = [q0_0, q0_1, q0_2, q0_3, q0_4, q0_5, q0_6, q0_7, q0_8, q0_9, q0_10, q0_11, q0_12, q0_13, q0_14, q0_15]; col witness q1_0(i) query Query::Hint(fe(select_limb(q1_hint(), 0))); col witness q1_1(i) query Query::Hint(fe(select_limb(q1_hint(), 1))); @@ -248,7 +248,7 @@ machine Arith with col witness q1_14(i) query Query::Hint(fe(select_limb(q1_hint(), 14))); col witness q1_15(i) query Query::Hint(fe(select_limb(q1_hint(), 15))); - let q1: expr[16] = [q1_0, q1_1, q1_2, q1_3, q1_4, q1_5, q1_6, q1_7, q1_8, q1_9, q1_10, q1_11, q1_12, q1_13, q1_14, q1_15]; + let q1: expr[] = [q1_0, q1_1, q1_2, q1_3, q1_4, q1_5, q1_6, q1_7, q1_8, q1_9, q1_10, q1_11, q1_12, q1_13, q1_14, q1_15]; col witness q2_0(i) query Query::Hint(fe(select_limb(q2_hint(), 0))); col witness q2_1(i) query Query::Hint(fe(select_limb(q2_hint(), 1))); @@ -267,16 +267,16 @@ machine Arith with col witness q2_14(i) query Query::Hint(fe(select_limb(q2_hint(), 14))); col witness q2_15(i) query Query::Hint(fe(select_limb(q2_hint(), 15))); - let q2: expr[16] = [q2_0, q2_1, q2_2, q2_3, q2_4, q2_5, q2_6, q2_7, q2_8, q2_9, q2_10, q2_11, q2_12, q2_13, q2_14, q2_15]; + let q2: expr[] = [q2_0, q2_1, q2_2, q2_3, q2_4, q2_5, q2_6, q2_7, q2_8, q2_9, q2_10, q2_11, q2_12, q2_13, q2_14, q2_15]; let combine: expr[] -> expr[] = |x| array::new(array::len(x) / 2, |i| x[2 * i + 1] * 2**16 + x[2 * i]); // Intermediate polynomials, arrays of 8 columns, 32 bit per column. - let x1c: expr[8] = combine(x1); - let y1c: expr[8] = combine(y1); - let x2c: expr[8] = combine(x2); - let y2c: expr[8] = combine(y2); - let x3c: expr[8] = combine(x3); - let y3c: expr[8] = combine(y3); + col x1c[8] = combine(x1); + col y1c[8] = combine(y1); + col x2c[8] = combine(x2); + col y2c[8] = combine(y2); + col x3c[8] = combine(x3); + col y3c[8] = combine(y3); let CLK32: col[32] = array::new(32, |i| |row| if row % 32 == i { 1 } else { 0 }); let CLK32_31: expr = CLK32[31]; @@ -449,7 +449,7 @@ machine Arith with link => byte2.check(carry_high[2]); // Carries can be any integer in the range [-2**31, 2**31 - 1) - let carry: expr[3] = array::new(3, |i| carry_high[i] * 2**16 + carry_low[i] - 2 ** 31); + let carry = array::new(3, |i| carry_high[i] * 2**16 + carry_low[i] - 2 ** 31); array::map(carry, |c| c * CLK32[0] = 0); diff --git a/std/machines/hash/poseidon_bn254.asm b/std/machines/hash/poseidon_bn254.asm index 3d79ad63a..3c5c54961 100644 --- a/std/machines/hash/poseidon_bn254.asm +++ b/std/machines/hash/poseidon_bn254.asm @@ -60,15 +60,15 @@ machine PoseidonBN254 with pol commit output[OUTPUT_SIZE]; // Add round constants - let a: expr[STATE_SIZE] = array::zip(state, C, |state, C| state + C); + let a = array::zip(state, C, |state, C| state + C); // Compute S-Boxes (x^5) - let x2: expr[STATE_SIZE] = array::map(a, |a| a * a); - let x4: expr[STATE_SIZE] = array::map(x2, |x2| x2 * x2); - let x5: expr[STATE_SIZE] = array::zip(x4, a, |x4, a| x4 * a); + let x2: inter[STATE_SIZE] = array::map(a, |a| a * a); + let x4: inter[STATE_SIZE] = array::map(x2, |x2| x2 * x2); + let x5: inter[STATE_SIZE] = array::zip(x4, a, |x4, a| x4 * a); // Apply S-Boxes on the first element and otherwise if it is a full round. - let b: expr[STATE_SIZE] = array::new(STATE_SIZE, |i| if i == 0 { + let b = array::new(STATE_SIZE, |i| if i == 0 { x5[i] } else { PARTIAL * (a[i] - x5[i]) + x5[i] @@ -83,7 +83,7 @@ machine PoseidonBN254 with // Multiply with MDS Matrix let dot_product = |v1, v2| array::sum(array::zip(v1, v2, |v1_i, v2_i| v1_i * v2_i)); - let c: expr[STATE_SIZE] = array::map(M, |M_row_i| dot_product(M_row_i, b)); + let c = array::map(M, |M_row_i| dot_product(M_row_i, b)); // Copy c to state in the next row array::zip(state, c, |state, c| (state' - c) * (1-LAST) = 0); diff --git a/std/machines/hash/poseidon_gl.asm b/std/machines/hash/poseidon_gl.asm index be7b8bc57..e3080b444 100644 --- a/std/machines/hash/poseidon_gl.asm +++ b/std/machines/hash/poseidon_gl.asm @@ -64,16 +64,16 @@ machine PoseidonGL with pol commit output[OUTPUT_SIZE]; // Add round constants - let a: expr[STATE_SIZE] = array::zip(state, C, |state, C| state + C); + let a = array::zip(state, C, |state, C| state + C); // Compute S-Boxes (x^7) - let x2: expr[STATE_SIZE] = array::map(a, |a| a * a); - let x4: expr[STATE_SIZE] = array::map(x2, |x2| x2 * x2); - let x6: expr[STATE_SIZE] = array::zip(x4, x2, |x4, x2| x4 * x2); - let x7: expr[STATE_SIZE] = array::zip(x6, a, |x6, a| x6 * a); + let x2: inter[STATE_SIZE] = array::map(a, |a| a * a); + let x4: inter[STATE_SIZE] = array::map(x2, |x2| x2 * x2); + let x6: inter[STATE_SIZE] = array::zip(x4, x2, |x4, x2| x4 * x2); + let x7: inter[STATE_SIZE] = array::zip(x6, a, |x6, a| x6 * a); // Apply S-Boxes on the first element and otherwise if it is a full round. - let b: expr[STATE_SIZE] = array::new(STATE_SIZE, |i| if i == 0 { + let b = array::new(STATE_SIZE, |i| if i == 0 { x7[i] } else { PARTIAL * (a[i] - x7[i]) + x7[i] @@ -97,7 +97,7 @@ machine PoseidonGL with // Multiply with MDS Matrix let dot_product = |v1, v2| array::sum(array::zip(v1, v2, |v1_i, v2_i| v1_i * v2_i)); - let c: expr[STATE_SIZE] = array::map(M, |M_row_i| dot_product(M_row_i, b)); + let c = array::map(M, |M_row_i| dot_product(M_row_i, b)); // Copy c to state in the next row array::zip(state, c, |state, c| (state' - c) * (1-LAST) = 0); diff --git a/std/machines/hash/poseidon_gl_memory.asm b/std/machines/hash/poseidon_gl_memory.asm index 14a3ec490..3d15022a9 100644 --- a/std/machines/hash/poseidon_gl_memory.asm +++ b/std/machines/hash/poseidon_gl_memory.asm @@ -152,16 +152,17 @@ machine PoseidonGLMemory(mem: Memory, split_gl: SplitGL) with pol commit output[OUTPUT_SIZE]; // Add round constants - let a: expr[STATE_SIZE] = array::zip(state, C, |state, C| state + C); + // TODO should these be intermediate? + let a = array::zip(state, C, |state, C| state + C); // Compute S-Boxes (x^7) - let x2: expr[STATE_SIZE] = array::map(a, |a| a * a); - let x4: expr[STATE_SIZE] = array::map(x2, |x2| x2 * x2); - let x6: expr[STATE_SIZE] = array::zip(x4, x2, |x4, x2| x4 * x2); - let x7: expr[STATE_SIZE] = array::zip(x6, a, |x6, a| x6 * a); + let x2: inter[STATE_SIZE] = array::map(a, |a| a * a); + let x4: inter[STATE_SIZE] = array::map(x2, |x2| x2 * x2); + let x6: inter[STATE_SIZE] = array::zip(x4, x2, |x4, x2| x4 * x2); + let x7: inter[STATE_SIZE] = array::zip(x6, a, |x6, a| x6 * a); // Apply S-Boxes on the first element and otherwise if it is a full round. - let b: expr[STATE_SIZE] = array::new(STATE_SIZE, |i| if i == 0 { + let b: expr[] = array::new(STATE_SIZE, |i| if i == 0 { x7[i] } else { PARTIAL * (a[i] - x7[i]) + x7[i] @@ -185,7 +186,7 @@ machine PoseidonGLMemory(mem: Memory, split_gl: SplitGL) with // Multiply with MDS Matrix let dot_product = |v1, v2| array::sum(array::zip(v1, v2, |v1_i, v2_i| v1_i * v2_i)); - let c: expr[STATE_SIZE] = array::map(M, |M_row_i| dot_product(M_row_i, b)); + let c = array::map(M, |M_row_i| dot_product(M_row_i, b)); // Copy c to state in the next row array::zip(state, c, |state, c| (state' - c) * (1-LAST) = 0); diff --git a/test_data/pil/referencing_array.pil b/test_data/pil/referencing_array.pil index 5316e511b..7c141e579 100644 --- a/test_data/pil/referencing_array.pil +++ b/test_data/pil/referencing_array.pil @@ -23,6 +23,6 @@ namespace Main(N); wit[0] + wit[1] + wit[2] + wit[3] = 1; // intermediate poly array - let inter: expr[4] = make_array(4, |i| wit[i] + 10); + let inter: inter[4] = make_array(4, |i| wit[i] + 10); make_array(4, |k| inter[k] = clocks[k] + 10);