Merge pull request #588 from powdr-labs/support-imp

Support intermediate polynomials
This commit is contained in:
chriseth
2023-09-13 09:54:09 +02:00
committed by GitHub
13 changed files with 98 additions and 8 deletions

View File

@@ -71,6 +71,7 @@ impl<T: Display> Display for FunctionValueDefinition<T> {
write!(f, " = {}", items.iter().map(|i| i.to_string()).join(" + "))
}
FunctionValueDefinition::Query(e) => write!(f, "(i) query {e}"),
FunctionValueDefinition::Expression(e) => write!(f, " = {e}"),
}
}
}

View File

@@ -61,6 +61,12 @@ impl<T> Analyzed<T> {
self.definitions_in_source_order(PolynomialType::Committed)
}
pub fn intermediate_polys_in_source_order(
&self,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
self.definitions_in_source_order(PolynomialType::Intermediate)
}
pub fn definitions_in_source_order(
&self,
poly_type: PolynomialType,
@@ -96,12 +102,12 @@ impl<T> Analyzed<T> {
/// so that they are contiguous again.
/// There must not be any reference to the removed polynomials left.
pub fn remove_polynomials(&mut self, to_remove: &BTreeSet<PolyID>) {
// TODO intermediate polys
let replacements: BTreeMap<PolyID, PolyID> = [
// We have to do it separately because we need to re-start the counter
// for each kind.
self.committed_polys_in_source_order(),
self.constant_polys_in_source_order(),
self.intermediate_polys_in_source_order(),
]
.map(|polys| {
polys
@@ -239,6 +245,7 @@ pub enum FunctionValueDefinition<T> {
Mapping(Expression<T>),
Array(Vec<RepeatedArray<T>>),
Query(Expression<T>),
Expression(Expression<T>),
}
/// An array of elements that might be repeated.

View File

@@ -22,6 +22,7 @@ where
.iter_mut()
.flat_map(|e| e.pattern.iter_mut())
.try_for_each(|e| previsit_expression_mut(e, f)),
Some(FunctionValueDefinition::Expression(e)) => previsit_expression_mut(e, f),
None => ControlFlow::Continue(()),
})?;
@@ -66,6 +67,7 @@ where
.iter_mut()
.flat_map(|e| e.pattern.iter_mut())
.try_for_each(|e| postvisit_expression_mut(e, f)),
Some(FunctionValueDefinition::Expression(e)) => postvisit_expression_mut(e, f),
None => ControlFlow::Continue(()),
})?;

View File

@@ -370,6 +370,9 @@ impl<T: Display> Display for FunctionDefinition<T> {
FunctionDefinition::Query(params, value) => {
write!(f, "({}) query {value}", params.join(", "),)
}
FunctionDefinition::Expression(e) => {
write!(f, " = {e}")
}
}
}
}

View File

@@ -289,7 +289,6 @@ pub struct FunctionCall<T> {
}
/// The definition of a function (excluding its name):
/// Either a param-value mapping or an array expression.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum FunctionDefinition<T> {
/// Parameter-value-mapping.
@@ -298,6 +297,8 @@ pub enum FunctionDefinition<T> {
Array(ArrayExpression<T>),
/// Prover query.
Query(Vec<String>, Expression<T>),
/// Expression, for intermediate polynomials
Expression(Expression<T>),
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
@@ -451,6 +452,7 @@ where
postvisit_expression_mut(e, f)
}
FunctionDefinition::Array(ae) => postvisit_expression_in_array_expression_mut(ae, f),
FunctionDefinition::Expression(e) => postvisit_expression_mut(e, f),
},
PilStatement::PolynomialCommitDeclaration(_, _, None)
| PilStatement::Include(_, _)

View File

@@ -34,7 +34,7 @@ pub fn export<T: FieldElement>(analyzed: &Analyzed<T>) -> JsonValue {
StatementIdentifier::Definition(name) => {
if let (poly, Some(value)) = &analyzed.definitions[name] {
if poly.poly_type == PolynomialType::Intermediate {
if let FunctionValueDefinition::Mapping(value) = value {
if let FunctionValueDefinition::Expression(value) = value {
let expression_id = exporter.extract_expression(value, 1);
assert_eq!(
expression_id,

View File

@@ -163,6 +163,7 @@ pub fn compile_asm_string<T: FieldElement>(
);
return Ok((pil_file_path, None));
}
fs::write(pil_file_path.clone(), format!("{pil}")).unwrap();
let pil_file_name = pil_file_path.file_name().unwrap();

View File

@@ -162,6 +162,22 @@ fn full_pil_constant() {
gen_halo2_proof(f, Default::default());
}
#[test]
#[should_panic = "assertion failed: poly.is_fixed() || poly.is_witness()"]
fn intermediate() {
let f = "intermediate.asm";
verify_asm::<GoldilocksField>(f, Default::default());
gen_halo2_proof(f, Default::default());
}
#[test]
#[should_panic = "assertion failed: poly.is_fixed() || poly.is_witness()"]
fn intermediate_nested() {
let f = "intermediate_nested.asm";
verify_asm::<GoldilocksField>(f, Default::default());
gen_halo2_proof(f, Default::default());
}
#[test]
fn book() {
for f in fs::read_dir("../test_data/asm/book/").unwrap() {

View File

@@ -74,6 +74,9 @@ fn generate_values<T: FieldElement>(
values
}
FunctionValueDefinition::Query(_) => panic!("Query used for fixed column."),
FunctionValueDefinition::Expression(_) => {
panic!("Expression used for fixed column, only expected for intermediate polynomials")
}
}
}

View File

@@ -147,7 +147,7 @@ impl<T: FieldElement> PILContext<T> {
name,
None,
PolynomialType::Intermediate,
Some(FunctionDefinition::Mapping(vec![], value)),
Some(FunctionDefinition::Expression(value)),
);
}
PilStatement::PublicDeclaration(start, name, polynomial, index) => {
@@ -321,12 +321,14 @@ impl<T: FieldElement> PILContext<T> {
let name = poly.absolute_name.clone();
let value = value.map(|v| match v {
FunctionDefinition::Expression(expr) => {
assert!(!have_array_size);
assert!(poly.poly_type == PolynomialType::Intermediate);
FunctionValueDefinition::Expression(self.process_expression(expr))
}
FunctionDefinition::Mapping(params, expr) => {
assert!(!have_array_size);
assert!(
poly.poly_type == PolynomialType::Constant
|| poly.poly_type == PolynomialType::Intermediate
);
assert!(poly.poly_type == PolynomialType::Constant);
FunctionValueDefinition::Mapping(self.process_function(params, expr))
}
FunctionDefinition::Query(params, expr) => {
@@ -660,4 +662,20 @@ namespace T(65536);
}
assert_eq!(input, formatted);
}
#[test]
fn intermediate() {
let input = r#"namespace N(65536);
col witness x;
col intermediate = x;
intermediate = intermediate;
"#;
let expected = r#"namespace N(65536);
col witness x;
col intermediate = N.x;
N.intermediate = N.intermediate;
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
}

View File

@@ -96,6 +96,7 @@ fn constant_value<T: FieldElement>(function: &FunctionValueDefinition<T>) -> Opt
}
}
FunctionValueDefinition::Query(_) => None,
FunctionValueDefinition::Expression(_) => None,
}
}
@@ -411,6 +412,22 @@ mod test {
N.Z = ((1 + N.A) * 2);
N.A = (1 + N.A);
N.Z = (1 + N.A);
"#;
let optimized = optimize(process_pil_file_contents::<GoldilocksField>(input)).to_string();
assert_eq!(optimized, expectation);
}
#[test]
fn intermediate() {
let input = r#"namespace N(65536);
col witness x;
col intermediate = x;
intermediate = intermediate;
"#;
let expectation = r#"namespace N(65536);
col witness x;
col intermediate = N.x;
N.intermediate = N.intermediate;
"#;
let optimized = optimize(process_pil_file_contents::<GoldilocksField>(input)).to_string();
assert_eq!(optimized, expectation);

View File

@@ -0,0 +1,9 @@
machine Intermediate(latch, operation_id) {
constraints {
col fixed latch = [1]*;
col fixed operation_id = [0]*;
col witness x;
col intermediate = x;
intermediate = intermediate;
}
}

View File

@@ -0,0 +1,11 @@
machine Intermediate(latch, operation_id) {
constraints {
col fixed latch = [1]*;
col fixed operation_id = [0]*;
col witness x;
col intermediate = x;
col int2 = intermediate;
col int3 = int2 + intermediate;
int3 = 2 * x;
}
}