From 98ebf3fbc6cef868000e77620d026213e521baaa Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 31 Jul 2024 18:19:34 +0200 Subject: [PATCH] Set hint (#1609) --- pil-analyzer/src/condenser.rs | 196 ++++++++++++++++++----- pil-analyzer/src/evaluator.rs | 24 ++- pil-analyzer/src/expression_processor.rs | 8 +- pil-analyzer/src/side_effect_checker.rs | 32 +++- pil-analyzer/src/type_builtins.rs | 4 + pil-analyzer/tests/condenser.rs | 165 ++++++++++++++++++- pil-analyzer/tests/side_effects.rs | 55 +++++++ 7 files changed, 438 insertions(+), 46 deletions(-) diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index e13af314a..afdceba4a 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -2,7 +2,8 @@ //! i.e. it turns more complex expressions in identities to simpler expressions. use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + fmt::Display, iter::once, str::FromStr, sync::Arc, @@ -11,7 +12,7 @@ use std::{ use powdr_ast::{ analyzed::{ self, AlgebraicExpression, AlgebraicReference, Analyzed, Expression, - FunctionValueDefinition, Identity, IdentityKind, PolynomialType, PublicDeclaration, + FunctionValueDefinition, Identity, IdentityKind, PolyID, PolynomialType, PublicDeclaration, SelectedExpressions, StatementIdentifier, Symbol, SymbolKind, }, parsed::{ @@ -19,7 +20,7 @@ use powdr_ast::{ asm::{AbsoluteSymbolPath, SymbolPath}, display::format_type_scheme_around_name, types::{ArrayType, Type}, - TypedExpression, + FunctionKind, TypedExpression, }, }; use powdr_number::{DegreeType, FieldElement}; @@ -45,6 +46,7 @@ pub fn condense( let mut condensed_identities = vec![]; let mut intermediate_columns = HashMap::new(); let mut new_columns = vec![]; + let mut new_values = HashMap::new(); // Condense identities and intermediate columns and update the source order. let source_order = source_order .into_iter() @@ -98,15 +100,14 @@ pub fn condense( } s => Some(s), }; - // Extract and prepend the new witness columns, then identities + // Extract and prepend the new columns, then identities // and finally the original statement (if it exists). let new_cols = condenser .extract_new_columns() .into_iter() - .map(|(new_col, value)| { - let name = new_col.absolute_name.clone(); - new_columns.push((new_col, value)); - StatementIdentifier::Definition(name) + .map(|new_col| { + new_columns.push(new_col.clone()); + StatementIdentifier::Definition(new_col.absolute_name) }) .collect::>(); @@ -120,6 +121,12 @@ pub fn condense( }) .collect::>(); + for (name, hint) in condenser.extract_new_column_values() { + if new_values.insert(name.clone(), hint).is_some() { + panic!("Column {name} already has a hint set, but tried to add another one.",) + } + } + new_cols .into_iter() .chain(identity_statements) @@ -128,8 +135,20 @@ pub fn condense( .collect(); definitions.retain(|name, _| !intermediate_columns.contains_key(name)); - for (symbol, value) in new_columns { - definitions.insert(symbol.absolute_name.clone(), (symbol, value)); + for symbol in new_columns { + definitions.insert(symbol.absolute_name.clone(), (symbol, None)); + } + for (name, new_value) in new_values { + if let Some((_, value)) = definitions.get_mut(&name) { + if !value.is_none() { + panic!( + "Column {name} already has a value / hint set, but tried to add another one." + ) + } + *value = Some(new_value); + } else { + panic!("Column {name} not found."); + } } for decl in public_declarations.values_mut() { @@ -164,10 +183,12 @@ pub struct Condenser<'a, T> { namespace: AbsoluteSymbolPath, /// ID dispensers. counters: Counters, - /// The generated columns since the last extraction. - new_columns: Vec<(Symbol, Option)>, - /// The names of all new olumns ever generated, to avoid duplicates. - all_new_names: HashSet, + /// The generated columns since the last extraction in creation order. + new_columns: Vec, + /// The hints and fixed column definitions added since the last extraction. + new_column_values: HashMap, + /// The names of all new columns ever generated, to avoid duplicates. + new_symbols: HashSet, new_constraints: Vec>, } @@ -181,7 +202,8 @@ impl<'a, T: FieldElement> Condenser<'a, T> { namespace: Default::default(), counters, new_columns: vec![], - all_new_names: HashSet::new(), + new_column_values: Default::default(), + new_symbols: HashSet::new(), new_constraints: vec![], } } @@ -226,11 +248,17 @@ impl<'a, T: FieldElement> Condenser<'a, T> { self.degree = degree; } - /// Returns the witness columns generated since the last call to this function. - pub fn extract_new_columns(&mut self) -> Vec<(Symbol, Option)> { + /// Returns columns generated since the last call to this function. + pub fn extract_new_columns(&mut self) -> Vec { std::mem::take(&mut self.new_columns) } + /// Return the new column values (fixed column definitions or witness column hints) + /// added since the last call to this function. + pub fn extract_new_column_values(&mut self) -> HashMap { + std::mem::take(&mut self.new_column_values) + } + /// Returns the new constraints generated since the last call to this function. pub fn extract_new_constraints(&mut self) -> Vec> { std::mem::take(&mut self.new_constraints) @@ -322,24 +350,16 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { } else { PolynomialType::Committed }); - let value = value.map(|v|{ - if let Value::Closure(evaluator::Closure { - lambda, - environment: _, - type_args: _, - }) = v.as_ref() - { - if !lambda.outer_var_references.is_empty() { - return Err(EvalError::TypeError(format!("Lambda expression for fixed column {name} must not reference outer variables."))) - } - Ok(FunctionValueDefinition::Expression(TypedExpression { - e: Expression::LambdaExpression(source.clone(), (*lambda).clone()), - type_scheme: None, - })) - } else { - Err(EvalError::TypeError(format!("Only lambda expressions are allowed for dynamically-created fixed columns. Got {v}."))) - } - }).transpose()?; + let value = value + .map(|v| { + closure_to_function(&source, v.as_ref(), FunctionKind::Pure).map_err(|e| match e { + EvalError::TypeError(e) => { + EvalError::TypeError(format!("Error creating fixed column {name}: {e}")) + } + _ => e, + }) + }) + .transpose()?; let symbol = Symbol { id: self.counters.dispense_symbol_id(kind, None), @@ -351,18 +371,71 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> { degree: Some(self.degree.unwrap()), }; - self.all_new_names.insert(name.clone()); - self.new_columns.push((symbol.clone(), value)); + self.new_symbols.insert(name.clone()); + self.new_columns.push(symbol.clone()); + if let Some(value) = value { + self.new_column_values.insert(name.clone(), value); + } Ok( Value::Expression(AlgebraicExpression::Reference(AlgebraicReference { name, - poly_id: (&symbol).into(), + poly_id: PolyID::from(&symbol), next: false, })) .into(), ) } + fn set_hint( + &mut self, + col: Arc>, + expr: Arc>, + ) -> Result<(), EvalError> { + let name = match col.as_ref() { + Value::Expression(AlgebraicExpression::Reference(AlgebraicReference { + name, + poly_id, + next: false, + })) => { + if poly_id.ptype != PolynomialType::Committed { + return Err(EvalError::TypeError(format!( + "Expected reference to witness column as first argument for std::prover::set_hint, but got {} column {name}.", + poly_id.ptype + ))); + } + if name.contains('[') { + return Err(EvalError::TypeError(format!( + "Array elements are not supported for std::prover::set_hint (called on {name})." + ))); + } + name.clone() + } + col => { + return Err(EvalError::TypeError(format!( + "Expected reference to witness column as first argument for std::prover::set_hint, but got {col}: {}", + col.type_formatted() + ))); + } + }; + + let value = closure_to_function(&SourceRef::unknown(), expr.as_ref(), FunctionKind::Query) + .map_err(|e| match e { + EvalError::TypeError(e) => { + EvalError::TypeError(format!("Error setting hint for column {col}: {e}")) + } + _ => e, + })?; + match self.new_column_values.entry(name) { + Entry::Vacant(entry) => entry.insert(value), + Entry::Occupied(_) => { + return Err(EvalError::TypeError(format!( + "Column {col} already has a hint set, but tried to add another one." + ))); + } + }; + Ok(()) + } + fn add_constraints( &mut self, constraints: Arc>, @@ -392,7 +465,7 @@ impl<'a, T: FieldElement> Condenser<'a, T> { .chain((1..).map(Some)) .map(|cnt| format!("{name}{}", cnt.map(|c| format!("_{c}")).unwrap_or_default())) .map(|name| self.namespace.with_part(&name).to_dotted_string()) - .find(|name| !self.symbols.contains_key(name) && !self.all_new_names.contains(name)) + .find(|name| !self.symbols.contains_key(name) && !self.new_symbols.contains(name)) .unwrap() } } @@ -513,3 +586,48 @@ fn to_expr(value: &Value<'_, T>) -> AlgebraicExpression { panic!() } } + +/// Turns a value of function type (i.e. a closure) into a FunctionValueDefinition +/// and sets the expected function kind. +/// Does not allow captured variables. +fn closure_to_function( + source: &SourceRef, + value: &Value<'_, T>, + expected_kind: FunctionKind, +) -> Result { + let Value::Closure(evaluator::Closure { + lambda, + environment: _, + type_args, + }) = value + else { + return Err(EvalError::TypeError(format!( + "Expected lambda expressions but got {value}." + ))); + }; + + if !type_args.is_empty() { + return Err(EvalError::TypeError( + "Lambda expression must not have type arguments.".to_string(), + )); + } + if !lambda.outer_var_references.is_empty() { + return Err(EvalError::TypeError(format!( + "Lambda expression must not reference outer variables: {lambda}" + ))); + } + if lambda.kind != FunctionKind::Pure && lambda.kind != expected_kind { + return Err(EvalError::TypeError(format!( + "Expected {expected_kind} lambda expression but got {}.", + lambda.kind + ))); + } + + let mut lambda = (*lambda).clone(); + lambda.kind = expected_kind; + + Ok(FunctionValueDefinition::Expression(TypedExpression { + e: Expression::LambdaExpression(source.clone(), lambda), + type_scheme: None, + })) +} diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index c989f28ff..c1170f8c9 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -308,7 +308,7 @@ impl<'a, T: FieldElement> Value<'a, T> { } } -const BUILTINS: [(&str, BuiltinFunction); 10] = [ +const BUILTINS: [(&str, BuiltinFunction); 11] = [ ("std::array::len", BuiltinFunction::ArrayLen), ("std::check::panic", BuiltinFunction::Panic), ("std::convert::expr", BuiltinFunction::ToExpr), @@ -317,6 +317,7 @@ const BUILTINS: [(&str, BuiltinFunction); 10] = [ ("std::debug::print", BuiltinFunction::Print), ("std::field::modulus", BuiltinFunction::Modulus), ("std::prelude::challenge", BuiltinFunction::Challenge), + ("std::prover::set_hint", BuiltinFunction::SetHint), ("std::prover::degree", BuiltinFunction::Degree), ("std::prover::eval", BuiltinFunction::Eval), ]; @@ -341,6 +342,8 @@ pub enum BuiltinFunction { ToFe, /// std::prover::challenge: int, int -> expr, constructs a challenge with a given stage and ID. Challenge, + /// std::prover::set_hint: expr, (int -> std::prover::Query) -> (), adds a hint to a witness column. + SetHint, /// std::prover::degree: -> int, returns the current column length / degree. Degree, /// std::prover::eval: expr -> fe, evaluates an expression on the current row @@ -551,6 +554,16 @@ pub trait SymbolLookup<'a, T: FieldElement> { ))) } + fn set_hint( + &mut self, + _col: Arc>, + _expr: Arc>, + ) -> Result<(), EvalError> { + Err(EvalError::Unsupported( + "Tried to add hint to column outside of statement context.".to_string(), + )) + } + fn add_constraints( &mut self, _constraints: Arc>, @@ -1105,6 +1118,7 @@ fn evaluate_builtin_function<'a, T: FieldElement>( BuiltinFunction::ToFe => 1, BuiltinFunction::ToInt => 1, BuiltinFunction::Challenge => 2, + BuiltinFunction::SetHint => 2, BuiltinFunction::Degree => 0, BuiltinFunction::Eval => 1, }; @@ -1140,7 +1154,7 @@ fn evaluate_builtin_function<'a, T: FieldElement>( } else { print!("{msg}"); } - Value::Array(Default::default()).into() + Value::Tuple(vec![]).into() } BuiltinFunction::ToExpr => { let arg = arguments.pop().unwrap(); @@ -1173,6 +1187,12 @@ fn evaluate_builtin_function<'a, T: FieldElement>( })) .into() } + BuiltinFunction::SetHint => { + let expr = arguments.pop().unwrap(); + let col = arguments.pop().unwrap(); + symbols.set_hint(col, expr)?; + Value::Tuple(vec![]).into() + } BuiltinFunction::Degree => symbols.degree()?, BuiltinFunction::Eval => { let arg = arguments.pop().unwrap(); diff --git a/pil-analyzer/src/expression_processor.rs b/pil-analyzer/src/expression_processor.rs index fc289f4fb..30f1bbb81 100644 --- a/pil-analyzer/src/expression_processor.rs +++ b/pil-analyzer/src/expression_processor.rs @@ -11,7 +11,7 @@ use powdr_ast::{ use powdr_number::DegreeType; use powdr_parser_util::SourceRef; use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeSet, HashMap, HashSet}, str::FromStr, }; @@ -311,7 +311,7 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> { }: LambdaExpression, ) -> LambdaExpression { let previous_local_vars = self.save_local_variables(); - let previous_local_var_refs = self.local_var_references.clone(); + let previous_local_var_refs = std::mem::take(&mut self.local_var_references); let local_variable_height = self.local_variable_counter; let params = params @@ -326,11 +326,13 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> { } let body = Box::new(self.process_expression(*body)); - let outer_var_references = + let outer_var_references: BTreeSet = std::mem::replace(&mut self.local_var_references, previous_local_var_refs) .into_iter() .filter(|id| *id < local_variable_height) .collect(); + self.local_var_references + .extend(outer_var_references.clone()); self.reset_local_variables(previous_local_vars); LambdaExpression { kind, diff --git a/pil-analyzer/src/side_effect_checker.rs b/pil-analyzer/src/side_effect_checker.rs index 8e79686e4..23a481b9b 100644 --- a/pil-analyzer/src/side_effect_checker.rs +++ b/pil-analyzer/src/side_effect_checker.rs @@ -4,7 +4,10 @@ use powdr_ast::{ analyzed::{ Expression, FunctionValueDefinition, Reference, Symbol, SymbolKind, TypedExpression, }, - parsed::{types::Type, BlockExpression, FunctionKind, LambdaExpression, StatementInsideBlock}, + parsed::{ + types::Type, BlockExpression, FunctionCall, FunctionKind, LambdaExpression, + StatementInsideBlock, + }, }; use lazy_static::lazy_static; @@ -79,6 +82,32 @@ impl<'a> SideEffectChecker<'a> { } e.children().try_for_each(|e| self.check(e)) } + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) if matches!(function.as_ref(), Expression::Reference(_, Reference::Poly(r)) if r.name == "std::prover::set_hint") => + { + // The function "set_hint" is special: It expects a "query" function as + // second argument, so we switch context when descending into the second argument. + self.check(function)?; + match &arguments[..] { + [col, hint] => { + self.check(col)?; + assert_eq!(self.context, FunctionKind::Constr); + self.context = FunctionKind::Query; + let result = self.check(hint); + self.context = FunctionKind::Constr; + result + } + _ => { + // Not the correct number of arguments, will lead to a type error later. + arguments.iter().try_for_each(|e| self.check(e)) + } + } + } _ => e.children().try_for_each(|e| self.check(e)), } } @@ -117,6 +146,7 @@ lazy_static! { ("std::field::modulus", FunctionKind::Pure), ("std::prelude::challenge", FunctionKind::Constr), // strictly, only new_challenge would need "constr" ("std::prover::degree", FunctionKind::Pure), + ("std::prover::set_hint", FunctionKind::Constr), ("std::prover::eval", FunctionKind::Query), ] .into_iter() diff --git a/pil-analyzer/src/type_builtins.rs b/pil-analyzer/src/type_builtins.rs index 2a3e3ade8..0e150d8b3 100644 --- a/pil-analyzer/src/type_builtins.rs +++ b/pil-analyzer/src/type_builtins.rs @@ -48,6 +48,10 @@ lazy_static! { ("std::field::modulus", ("", "-> int")), ("std::prelude::challenge", ("", "int, int -> expr")), ("std::prover::degree", ("", "-> int")), + ( + "std::prover::set_hint", + ("", "expr, (int -> std::prover::Query) -> ()") + ), ("std::prover::eval", ("", "expr -> fe")), ] .into_iter() diff --git a/pil-analyzer/tests/condenser.rs b/pil-analyzer/tests/condenser.rs index 7c5bad0a0..0a5282261 100644 --- a/pil-analyzer/tests/condenser.rs +++ b/pil-analyzer/tests/condenser.rs @@ -230,7 +230,7 @@ fn new_fixed_column() { } #[test] -#[should_panic = "Lambda expression for fixed column N.fi must not reference outer variables."] +#[should_panic = "Error creating fixed column N.fi: Lambda expression must not reference outer variables: (|i| (i + j) * 2)"] fn new_fixed_column_as_closure() { let input = r#"namespace N(16); let f = constr |j| { @@ -243,3 +243,166 @@ fn new_fixed_column_as_closure() { "#; analyze_string::(input); } + +#[test] +fn set_hint() { + let input = r#" + namespace std::prover; + let set_hint = 8; + let eval = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + let y; + std::prover::set_hint(y, |i| std::prover::Query::Hint(std::prover::eval(x))); + { + let z; + std::prover::set_hint(z, query |_| std::prover::Query::Hint(1)); + }; + "#; + let expected = r#"namespace std::prover; + let set_hint = 8; + let eval = 8; + enum Query { + Hint(fe), + None, + } +namespace N(16); + col witness x; + col witness y(i) query std::prover::Query::Hint(std::prover::eval(N.x)); + col witness z(_) query std::prover::Query::Hint(1); +"#; + let formatted = analyze_string::(input).to_string(); + assert_eq!(formatted, expected); +} + +#[test] +#[should_panic = "Expected type: int -> std::prover::Query"] +fn set_hint_invalid_function() { + let input = r#" + namespace std::prover; + let set_hint = 8; + let eval = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + std::prover::set_hint(x, query |_, _| std::prover::Query::Hint(1)); + "#; + analyze_string::(input); +} + +#[test] +#[should_panic = "Array elements are not supported for std::prover::set_hint (called on N.x[0])."] +fn set_hint_array_element() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x: col[2]; + std::prover::set_hint(x[0], query |_| std::prover::Query::Hint(1)); + "#; + let expected = r#"namespace std::prover; + let set_hint = 8; + let eval = 8; + enum Query { + Hint(fe), + None, + } +namespace N(16); + col witness x(_) query std::prover::Query::Hint(1); + col witness y(i) query std::prover::Query::Hint(std::prover::eval(N.x)); +"#; + let formatted = analyze_string::(input).to_string(); + assert_eq!(formatted, expected); +} + +#[test] +#[should_panic = "Expected reference to witness column as first argument for std::prover::set_hint, but got intermediate column N.y."] +fn set_hint_no_col() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + let y: expr = x; + std::prover::set_hint(y, query |_| std::prover::Query::Hint(1)); + "#; + analyze_string::(input); +} + +#[test] +#[should_panic = "Column N.x already has a hint set, but tried to add another one."] +fn set_hint_twice() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + std::prover::set_hint(x, query |_| std::prover::Query::Hint(1)); + std::prover::set_hint(x, query |_| std::prover::Query::Hint(2)); + "#; + analyze_string::(input); +} + +#[test] +#[should_panic = "Column N.x already has a hint set, but tried to add another one."] +fn set_hint_twice_in_constr() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let y; + { + let x; + std::prover::set_hint(x, query |_| std::prover::Query::Hint(1)); + std::prover::set_hint(x, query |_| std::prover::Query::Hint(2)); + }; + "#; + analyze_string::(input); +} + +#[test] +fn set_hint_outside() { + let input = r#" + namespace std::prover; + let set_hint = 8; + let eval = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + let y; + let create_wit = constr || { let w; w }; + let z = create_wit(); + let set_hint = constr |c| { std::prover::set_hint(c, query |_| std::prover::Query::Hint(8)); }; + set_hint(x); + set_hint(y); + (|| { set_hint(z); })(); + "#; + let expected = r#"namespace std::prover; + let set_hint = 8; + let eval = 8; + enum Query { + Hint(fe), + None, + } +namespace N(16); + col witness x(_) query std::prover::Query::Hint(8); + col witness y(_) query std::prover::Query::Hint(8); + let create_wit: -> expr = (constr || { + let w: col; + w + }); + let z: expr = N.create_wit(); + let set_hint: expr -> () = (constr |c| { + std::prover::set_hint(c, (query |_| std::prover::Query::Hint(8))); + + }); + col witness w(_) query std::prover::Query::Hint(8); +"#; + let formatted = analyze_string::(input).to_string(); + assert_eq!(formatted, expected); +} diff --git a/pil-analyzer/tests/side_effects.rs b/pil-analyzer/tests/side_effects.rs index 8c93de996..434f27c87 100644 --- a/pil-analyzer/tests/side_effects.rs +++ b/pil-analyzer/tests/side_effects.rs @@ -100,3 +100,58 @@ fn fixed_with_constr_type() { let input = "let x: col = constr |i| 2;"; analyze_string::(input); } + +#[test] +fn set_hint() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + std::prover::set_hint(x, query |i| std::prover::Query::Hint(1)); + "#; + analyze_string::(input); +} + +#[test] +fn set_hint_can_use_query() { + let input = r#" + namespace std::prover; + let set_hint = 8; + let eval = 7; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + let y; + std::prover::set_hint(x, query |_| std::prover::Query::Hint(std::prover::eval(y))); + "#; + analyze_string::(input); +} + +#[test] +fn set_hint_pure() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + std::prover::set_hint(x, |i| std::prover::Query::Hint(1)); + "#; + analyze_string::(input); +} + +#[test] +#[should_panic = "Used a constr lambda function inside a query context"] +fn set_hint_constr() { + let input = r#" + namespace std::prover; + let set_hint = 8; + enum Query { Hint(fe), None, } + namespace N(16); + let x; + std::prover::set_hint(x, constr |i| std::prover::Query::Hint(1)); + "#; + analyze_string::(input); +}