From d94db64b6f5f12021663ddb4b4f699cce0843f04 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 2 Nov 2023 16:41:34 +0100 Subject: [PATCH] Support arrays in witgen. --- ast/src/analyzed/display.rs | 22 ++++--- ast/src/analyzed/mod.rs | 73 +++++++++++++++++------ backend/src/pilstark/json_exporter/mod.rs | 42 +++++++------ compiler/src/util.rs | 7 ++- compiler/tests/pil.rs | 8 +++ executor/src/witgen/mod.rs | 50 +++++++--------- executor/src/witgen/query_processor.rs | 1 - executor/src/witgen/util.rs | 19 +----- halo2/src/circuit_builder.rs | 2 - halo2/src/mock_prover.rs | 2 - number/src/serialize.rs | 8 ++- pil_analyzer/src/condenser.rs | 44 ++++++++++---- pil_analyzer/src/pil_analyzer.rs | 34 +++++++++++ pilopt/src/lib.rs | 9 +-- test_data/pil/fib_arrays.pil | 19 ++++++ 15 files changed, 215 insertions(+), 125 deletions(-) create mode 100644 test_data/pil/fib_arrays.pil diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index db01efc11..01101ae78 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -38,6 +38,9 @@ impl Display for Analyzed { PolynomialType::Intermediate => panic!(), }; write!(f, " col {kind}{name}")?; + if let Some(length) = symbol.length { + write!(f, "[{length}]")?; + } if let Some(value) = definition { writeln!(f, "{value};")? } else { @@ -74,8 +77,12 @@ impl Display for Analyzed { let (name, _) = update_namespace(&decl.name, 0, f)?; writeln!( f, - " public {name} = {}({});", - decl.polynomial, decl.index + " public {name} = {}{}({});", + decl.polynomial, + decl.array_index + .map(|i| format!("[{i}]")) + .unwrap_or_default(), + decl.index )?; } StatementIdentifier::Identity(i) => writeln!(f, " {}", &self.identities[*i])?, @@ -206,16 +213,7 @@ impl Display for AlgebraicBinaryOperator { impl Display for AlgebraicReference { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}{}{}", - self.name, - self.index - .as_ref() - .map(|s| format!("[{s}]")) - .unwrap_or_default(), - if self.next { "'" } else { "" }, - ) + write!(f, "{}{}", self.name, if self.next { "'" } else { "" },) } } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 17411ec5e..b20468d3f 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -121,6 +121,7 @@ impl Analyzed { /// Removes the specified polynomials and updates the IDs of the other polynomials /// so that they are contiguous again. /// There must not be any reference to the removed polynomials left. + /// Does not support arrays or array elements. pub fn remove_polynomials(&mut self, to_remove: &BTreeSet) { let mut replacements: BTreeMap = [ // We have to do it separately because we need to re-start the counter @@ -135,17 +136,19 @@ impl Analyzed { (0, BTreeMap::new()), |(shift, mut replacements), (poly, _def)| { let poly_id = poly.into(); + let length = poly.length.unwrap_or(1); if to_remove.contains(&poly_id) { - let length = poly.length.unwrap_or(1); (shift + length, replacements) } else { - replacements.insert( - poly_id, - PolyID { - id: poly_id.id - shift, - ..poly_id - }, - ); + for (_name, id) in poly.array_elements() { + replacements.insert( + id, + PolyID { + id: id.id - shift, + ..id + }, + ); + } (shift, replacements) } }, @@ -306,6 +309,46 @@ impl Symbol { pub fn is_array(&self) -> bool { self.length.is_some() } + /// Returns an iterator producing either just the symbol (if it is not an array), + /// or all the elements of the array with their names in the form `array[index]`. + pub fn array_elements(&self) -> impl Iterator + '_ { + let SymbolKind::Poly(ptype) = self.kind else { + panic!("Expected polynomial."); + }; + let length = self.length.unwrap_or(1); + (0..length).map(move |i| { + ( + self.array_element_name(i), + PolyID { + id: self.id + i, + ptype, + }, + ) + }) + } + + /// Returns "name[index]" if this is an array or just "name" otherwise. + /// In the second case, requires index to be zero and otherwise + /// requires index to be less than length. + pub fn array_element_name(&self, index: u64) -> String { + match self.length { + Some(length) => { + assert!(index < length); + format!("{}[{index}]", self.absolute_name) + } + None => self.absolute_name.to_string(), + } + } + + /// Returns "name[length]" if this is an array or just "name" otherwise. + pub fn array_name(&self) -> String { + match self.length { + Some(length) => { + format!("{}[{length}]", self.absolute_name) + } + None => self.absolute_name.to_string(), + } + } } /// The "kind" of a symbol. In the future, this will be mostly @@ -438,10 +481,11 @@ pub enum Reference { pub struct AlgebraicReference { /// Name of the polynomial - just for informational purposes. /// Comparisons are based on polynomial ID. + /// In case of an array element, this ends in `[i]`. pub name: String, - /// Identifier for a polynomial reference. + /// Identifier for a polynomial reference, already contains + /// the element offset in case of an array element. pub poly_id: PolyID, - pub index: Option, pub next: bool, } @@ -464,18 +508,12 @@ impl PartialOrd for AlgebraicReference { impl Ord for AlgebraicReference { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match self.poly_id.cmp(&other.poly_id) { - core::cmp::Ordering::Equal => {} - ord => return ord, - } - assert!(self.index.is_none() && other.index.is_none()); - self.next.cmp(&other.next) + (&self.poly_id, &self.next).cmp(&(&other.poly_id, &other.next)) } } impl PartialEq for AlgebraicReference { fn eq(&self, other: &Self) -> bool { - assert!(self.index.is_none() && other.index.is_none()); self.poly_id == other.poly_id && self.next == other.next } } @@ -483,7 +521,6 @@ impl PartialEq for AlgebraicReference { impl Hash for AlgebraicReference { fn hash(&self, state: &mut H) { self.poly_id.hash(state); - self.index.hash(state); self.next.hash(state); } } diff --git a/backend/src/pilstark/json_exporter/mod.rs b/backend/src/pilstark/json_exporter/mod.rs index de2d6bcc0..f621e9b18 100644 --- a/backend/src/pilstark/json_exporter/mod.rs +++ b/backend/src/pilstark/json_exporter/mod.rs @@ -3,9 +3,8 @@ use std::cmp; use std::collections::HashMap; use ast::analyzed::{ - AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicReference, - AlgebraicUnaryOperator, Analyzed, IdentityKind, PolyID, PolynomialType, StatementIdentifier, - SymbolKind, + AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicUnaryOperator, Analyzed, + IdentityKind, PolyID, PolynomialType, StatementIdentifier, SymbolKind, }; use starky::types::{ ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity, @@ -59,12 +58,14 @@ pub fn export(analyzed: &Analyzed) -> PIL { StatementIdentifier::PublicDeclaration(name) => { let pub_def = &analyzed.public_declarations[name]; let pub_ref = &pub_def.polynomial; - let (_, expr) = exporter.polynomial_reference_to_json(&AlgebraicReference { - name: pub_ref.name.clone(), - poly_id: pub_ref.poly_id.unwrap(), - index: pub_def.array_index, - next: false, - }); + let poly_id = pub_ref.poly_id.unwrap(); + let (_, expr) = exporter.polynomial_reference_to_json( + PolyID { + id: poly_id.id + pub_def.array_index.unwrap_or_default() as u64, + ..poly_id + }, + false, + ); let id = publics.len(); publics.push(starky::types::Public { polType: polynomial_reference_type_to_type(&expr.op).to_string(), @@ -261,7 +262,9 @@ impl<'a, T: FieldElement> Exporter<'a, T> { /// returns the degree and the JSON value (intermediate polynomial IDs) fn expression_to_json(&self, expr: &Expression) -> (u32, StarkyExpr) { match expr { - Expression::Reference(reference) => self.polynomial_reference_to_json(reference), + Expression::Reference(reference) => { + self.polynomial_reference_to_json(reference.poly_id, reference.next) + } Expression::PublicReference(name) => ( 0, StarkyExpr { @@ -326,24 +329,19 @@ impl<'a, T: FieldElement> Exporter<'a, T> { fn polynomial_reference_to_json( &self, - AlgebraicReference { - name: _, - index, - poly_id: PolyID { id, ptype }, - next, - }: &AlgebraicReference, + PolyID { id, ptype }: PolyID, + next: bool, ) -> (u32, StarkyExpr) { - let id = if *ptype == PolynomialType::Intermediate { - assert!(index.is_none()); - self.intermediate_poly_expression_ids[id] + let id = if ptype == PolynomialType::Intermediate { + self.intermediate_poly_expression_ids[&id] } else { - id + index.unwrap_or_default() as u64 + id }; let poly = StarkyExpr { id: Some(id as usize), - op: polynomial_reference_type_to_json_string(*ptype).to_string(), + op: polynomial_reference_type_to_json_string(ptype).to_string(), deg: 1, - next: Some(*next), + next: Some(next), ..DEFAULT_EXPR }; (1, poly) diff --git a/compiler/src/util.rs b/compiler/src/util.rs index 165944981..2b10a595a 100644 --- a/compiler/src/util.rs +++ b/compiler/src/util.rs @@ -35,13 +35,14 @@ pub fn read_poly_set( pil: &Analyzed, dir: &Path, ) -> (Vec<(String, Vec)>, DegreeType) { - let fixed_columns: Vec<&str> = P::get_polys(pil) + let column_names: Vec = P::get_polys(pil) .iter() - .map(|(poly, _)| poly.absolute_name.as_str()) + .flat_map(|(poly, _)| poly.array_elements()) + .map(|(name, _id)| name) .collect(); read_polys_file( &mut BufReader::new(File::open(dir.join(P::FILE_NAME)).unwrap()), - &fixed_columns, + &column_names, ) } diff --git a/compiler/tests/pil.rs b/compiler/tests/pil.rs index d0d8fc992..5d5fc3737 100644 --- a/compiler/tests/pil.rs +++ b/compiler/tests/pil.rs @@ -97,6 +97,14 @@ fn test_fibonacci_macro() { gen_estark_proof(f, Default::default()); } +#[test] +fn fib_arrays() { + let f = "fib_arrays.pil"; + verify_pil(f, None); + gen_halo2_proof(f, Default::default()); + gen_estark_proof(f, Default::default()); +} + #[test] #[should_panic = "Witness generation failed."] fn test_external_witgen_fails_if_none_provided() { diff --git a/executor/src/witgen/mod.rs b/executor/src/witgen/mod.rs index 27033c18a..58bbcc5bd 100644 --- a/executor/src/witgen/mod.rs +++ b/executor/src/witgen/mod.rs @@ -80,7 +80,7 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback> WitnessGenerator<'a, 'b, T, Q /// Generates the committed polynomial values /// @returns the values (in source order) and the degree of the polynomials. - pub fn generate(self) -> Vec<(&'a str, Vec)> { + pub fn generate(self) -> Vec<(String, Vec)> { let fixed = FixedData::new( self.analyzed, self.fixed_col_values, @@ -136,16 +136,15 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback> WitnessGenerator<'a, 'b, T, Q .chain(main_columns) .collect::>(); - // Done this way, because: - // 1. The keys need to be string references of the right lifetime. - // 2. The order needs to be the the order of declaration. + // Order columns according to the order of declaration. self.analyzed .committed_polys_in_source_order() .into_iter() - .map(|(p, _)| { - let column = columns.remove(&p.absolute_name).unwrap(); + .flat_map(|(p, _)| p.array_elements()) + .map(|(name, _id)| { + let column = columns.remove(&name).unwrap(); assert!(!column.is_empty()); - (p.absolute_name.as_str(), column) + (name, column) }) .collect() } @@ -165,25 +164,21 @@ impl<'a, T: FieldElement> FixedData<'a, T> { external_witness_values: Vec<(&'a str, Vec)>, ) -> Self { let mut external_witness_values = BTreeMap::from_iter(external_witness_values); - let witness_cols = WitnessColumnMap::from( - analyzed - .committed_polys_in_source_order() - .iter() - .enumerate() - .map(|(i, (poly, value))| { - if poly.length.is_some() { - unimplemented!("Committed arrays not implemented.") - } - assert_eq!(i as u64, poly.id); - let external_values = - external_witness_values.remove(poly.absolute_name.as_str()); - if let Some(external_values) = &external_values { - assert_eq!(external_values.len(), analyzed.degree() as usize); - } - let col = WitnessColumn::new(i, &poly.absolute_name, value, external_values); - col - }), - ); + + let witness_cols = + WitnessColumnMap::from(analyzed.committed_polys_in_source_order().iter().flat_map( + |(poly, value)| { + poly.array_elements() + .map(|(name, poly_id)| { + let external_values = external_witness_values.remove(name.as_str()); + if let Some(external_values) = &external_values { + assert_eq!(external_values.len(), analyzed.degree() as usize); + } + WitnessColumn::new(poly_id.id as usize, &name, value, external_values) + }) + .collect::>() + }, + )); if !external_witness_values.is_empty() { panic!( @@ -251,7 +246,7 @@ pub struct WitnessColumn<'a, T> { impl<'a, T> WitnessColumn<'a, T> { pub fn new( id: usize, - name: &'a str, + name: &str, value: &'a Option>, external_values: Option>, ) -> WitnessColumn<'a, T> { @@ -267,7 +262,6 @@ impl<'a, T> WitnessColumn<'a, T> { }, name: name.to_string(), next: false, - index: None, }; WitnessColumn { poly, diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index 2a511fa2d..8cab31bce 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -146,7 +146,6 @@ where let poly_ref = AlgebraicReference { name: poly.name.clone(), poly_id, - index: None, next, }; Ok(rows diff --git a/executor/src/witgen/util.rs b/executor/src/witgen/util.rs index 2a169e95e..98ae85db7 100644 --- a/executor/src/witgen/util.rs +++ b/executor/src/witgen/util.rs @@ -6,14 +6,7 @@ use ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference}; /// - not shifted with `'` /// and return the polynomial if so pub fn try_to_simple_poly(expr: &Expression) -> Option<&AlgebraicReference> { - if let Expression::Reference( - p @ AlgebraicReference { - index: None, - next: false, - .. - }, - ) = expr - { + if let Expression::Reference(p @ AlgebraicReference { next: false, .. }) = expr { Some(p) } else { None @@ -22,10 +15,7 @@ pub fn try_to_simple_poly(expr: &Expression) -> Option<&AlgebraicReference pub fn try_to_simple_poly_ref(expr: &Expression) -> Option<&AlgebraicReference> { if let Expression::Reference(poly_ref) = expr { - if poly_ref.index.is_none() && !poly_ref.next { - return Some(poly_ref); - } - None + (!poly_ref.next).then_some(poly_ref) } else { None } @@ -33,10 +23,7 @@ pub fn try_to_simple_poly_ref(expr: &Expression) -> Option<&AlgebraicRefer pub fn is_simple_poly_of_name(expr: &Expression, poly_name: &str) -> bool { if let Expression::Reference(AlgebraicReference { - name, - index: None, - next: false, - .. + name, next: false, .. }) = expr { name == poly_name diff --git a/halo2/src/circuit_builder.rs b/halo2/src/circuit_builder.rs index d0577df03..57369181a 100644 --- a/halo2/src/circuit_builder.rs +++ b/halo2/src/circuit_builder.rs @@ -239,8 +239,6 @@ fn expression_2_expr(cd: &CircuitData, expr: &Expression) match expr { Expression::Number(n) => Expr::Const(n.to_arbitrary_integer()), Expression::Reference(polyref) => { - assert_eq!(polyref.index, None); - let plonkvar = PlonkVar::Query(ColumnQuery { column: cd.col(&polyref.name), rotation: polyref.next as i32, diff --git a/halo2/src/mock_prover.rs b/halo2/src/mock_prover.rs index 89b090dce..72886845e 100644 --- a/halo2/src/mock_prover.rs +++ b/halo2/src/mock_prover.rs @@ -93,7 +93,6 @@ mod test { executor::witgen::WitnessGenerator::new(&analyzed, &fixed, query_callback).generate(); let fixed = to_owned_values(fixed); - let witness = to_owned_values(witness); mock_prove(&analyzed, &fixed, &witness); } @@ -110,7 +109,6 @@ mod test { executor::witgen::WitnessGenerator::new(&analyzed, &fixed, query_callback).generate(); let fixed = to_owned_values(fixed); - let witness = to_owned_values(witness); mock_prove(&analyzed, &fixed, &witness); } diff --git a/number/src/serialize.rs b/number/src/serialize.rs index db0d9dd1d..916c34721 100644 --- a/number/src/serialize.rs +++ b/number/src/serialize.rs @@ -103,7 +103,7 @@ pub fn write_polys_file(file: &mut impl Write, polys: &[(String pub fn read_polys_file( file: &mut impl Read, - columns: &[&str], + columns: &[String], ) -> (Vec<(String, Vec)>, DegreeType) { let width = ceil_div(T::BITS as usize, 64) * 8; @@ -156,8 +156,10 @@ mod tests { let (polys, degree) = test_polys(); write_polys_file(&mut buf, &polys); - let (read_polys, read_degree) = - read_polys_file::(&mut Cursor::new(buf), &["a", "b"]); + let (read_polys, read_degree) = read_polys_file::( + &mut Cursor::new(buf), + &["a".to_string(), "b".to_string()], + ); assert_eq!(read_polys, polys); assert_eq!(read_degree, degree); diff --git a/pil_analyzer/src/condenser.rs b/pil_analyzer/src/condenser.rs index 6ba001902..1cfefb9ea 100644 --- a/pil_analyzer/src/condenser.rs +++ b/pil_analyzer/src/condenser.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use ast::{ analyzed::{ AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition, - Identity, PolynomialReference, PolynomialType, PublicDeclaration, Reference, + Identity, PolyID, PolynomialReference, PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind, }, evaluate_binary_operation, evaluate_unary_operation, @@ -136,10 +136,14 @@ impl Condenser { .unwrap_or_else(|| panic!("Column {} not found.", poly.name)) .0; + assert!( + symbol.length.is_none(), + "Arrays cannot be used as a whole in this context, only individual array elements can be used." + ); + AlgebraicExpression::Reference(AlgebraicReference { name: poly.name.clone(), poly_id: symbol.into(), - index: None, next: false, }) } @@ -190,19 +194,39 @@ impl Condenser { } Expression::PublicReference(r) => AlgebraicExpression::PublicReference(r.clone()), Expression::IndexAccess(IndexAccess { array, index }) => { - let AlgebraicExpression::Reference(array) = self.condense_expression(array) else { - panic!("Expected direct reference before array index access."); + let array_symbol = match array.as_ref() { + ast::parsed::Expression::Reference(Reference::Poly(PolynomialReference { + name, + poly_id: _, + })) => { + &self + .symbols + .get(name) + .unwrap_or_else(|| panic!("Column {name} not found.")) + .0 + } + _ => panic!("Expected direct reference before array index access."), }; + let Some(length) = array_symbol.length else { + panic!("Array-access for non-array {}.", array_symbol.absolute_name); + }; + let index = evaluate_expression(&self.symbols, index) - .expect("Index needs to be constant number."); + .expect("Index needs to be constant number.") + .to_degree(); assert!( - array.index.is_none(), - "Cannot index an array twice in this context." + index < length, + "Array access to index {index} for array of length {length}: {}", + array_symbol.absolute_name, ); - assert!(index.to_degree() <= usize::MAX as u64); + let poly_id: PolyID = array_symbol.into(); AlgebraicExpression::Reference(AlgebraicReference { - index: Some(index.to_degree() as usize), - ..array + poly_id: PolyID { + id: poly_id.id + index, + ..poly_id + }, + name: array_symbol.array_element_name(index), + next: false, }) } Expression::String(_) => panic!("Strings are not allowed here."), diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 35b2fafcf..1ccc8c0ec 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -853,4 +853,38 @@ namespace N(65536); let formatted = process_pil_file_contents::(input).to_string(); assert_eq!(formatted, expected); } + + #[test] + fn reparse_arrays() { + let input = r#"namespace N(16); + col witness y[3]; + (N.y[1] - 2) = 0; + (N.y[2]' - 2) = 0; + public out = N.y[1](2); +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, input); + } + + #[test] + #[should_panic = "Arrays cannot be used as a whole in this context"] + fn no_direct_array_references() { + let input = r#"namespace N(16); + col witness y[3]; + (N.y - 2) = 0; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, input); + } + + #[test] + #[should_panic = "Array access to index 3 for array of length 3"] + fn no_out_of_bounds() { + let input = r#"namespace N(16); + col witness y[3]; + (N.y[3] - 2) = 0; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, input); + } } diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 124632df5..7da299a22 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -268,15 +268,8 @@ fn substitute_polynomial_references( } }); pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| { - if let AlgebraicExpression::Reference(AlgebraicReference { - name: _, - index, - next: _, - poly_id, - }) = e - { + if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e { if let Some(value) = substitutions.get(poly_id) { - assert!(index.is_none()); *e = AlgebraicExpression::Number(*value); } } diff --git a/test_data/pil/fib_arrays.pil b/test_data/pil/fib_arrays.pil new file mode 100644 index 000000000..0fd1beccb --- /dev/null +++ b/test_data/pil/fib_arrays.pil @@ -0,0 +1,19 @@ +let N = 16; +namespace FibArrays(N); + col fixed ISLAST(i) { match i { + N - 1 => 1, + _ => 0, + } }; + col witness unused; + col witness x[2]; + col witness unused2; + + ISLAST * (x[1]' - 1) = 0; + ISLAST * (x[0]' - 1) = 0; + + (1-ISLAST) * (x[0]' - x[1]) = 0; + (1-ISLAST) * (x[1]' - (x[0] + x[1])) = 0; + + (unused - 1) * unused2 = 0; + + public out = x[1](N - 1);