Support arrays in witgen.

This commit is contained in:
chriseth
2023-11-02 16:41:34 +01:00
parent 13d1f66e02
commit d94db64b6f
15 changed files with 215 additions and 125 deletions

View File

@@ -38,6 +38,9 @@ impl<T: Display> Display for Analyzed<T> {
PolynomialType::Intermediate => panic!(),
};
write!(f, " col {kind}{name}")?;
if let Some(length) = symbol.length {
write!(f, "[{length}]")?;
}
if let Some(value) = definition {
writeln!(f, "{value};")?
} else {
@@ -74,8 +77,12 @@ impl<T: Display> Display for Analyzed<T> {
let (name, _) = update_namespace(&decl.name, 0, f)?;
writeln!(
f,
" public {name} = {}({});",
decl.polynomial, decl.index
" public {name} = {}{}({});",
decl.polynomial,
decl.array_index
.map(|i| format!("[{i}]"))
.unwrap_or_default(),
decl.index
)?;
}
StatementIdentifier::Identity(i) => writeln!(f, " {}", &self.identities[*i])?,
@@ -206,16 +213,7 @@ impl Display for AlgebraicBinaryOperator {
impl Display for AlgebraicReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}{}{}",
self.name,
self.index
.as_ref()
.map(|s| format!("[{s}]"))
.unwrap_or_default(),
if self.next { "'" } else { "" },
)
write!(f, "{}{}", self.name, if self.next { "'" } else { "" },)
}
}

View File

@@ -121,6 +121,7 @@ impl<T> Analyzed<T> {
/// Removes the specified polynomials and updates the IDs of the other polynomials
/// so that they are contiguous again.
/// There must not be any reference to the removed polynomials left.
/// Does not support arrays or array elements.
pub fn remove_polynomials(&mut self, to_remove: &BTreeSet<PolyID>) {
let mut replacements: BTreeMap<PolyID, PolyID> = [
// We have to do it separately because we need to re-start the counter
@@ -135,17 +136,19 @@ impl<T> Analyzed<T> {
(0, BTreeMap::new()),
|(shift, mut replacements), (poly, _def)| {
let poly_id = poly.into();
let length = poly.length.unwrap_or(1);
if to_remove.contains(&poly_id) {
let length = poly.length.unwrap_or(1);
(shift + length, replacements)
} else {
replacements.insert(
poly_id,
PolyID {
id: poly_id.id - shift,
..poly_id
},
);
for (_name, id) in poly.array_elements() {
replacements.insert(
id,
PolyID {
id: id.id - shift,
..id
},
);
}
(shift, replacements)
}
},
@@ -306,6 +309,46 @@ impl Symbol {
pub fn is_array(&self) -> bool {
self.length.is_some()
}
/// Returns an iterator producing either just the symbol (if it is not an array),
/// or all the elements of the array with their names in the form `array[index]`.
pub fn array_elements(&self) -> impl Iterator<Item = (String, PolyID)> + '_ {
let SymbolKind::Poly(ptype) = self.kind else {
panic!("Expected polynomial.");
};
let length = self.length.unwrap_or(1);
(0..length).map(move |i| {
(
self.array_element_name(i),
PolyID {
id: self.id + i,
ptype,
},
)
})
}
/// Returns "name[index]" if this is an array or just "name" otherwise.
/// In the second case, requires index to be zero and otherwise
/// requires index to be less than length.
pub fn array_element_name(&self, index: u64) -> String {
match self.length {
Some(length) => {
assert!(index < length);
format!("{}[{index}]", self.absolute_name)
}
None => self.absolute_name.to_string(),
}
}
/// Returns "name[length]" if this is an array or just "name" otherwise.
pub fn array_name(&self) -> String {
match self.length {
Some(length) => {
format!("{}[{length}]", self.absolute_name)
}
None => self.absolute_name.to_string(),
}
}
}
/// The "kind" of a symbol. In the future, this will be mostly
@@ -438,10 +481,11 @@ pub enum Reference {
pub struct AlgebraicReference {
/// Name of the polynomial - just for informational purposes.
/// Comparisons are based on polynomial ID.
/// In case of an array element, this ends in `[i]`.
pub name: String,
/// Identifier for a polynomial reference.
/// Identifier for a polynomial reference, already contains
/// the element offset in case of an array element.
pub poly_id: PolyID,
pub index: Option<usize>,
pub next: bool,
}
@@ -464,18 +508,12 @@ impl PartialOrd for AlgebraicReference {
impl Ord for AlgebraicReference {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.poly_id.cmp(&other.poly_id) {
core::cmp::Ordering::Equal => {}
ord => return ord,
}
assert!(self.index.is_none() && other.index.is_none());
self.next.cmp(&other.next)
(&self.poly_id, &self.next).cmp(&(&other.poly_id, &other.next))
}
}
impl PartialEq for AlgebraicReference {
fn eq(&self, other: &Self) -> bool {
assert!(self.index.is_none() && other.index.is_none());
self.poly_id == other.poly_id && self.next == other.next
}
}
@@ -483,7 +521,6 @@ impl PartialEq for AlgebraicReference {
impl Hash for AlgebraicReference {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.poly_id.hash(state);
self.index.hash(state);
self.next.hash(state);
}
}

View File

@@ -3,9 +3,8 @@ use std::cmp;
use std::collections::HashMap;
use ast::analyzed::{
AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicReference,
AlgebraicUnaryOperator, Analyzed, IdentityKind, PolyID, PolynomialType, StatementIdentifier,
SymbolKind,
AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicUnaryOperator, Analyzed,
IdentityKind, PolyID, PolynomialType, StatementIdentifier, SymbolKind,
};
use starky::types::{
ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity,
@@ -59,12 +58,14 @@ pub fn export<T: FieldElement>(analyzed: &Analyzed<T>) -> PIL {
StatementIdentifier::PublicDeclaration(name) => {
let pub_def = &analyzed.public_declarations[name];
let pub_ref = &pub_def.polynomial;
let (_, expr) = exporter.polynomial_reference_to_json(&AlgebraicReference {
name: pub_ref.name.clone(),
poly_id: pub_ref.poly_id.unwrap(),
index: pub_def.array_index,
next: false,
});
let poly_id = pub_ref.poly_id.unwrap();
let (_, expr) = exporter.polynomial_reference_to_json(
PolyID {
id: poly_id.id + pub_def.array_index.unwrap_or_default() as u64,
..poly_id
},
false,
);
let id = publics.len();
publics.push(starky::types::Public {
polType: polynomial_reference_type_to_type(&expr.op).to_string(),
@@ -261,7 +262,9 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
/// returns the degree and the JSON value (intermediate polynomial IDs)
fn expression_to_json(&self, expr: &Expression<T>) -> (u32, StarkyExpr) {
match expr {
Expression::Reference(reference) => self.polynomial_reference_to_json(reference),
Expression::Reference(reference) => {
self.polynomial_reference_to_json(reference.poly_id, reference.next)
}
Expression::PublicReference(name) => (
0,
StarkyExpr {
@@ -326,24 +329,19 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
fn polynomial_reference_to_json(
&self,
AlgebraicReference {
name: _,
index,
poly_id: PolyID { id, ptype },
next,
}: &AlgebraicReference,
PolyID { id, ptype }: PolyID,
next: bool,
) -> (u32, StarkyExpr) {
let id = if *ptype == PolynomialType::Intermediate {
assert!(index.is_none());
self.intermediate_poly_expression_ids[id]
let id = if ptype == PolynomialType::Intermediate {
self.intermediate_poly_expression_ids[&id]
} else {
id + index.unwrap_or_default() as u64
id
};
let poly = StarkyExpr {
id: Some(id as usize),
op: polynomial_reference_type_to_json_string(*ptype).to_string(),
op: polynomial_reference_type_to_json_string(ptype).to_string(),
deg: 1,
next: Some(*next),
next: Some(next),
..DEFAULT_EXPR
};
(1, poly)

View File

@@ -35,13 +35,14 @@ pub fn read_poly_set<P: PolySet, T: FieldElement>(
pil: &Analyzed<T>,
dir: &Path,
) -> (Vec<(String, Vec<T>)>, DegreeType) {
let fixed_columns: Vec<&str> = P::get_polys(pil)
let column_names: Vec<String> = P::get_polys(pil)
.iter()
.map(|(poly, _)| poly.absolute_name.as_str())
.flat_map(|(poly, _)| poly.array_elements())
.map(|(name, _id)| name)
.collect();
read_polys_file(
&mut BufReader::new(File::open(dir.join(P::FILE_NAME)).unwrap()),
&fixed_columns,
&column_names,
)
}

View File

@@ -97,6 +97,14 @@ fn test_fibonacci_macro() {
gen_estark_proof(f, Default::default());
}
#[test]
fn fib_arrays() {
let f = "fib_arrays.pil";
verify_pil(f, None);
gen_halo2_proof(f, Default::default());
gen_estark_proof(f, Default::default());
}
#[test]
#[should_panic = "Witness generation failed."]
fn test_external_witgen_fails_if_none_provided() {

View File

@@ -80,7 +80,7 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback<T>> WitnessGenerator<'a, 'b, T, Q
/// Generates the committed polynomial values
/// @returns the values (in source order) and the degree of the polynomials.
pub fn generate(self) -> Vec<(&'a str, Vec<T>)> {
pub fn generate(self) -> Vec<(String, Vec<T>)> {
let fixed = FixedData::new(
self.analyzed,
self.fixed_col_values,
@@ -136,16 +136,15 @@ impl<'a, 'b, T: FieldElement, Q: QueryCallback<T>> WitnessGenerator<'a, 'b, T, Q
.chain(main_columns)
.collect::<BTreeMap<_, _>>();
// Done this way, because:
// 1. The keys need to be string references of the right lifetime.
// 2. The order needs to be the the order of declaration.
// Order columns according to the order of declaration.
self.analyzed
.committed_polys_in_source_order()
.into_iter()
.map(|(p, _)| {
let column = columns.remove(&p.absolute_name).unwrap();
.flat_map(|(p, _)| p.array_elements())
.map(|(name, _id)| {
let column = columns.remove(&name).unwrap();
assert!(!column.is_empty());
(p.absolute_name.as_str(), column)
(name, column)
})
.collect()
}
@@ -165,25 +164,21 @@ impl<'a, T: FieldElement> FixedData<'a, T> {
external_witness_values: Vec<(&'a str, Vec<T>)>,
) -> Self {
let mut external_witness_values = BTreeMap::from_iter(external_witness_values);
let witness_cols = WitnessColumnMap::from(
analyzed
.committed_polys_in_source_order()
.iter()
.enumerate()
.map(|(i, (poly, value))| {
if poly.length.is_some() {
unimplemented!("Committed arrays not implemented.")
}
assert_eq!(i as u64, poly.id);
let external_values =
external_witness_values.remove(poly.absolute_name.as_str());
if let Some(external_values) = &external_values {
assert_eq!(external_values.len(), analyzed.degree() as usize);
}
let col = WitnessColumn::new(i, &poly.absolute_name, value, external_values);
col
}),
);
let witness_cols =
WitnessColumnMap::from(analyzed.committed_polys_in_source_order().iter().flat_map(
|(poly, value)| {
poly.array_elements()
.map(|(name, poly_id)| {
let external_values = external_witness_values.remove(name.as_str());
if let Some(external_values) = &external_values {
assert_eq!(external_values.len(), analyzed.degree() as usize);
}
WitnessColumn::new(poly_id.id as usize, &name, value, external_values)
})
.collect::<Vec<_>>()
},
));
if !external_witness_values.is_empty() {
panic!(
@@ -251,7 +246,7 @@ pub struct WitnessColumn<'a, T> {
impl<'a, T> WitnessColumn<'a, T> {
pub fn new(
id: usize,
name: &'a str,
name: &str,
value: &'a Option<FunctionValueDefinition<T>>,
external_values: Option<Vec<T>>,
) -> WitnessColumn<'a, T> {
@@ -267,7 +262,6 @@ impl<'a, T> WitnessColumn<'a, T> {
},
name: name.to_string(),
next: false,
index: None,
};
WitnessColumn {
poly,

View File

@@ -146,7 +146,6 @@ where
let poly_ref = AlgebraicReference {
name: poly.name.clone(),
poly_id,
index: None,
next,
};
Ok(rows

View File

@@ -6,14 +6,7 @@ use ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference};
/// - not shifted with `'`
/// and return the polynomial if so
pub fn try_to_simple_poly<T>(expr: &Expression<T>) -> Option<&AlgebraicReference> {
if let Expression::Reference(
p @ AlgebraicReference {
index: None,
next: false,
..
},
) = expr
{
if let Expression::Reference(p @ AlgebraicReference { next: false, .. }) = expr {
Some(p)
} else {
None
@@ -22,10 +15,7 @@ pub fn try_to_simple_poly<T>(expr: &Expression<T>) -> Option<&AlgebraicReference
pub fn try_to_simple_poly_ref<T>(expr: &Expression<T>) -> Option<&AlgebraicReference> {
if let Expression::Reference(poly_ref) = expr {
if poly_ref.index.is_none() && !poly_ref.next {
return Some(poly_ref);
}
None
(!poly_ref.next).then_some(poly_ref)
} else {
None
}
@@ -33,10 +23,7 @@ pub fn try_to_simple_poly_ref<T>(expr: &Expression<T>) -> Option<&AlgebraicRefer
pub fn is_simple_poly_of_name<T>(expr: &Expression<T>, poly_name: &str) -> bool {
if let Expression::Reference(AlgebraicReference {
name,
index: None,
next: false,
..
name, next: false, ..
}) = expr
{
name == poly_name

View File

@@ -239,8 +239,6 @@ fn expression_2_expr<T: FieldElement>(cd: &CircuitData<T>, expr: &Expression<T>)
match expr {
Expression::Number(n) => Expr::Const(n.to_arbitrary_integer()),
Expression::Reference(polyref) => {
assert_eq!(polyref.index, None);
let plonkvar = PlonkVar::Query(ColumnQuery {
column: cd.col(&polyref.name),
rotation: polyref.next as i32,

View File

@@ -93,7 +93,6 @@ mod test {
executor::witgen::WitnessGenerator::new(&analyzed, &fixed, query_callback).generate();
let fixed = to_owned_values(fixed);
let witness = to_owned_values(witness);
mock_prove(&analyzed, &fixed, &witness);
}
@@ -110,7 +109,6 @@ mod test {
executor::witgen::WitnessGenerator::new(&analyzed, &fixed, query_callback).generate();
let fixed = to_owned_values(fixed);
let witness = to_owned_values(witness);
mock_prove(&analyzed, &fixed, &witness);
}

View File

@@ -103,7 +103,7 @@ pub fn write_polys_file<T: FieldElement>(file: &mut impl Write, polys: &[(String
pub fn read_polys_file<T: FieldElement>(
file: &mut impl Read,
columns: &[&str],
columns: &[String],
) -> (Vec<(String, Vec<T>)>, DegreeType) {
let width = ceil_div(T::BITS as usize, 64) * 8;
@@ -156,8 +156,10 @@ mod tests {
let (polys, degree) = test_polys();
write_polys_file(&mut buf, &polys);
let (read_polys, read_degree) =
read_polys_file::<Bn254Field>(&mut Cursor::new(buf), &["a", "b"]);
let (read_polys, read_degree) = read_polys_file::<Bn254Field>(
&mut Cursor::new(buf),
&["a".to_string(), "b".to_string()],
);
assert_eq!(read_polys, polys);
assert_eq!(read_degree, degree);

View File

@@ -6,7 +6,7 @@ use std::collections::HashMap;
use ast::{
analyzed::{
AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition,
Identity, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
Identity, PolyID, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
StatementIdentifier, Symbol, SymbolKind,
},
evaluate_binary_operation, evaluate_unary_operation,
@@ -136,10 +136,14 @@ impl<T: FieldElement> Condenser<T> {
.unwrap_or_else(|| panic!("Column {} not found.", poly.name))
.0;
assert!(
symbol.length.is_none(),
"Arrays cannot be used as a whole in this context, only individual array elements can be used."
);
AlgebraicExpression::Reference(AlgebraicReference {
name: poly.name.clone(),
poly_id: symbol.into(),
index: None,
next: false,
})
}
@@ -190,19 +194,39 @@ impl<T: FieldElement> Condenser<T> {
}
Expression::PublicReference(r) => AlgebraicExpression::PublicReference(r.clone()),
Expression::IndexAccess(IndexAccess { array, index }) => {
let AlgebraicExpression::Reference(array) = self.condense_expression(array) else {
panic!("Expected direct reference before array index access.");
let array_symbol = match array.as_ref() {
ast::parsed::Expression::Reference(Reference::Poly(PolynomialReference {
name,
poly_id: _,
})) => {
&self
.symbols
.get(name)
.unwrap_or_else(|| panic!("Column {name} not found."))
.0
}
_ => panic!("Expected direct reference before array index access."),
};
let Some(length) = array_symbol.length else {
panic!("Array-access for non-array {}.", array_symbol.absolute_name);
};
let index = evaluate_expression(&self.symbols, index)
.expect("Index needs to be constant number.");
.expect("Index needs to be constant number.")
.to_degree();
assert!(
array.index.is_none(),
"Cannot index an array twice in this context."
index < length,
"Array access to index {index} for array of length {length}: {}",
array_symbol.absolute_name,
);
assert!(index.to_degree() <= usize::MAX as u64);
let poly_id: PolyID = array_symbol.into();
AlgebraicExpression::Reference(AlgebraicReference {
index: Some(index.to_degree() as usize),
..array
poly_id: PolyID {
id: poly_id.id + index,
..poly_id
},
name: array_symbol.array_element_name(index),
next: false,
})
}
Expression::String(_) => panic!("Strings are not allowed here."),

View File

@@ -853,4 +853,38 @@ namespace N(65536);
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn reparse_arrays() {
let input = r#"namespace N(16);
col witness y[3];
(N.y[1] - 2) = 0;
(N.y[2]' - 2) = 0;
public out = N.y[1](2);
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, input);
}
#[test]
#[should_panic = "Arrays cannot be used as a whole in this context"]
fn no_direct_array_references() {
let input = r#"namespace N(16);
col witness y[3];
(N.y - 2) = 0;
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, input);
}
#[test]
#[should_panic = "Array access to index 3 for array of length 3"]
fn no_out_of_bounds() {
let input = r#"namespace N(16);
col witness y[3];
(N.y[3] - 2) = 0;
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, input);
}
}

View File

@@ -268,15 +268,8 @@ fn substitute_polynomial_references<T: FieldElement>(
}
});
pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| {
if let AlgebraicExpression::Reference(AlgebraicReference {
name: _,
index,
next: _,
poly_id,
}) = e
{
if let AlgebraicExpression::Reference(AlgebraicReference { poly_id, .. }) = e {
if let Some(value) = substitutions.get(poly_id) {
assert!(index.is_none());
*e = AlgebraicExpression::Number(*value);
}
}

View File

@@ -0,0 +1,19 @@
let N = 16;
namespace FibArrays(N);
col fixed ISLAST(i) { match i {
N - 1 => 1,
_ => 0,
} };
col witness unused;
col witness x[2];
col witness unused2;
ISLAST * (x[1]' - 1) = 0;
ISLAST * (x[0]' - 1) = 0;
(1-ISLAST) * (x[0]' - x[1]) = 0;
(1-ISLAST) * (x[1]' - (x[0] + x[1])) = 0;
(unused - 1) * unused2 = 0;
public out = x[1](N - 1);