diff --git a/pil_analyzer/src/condenser.rs b/pil_analyzer/src/condenser.rs index b11cbab51..50c17375a 100644 --- a/pil_analyzer/src/condenser.rs +++ b/pil_analyzer/src/condenser.rs @@ -1,20 +1,22 @@ //! Component that turns data from the PILAnalyzer into Analyzed, //! i.e. it turns more complex expressions in identities to simpler expressions. -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Display, rc::Rc}; use ast::{ analyzed::{ AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition, - Identity, PolyID, PolynomialReference, PolynomialType, PublicDeclaration, Reference, + Identity, PolynomialReference, PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind, }, - evaluate_binary_operation, evaluate_unary_operation, - parsed::{visitor::ExpressionVisitable, IndexAccess, SelectedExpressions, UnaryOperator}, + parsed::{visitor::ExpressionVisitable, SelectedExpressions, UnaryOperator}, }; +use itertools::Itertools; use number::{DegreeType, FieldElement}; -use crate::evaluator; +use crate::evaluator::{ + self, evaluate, evaluate_function_call, Custom, EvalError, SymbolLookup, Value, +}; pub fn condense( degree: Option, @@ -24,7 +26,6 @@ pub fn condense( source_order: Vec, ) -> Analyzed { let condenser = Condenser { - constants: compute_constants(&definitions), symbols: definitions.clone(), }; @@ -78,8 +79,6 @@ pub fn condense( pub struct Condenser { /// All the definitions from the PIL file. pub symbols: HashMap>)>, - /// Definitions that evaluate to constant numbers. - pub constants: HashMap, } impl Condenser { @@ -125,143 +124,202 @@ impl Condenser { } pub fn condense_expression(&self, e: &Expression) -> AlgebraicExpression { - match e { - Expression::Reference(Reference::Poly(poly)) => { - if let Some(value) = self.constants.get(&poly.name) { - return AlgebraicExpression::Number(*value); + evaluator::evaluate(e, &self) + .and_then(|result| { + // TODO at this point, we could also support arrays of constraints, but we would + // need to make it clear if an array is supported in this context (it is at statement level, + // but not inside a lookup for example). + match result { + Value::Custom(Condensate { expr, .. }) => Ok(expr), + Value::Number(n) => Ok(n.into()), + _ => Err(EvalError::TypeError(format!( + "Expected constraint, but got {result}" + ))), } - let symbol = &self - .symbols - .get(&poly.name) - .unwrap_or_else(|| panic!("Symbol {} not found.", poly.name)) - .0; + }) + .unwrap_or_else(|err| { + panic!("Error reducing expression to constraint:\nExpression: {e}\nError: {err:?}") + }) + } +} - assert!( - symbol.length.is_none(), - "Arrays cannot be used as a whole in this context, only individual array elements can be used." - ); +impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate> for &'a Condenser { + fn lookup(&self, name: &str) -> Result>, EvalError> { + let name = name.to_string(); + let (symbol, value) = &self + .symbols + .get(&name) + .ok_or_else(|| EvalError::SymbolNotFound(format!("Symbol {name} not found.")))?; + Ok(if matches!(symbol.kind, SymbolKind::Poly(_)) { + if symbol.is_array() { + Value::Array( + symbol + .array_elements() + .map(|(name, poly_id)| { + AlgebraicExpression::Reference(AlgebraicReference { + name, + poly_id, + next: false, + }) + .into() + }) + .collect(), + ) + } else { AlgebraicExpression::Reference(AlgebraicReference { - name: poly.name.clone(), + name, poly_id: symbol.into(), next: false, }) + .into() } - Expression::Reference(Reference::LocalVar(_, _)) => { - panic!("Local variables not allowed here.") - } - Expression::Number(n) => AlgebraicExpression::Number(*n), - Expression::BinaryOperation(left, op, right) => { - match ( - self.condense_expression(left), - self.condense_expression(right), - ) { - (AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => { - AlgebraicExpression::Number(evaluate_binary_operation(l, *op, r)) - } - (l, r) => AlgebraicExpression::BinaryOperation( - Box::new(l), - (*op).try_into().unwrap(), - Box::new(r), - ), + } else { + match value { + Some(FunctionValueDefinition::Expression(value)) => { + evaluator::evaluate(value, self)? } + _ => Err(EvalError::Unsupported( + "Cannot evaluate arrays and queries.".to_string(), + ))?, } - Expression::UnaryOperation(op, inner) => { - let inner = self.condense_expression(inner); - if *op == UnaryOperator::Next { - let AlgebraicExpression::Reference(reference) = inner else { - panic!( - "Can apply \"'\" operator only directly to columns in this context." - ); + }) + } + + fn lookup_public_reference( + &self, + name: &str, + ) -> Result>, EvalError> { + Ok(AlgebraicExpression::PublicReference(name.to_string()).into()) + } + + fn eval_function_application( + &self, + function: Condensate, + arguments: &[Rc>>], + ) -> Result>, EvalError> { + match function.expr { + AlgebraicExpression::Reference(AlgebraicReference { + name, + poly_id, + next, + }) if poly_id.ptype == PolynomialType::Constant => { + let arguments = if next { + assert_eq!(arguments.len(), 1); + let Value::Number(arg) = *arguments[0] else { + return Err(EvalError::TypeError( + "Expected numeric argument when evaluating function with next ref." + .to_string(), + )); }; - - assert!(!reference.next, "Double application of \"'\""); - AlgebraicExpression::Reference(AlgebraicReference { - next: true, - ..reference - }) + vec![Rc::new(Value::Number(arg + 1.into()))] } else { - match inner { - AlgebraicExpression::Number(n) => { - AlgebraicExpression::Number(evaluate_unary_operation(*op, n)) - } - _ => AlgebraicExpression::UnaryOperation( - (*op).try_into().unwrap(), - Box::new(inner), - ), + arguments.to_vec() + }; + + match self.symbols[&name].1.as_ref() { + Some(FunctionValueDefinition::Expression(v)) => { + let function = evaluate(v, self)?; + evaluate_function_call(function, arguments, self) } + None => Err(EvalError::SymbolNotFound(format!( + "Symbol not found in function call: {name}" + ))), + _ => Err(EvalError::Unsupported(format!( + "Cannot evaluate arrays or queries: {name}" + ))), } } - Expression::PublicReference(r) => AlgebraicExpression::PublicReference(r.clone()), - Expression::IndexAccess(IndexAccess { array, index }) => { - let array_symbol = match array.as_ref() { - ast::parsed::Expression::Reference(Reference::Poly(PolynomialReference { - name, - poly_id: _, - })) => { - &self - .symbols - .get(name) - .unwrap_or_else(|| panic!("Symbol {name} not found.")) - .0 - } - _ => panic!("Expected direct reference before array index access."), - }; - let Some(length) = array_symbol.length else { - panic!("Array-access for non-array {}.", array_symbol.absolute_name); - }; + _ => Err(EvalError::TypeError(format!( + "Function application not supported: {function}({})", + arguments.iter().format(", ") + ))), + } + } - let index = evaluator::evaluate_expression(index, &self.symbols) - .and_then(|v| v.try_to_number()) - .expect("Index needs to be constant number.") - .to_degree(); - assert!( - index < length, - "Array access to index {index} for array of length {length}: {}", - array_symbol.absolute_name, - ); - let poly_id: PolyID = array_symbol.into(); - AlgebraicExpression::Reference(AlgebraicReference { - poly_id: PolyID { - id: poly_id.id + index, - ..poly_id - }, - name: array_symbol.array_element_name(index), - next: false, - }) + fn eval_binary_operation( + &self, + left: Value<'a, T, Condensate>, + op: ast::parsed::BinaryOperator, + right: Value<'a, T, Condensate>, + ) -> Result>, EvalError> { + let left: Condensate = left.try_into()?; + let right: Condensate = right.try_into()?; + Ok(AlgebraicExpression::BinaryOperation( + Box::new(left.expr), + op.try_into().map_err(EvalError::TypeError)?, + Box::new(right.expr), + ) + .into()) + } + + fn eval_unary_operation( + &self, + op: UnaryOperator, + inner: Condensate, + ) -> Result>, EvalError> { + if op == UnaryOperator::Next { + let AlgebraicExpression::Reference(reference) = inner.expr else { + return Err(EvalError::TypeError(format!( + "Expected column for \"'\" operator, but got: {inner}" + ))); + }; + + if reference.next { + return Err(EvalError::TypeError(format!( + "Double application of \"'\" on: {reference}" + ))); } - Expression::String(_) => panic!("Strings are not allowed here."), - Expression::Tuple(_) => panic!(), - Expression::LambdaExpression(_) => panic!(), - Expression::ArrayLiteral(_) => panic!(), - Expression::FunctionCall(_) => panic!(), - Expression::FreeInput(_) => panic!(), - Expression::MatchExpression(_, _) => panic!(), + Ok(AlgebraicExpression::Reference(AlgebraicReference { + next: true, + ..reference + }) + .into()) + } else { + Ok(AlgebraicExpression::UnaryOperation( + op.try_into().map_err(EvalError::TypeError)?, + Box::new(inner.expr), + ) + .into()) } } } -/// Returns a HashMap of all symbols that have a constant single value. -fn compute_constants( - definitions: &HashMap>)>, -) -> HashMap { - definitions - .iter() - .filter_map(|(name, (symbol, value))| { - // TODO we could try to compute anything that evaluates to a "value" here. - if symbol.kind == SymbolKind::Constant() { - let Some(FunctionValueDefinition::Expression(value)) = value else { - panic!(); - }; - Ok(value) - .and_then(|value| { - evaluator::evaluate_expression(value, definitions)?.try_to_number() - }) - .map(|value| (name.to_owned(), value)) - .ok() - } else { - None - } - }) - .collect() +#[derive(Clone)] +struct Condensate { + pub expr: AlgebraicExpression, +} + +impl PartialEq for Condensate { + fn eq(&self, other: &Self) -> bool { + self.expr == other.expr + } +} + +impl Display for Condensate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } +} + +impl Custom for Condensate {} + +impl<'a, T: FieldElement> TryFrom> for Condensate { + type Error = EvalError; + + fn try_from(value: Value<'a, T, Self>) -> Result { + match value { + Value::Number(n) => Ok(Self { expr: n.into() }), + Value::Custom(v) => Ok(v), + value => Err(EvalError::TypeError(format!( + "Expected algebraic expression, got {value}" + ))), + } + } +} + +impl<'a, T: FieldElement> From> for Value<'a, T, Condensate> { + fn from(expr: AlgebraicExpression) -> Self { + Value::Custom(Condensate { expr }) + } } diff --git a/pil_analyzer/src/evaluator.rs b/pil_analyzer/src/evaluator.rs index 9f9cc8031..62726fd34 100644 --- a/pil_analyzer/src/evaluator.rs +++ b/pil_analyzer/src/evaluator.rs @@ -3,7 +3,10 @@ use std::{collections::HashMap, fmt::Display, rc::Rc}; use ast::{ analyzed::{Expression, FunctionValueDefinition, Reference, Symbol}, evaluate_binary_operation, evaluate_unary_operation, - parsed::{display::quote, FunctionCall, LambdaExpression, MatchArm, MatchPattern}, + parsed::{ + display::quote, BinaryOperator, FunctionCall, LambdaExpression, MatchArm, MatchPattern, + UnaryOperator, + }, }; use itertools::Itertools; use number::FieldElement; @@ -199,6 +202,23 @@ pub trait SymbolLookup<'a, T, C> { function: C, arguments: &[Rc>], ) -> Result, EvalError>; + + fn eval_binary_operation( + &self, + _left: Value<'a, T, C>, + _op: BinaryOperator, + _right: Value<'a, T, C>, + ) -> Result, EvalError> { + unreachable!() + } + + fn eval_unary_operation( + &self, + _op: UnaryOperator, + _inner: C, + ) -> Result, EvalError> { + unreachable!() + } } mod internal { @@ -228,16 +248,27 @@ mod internal { .collect::>()?, ), Expression::BinaryOperation(left, op, right) => { - Value::Number(evaluate_binary_operation( - evaluate(left, locals, symbols)?.try_to_number()?, - *op, - evaluate(right, locals, symbols)?.try_to_number()?, - )) + let left = evaluate(left, locals, symbols)?; + let right = evaluate(right, locals, symbols)?; + match (&left, &right) { + (Value::Custom(_), _) | (_, Value::Custom(_)) => { + symbols.eval_binary_operation(left, *op, right)? + } + (Value::Number(l), Value::Number(r)) => { + Value::Number(evaluate_binary_operation(*l, *op, *r)) + } + _ => Err(EvalError::TypeError(format!( + "Operator {op} not supported on types: {left} {op} {right}" + )))?, + } } - Expression::UnaryOperation(op, expr) => Value::Number(evaluate_unary_operation( - *op, - evaluate(expr, locals, symbols)?.try_to_number()?, - )), + Expression::UnaryOperation(op, expr) => match evaluate(expr, locals, symbols)? { + Value::Custom(inner) => symbols.eval_unary_operation(*op, inner)?, + Value::Number(n) => Value::Number(evaluate_unary_operation(*op, n)), + inner => Err(EvalError::TypeError(format!( + "Operator {op} not supported on types: {op} {inner}" + )))?, + }, Expression::LambdaExpression(lambda) => { // TODO only copy the part of the environment that is actually referenced? (Closure { diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 402e37196..60d9dbec3 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -834,7 +834,7 @@ namespace N(65536); } #[test] - #[should_panic = "Arrays cannot be used as a whole in this context"] + #[should_panic = "Operator - not supported on types"] fn no_direct_array_references() { let input = r#"namespace N(16); col witness y[3]; @@ -845,7 +845,7 @@ namespace N(65536); } #[test] - #[should_panic = "Array access to index 3 for array of length 3"] + #[should_panic = "Tried to access element 3 of array of size 3."] fn no_out_of_bounds() { let input = r#"namespace N(16); col witness y[3]; @@ -865,4 +865,74 @@ namespace N(65536); let formatted = process_pil_file_contents::(input).to_string(); assert_eq!(formatted, input); } + + #[test] + fn symbolic_functions() { + let input = r#"namespace N(16); + let last_row = 15; + let ISLAST = |i| match i { last_row => 1, _ => 0 }; + let x; + let y; + let constrain_equal_expr = |A, B| A - B; + let on_regular_row = |cond| (1 - ISLAST) * cond; + on_regular_row(constrain_equal_expr(x', y)) = 0; + on_regular_row(constrain_equal_expr(y', x + y)) = 0; + "#; + let expected = r#"constant last_row = 15; +namespace N(16); + col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } }; + col witness x; + col witness y; + let constrain_equal_expr = |A, B| (A - B); + col fixed on_regular_row(cond) { ((1 - N.ISLAST) * cond) }; + ((1 - N.ISLAST) * (N.x' - N.y)) = 0; + ((1 - N.ISLAST) * (N.y' - (N.x + N.y))) = 0; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, expected); + } + + #[test] + fn next_op_on_param() { + let input = r#"namespace N(16); + let last_row = 15; + let ISLAST = |i| match i { last_row => 1, _ => 0 }; + let x; + let y; + let next_is_seven = |t| t' - 7; + next_is_seven(y) = 0; + "#; + let expected = r#"constant last_row = 15; +namespace N(16); + col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } }; + col witness x; + col witness y; + col fixed next_is_seven(t) { (t' - 7) }; + (N.y' - 7) = 0; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, expected); + } + + #[test] + fn fixed_concrete_and_symbolic() { + let input = r#"namespace N(16); + let last_row = 15; + let ISLAST = |i| match i { last_row => 1, _ => 0, }; + let x; + let y; + y - ISLAST(3) = 0; + x - ISLAST = 0; + "#; + let expected = r#"constant last_row = 15; +namespace N(16); + col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } }; + col witness x; + col witness y; + (N.y - 0) = 0; + (N.x - N.ISLAST) = 0; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, expected); + } }