diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index e7ebfe57c..d10041c66 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -4,7 +4,7 @@ use powdr_ast::analyzed::{ types::Type, AlgebraicExpression, AlgebraicReference, Expression, PolyID, PolynomialType, }; use powdr_number::{BigInt, FieldElement}; -use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, NoCustom, SymbolLookup, Value}; +use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, SymbolLookup, Value}; use super::{rows::RowPair, Constraint, EvalResult, EvalValue, FixedData, IncompleteCause}; @@ -96,16 +96,16 @@ struct Symbols<'a, T: FieldElement> { rows: &'a RowPair<'a, 'a, T>, } -impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Symbols<'a, T> { +impl<'a, T: FieldElement> SymbolLookup<'a, T> for Symbols<'a, T> { fn lookup<'b>( &self, name: &'a str, generic_args: Option>, - ) -> Result, EvalError> { + ) -> Result, EvalError> { Definitions(&self.fixed_data.analyzed.definitions).lookup(name, generic_args) } - fn eval_expr(&self, expr: AlgebraicExpression) -> Result, EvalError> { + fn eval_expr(&self, expr: AlgebraicExpression) -> Result, EvalError> { let AlgebraicExpression::Reference(poly_ref) = expr else { return Err(EvalError::TypeError(format!( "Can use std::prover::eval only directly on columns - tried to evaluate {expr}" diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index d8f8209a0..f88bb43c5 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -22,36 +22,36 @@ use powdr_number::{BigInt, FieldElement, LargeInt}; pub fn evaluate_expression<'a, T: FieldElement>( expr: &'a Expression, definitions: &'a HashMap>)>, -) -> Result, EvalError> { +) -> Result, EvalError> { evaluate(expr, &Definitions(definitions)) } /// Evaluates an expression given a symbol lookup implementation -pub fn evaluate<'a, T: FieldElement, C: Custom>( +pub fn evaluate<'a, T: FieldElement>( expr: &'a Expression, - symbols: &impl SymbolLookup<'a, T, C>, -) -> Result, EvalError> { + symbols: &impl SymbolLookup<'a, T>, +) -> Result, EvalError> { evaluate_generic(expr, &Default::default(), symbols) } /// Evaluates a generic expression given a symbol lookup implementation /// and values for the generic type parameters. -pub fn evaluate_generic<'a, 'b, T: FieldElement, C: Custom>( +pub fn evaluate_generic<'a, 'b, T: FieldElement>( expr: &'a Expression, generic_args: &'b HashMap, - symbols: &impl SymbolLookup<'a, T, C>, -) -> Result, EvalError> { + symbols: &impl SymbolLookup<'a, T>, +) -> Result, EvalError> { internal::evaluate(expr, &[], generic_args, symbols) } /// Evaluates a function call. -pub fn evaluate_function_call<'a, T: FieldElement, C: Custom>( - function: Value<'a, T, C>, - arguments: Vec>>, - symbols: &impl SymbolLookup<'a, T, C>, +pub fn evaluate_function_call<'a, T: FieldElement>( + function: Value<'a, T>, + arguments: Vec>>, + symbols: &impl SymbolLookup<'a, T>, // TODO maybe we should also make this return an Rc. // Otherwise we might have to clone big nested objects. -) -> Result, EvalError> { +) -> Result, EvalError> { match function { Value::BuiltinFunction(b) => internal::evaluate_builtin_function(b, arguments, symbols), Value::Closure(Closure { @@ -146,33 +146,32 @@ impl Display for EvalError { } #[derive(Clone, PartialEq, Debug)] -pub enum Value<'a, T, C> { +pub enum Value<'a, T> { Bool(bool), Integer(BigInt), FieldElement(T), String(String), Tuple(Vec), Array(Vec), - Closure(Closure<'a, T, C>), + Closure(Closure<'a, T>), BuiltinFunction(BuiltinFunction), Expression(AlgebraicExpression), Identity(AlgebraicExpression, AlgebraicExpression), - Custom(C), } -impl<'a, T: FieldElement, C> From for Value<'a, T, C> { +impl<'a, T: FieldElement> From for Value<'a, T> { fn from(value: T) -> Self { Value::FieldElement(value) } } -impl<'a, T: FieldElement, C> From> for Value<'a, T, C> { +impl<'a, T: FieldElement> From> for Value<'a, T> { fn from(value: AlgebraicExpression) -> Self { Value::Expression(value) } } -impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> { +impl<'a, T: FieldElement> Value<'a, T> { /// Tries to convert the value to a field element. For integers, this only works /// if the integer is non-negative and less than the modulus. pub fn try_to_field_element(self) -> Result { @@ -228,7 +227,6 @@ impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> { Value::BuiltinFunction(b) => format!("builtin_{b:?}"), Value::Expression(_) => "expr".to_string(), Value::Identity(_, _) => "constr".to_string(), - Value::Custom(c) => c.type_name(), } } } @@ -270,7 +268,7 @@ pub trait Custom: Display + fmt::Debug + Clone + PartialEq { fn type_name(&self) -> String; } -impl<'a, T: Display, C: Custom> Display for Value<'a, T, C> { +impl<'a, T: Display> Display for Value<'a, T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Value::Bool(b) => write!(f, "{b}"), @@ -283,34 +281,18 @@ impl<'a, T: Display, C: Custom> Display for Value<'a, T, C> { Value::BuiltinFunction(b) => write!(f, "{b:?}"), Value::Expression(e) => write!(f, "{e}"), Value::Identity(left, right) => write!(f, "{left} = {right}"), - Value::Custom(c) => write!(f, "{c}"), } } } -#[derive(Clone, PartialEq, Debug)] -pub enum NoCustom {} - -impl Custom for NoCustom { - fn type_name(&self) -> String { - unreachable!(); - } -} - -impl Display for NoCustom { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - unreachable!() - } -} - #[derive(Clone, Debug)] -pub struct Closure<'a, T, C> { +pub struct Closure<'a, T> { pub lambda: &'a LambdaExpression, - pub environment: Vec>>, + pub environment: Vec>>, pub generic_args: HashMap, } -impl<'a, T, C> PartialEq for Closure<'a, T, C> { +impl<'a, T> PartialEq for Closure<'a, T> { fn eq(&self, _other: &Self) -> bool { // Eq is used for pattern matching. // In the future, we should introduce a proper pattern type. @@ -318,19 +300,19 @@ impl<'a, T, C> PartialEq for Closure<'a, T, C> { } } -impl<'a, T: Display, C> Display for Closure<'a, T, C> { +impl<'a, T: Display> Display for Closure<'a, T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.lambda) } } -impl<'a, T, C> From> for Value<'a, T, C> { - fn from(value: Closure<'a, T, C>) -> Self { +impl<'a, T> From> for Value<'a, T> { + fn from(value: Closure<'a, T>) -> Self { Value::Closure(value) } } -impl<'a, T, C> Closure<'a, T, C> { +impl<'a, T> Closure<'a, T> { pub fn type_name(&self) -> String { // TODO should use proper types as soon as we have them "closure".to_string() @@ -341,12 +323,12 @@ pub struct Definitions<'a, T>( pub &'a HashMap>)>, ); -impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> { +impl<'a, T: FieldElement> SymbolLookup<'a, T> for Definitions<'a, T> { fn lookup<'b>( &self, name: &str, generic_args: Option>, - ) -> Result, EvalError> { + ) -> Result, EvalError> { let name = name.to_string(); let (symbol, value) = &self .0 @@ -391,7 +373,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> { }) } - fn lookup_public_reference(&self, name: &str) -> Result, EvalError> { + fn lookup_public_reference(&self, name: &str) -> Result, EvalError> { Ok(AlgebraicExpression::PublicReference(name.to_string()).into()) } } @@ -404,36 +386,20 @@ impl<'a, T: FieldElement> From<&'a HashMap { +pub trait SymbolLookup<'a, T> { fn lookup( &self, name: &'a str, generic_args: Option>, - ) -> Result, EvalError>; - fn lookup_public_reference(&self, name: &'a str) -> Result, EvalError> { + ) -> Result, EvalError>; + + fn lookup_public_reference(&self, name: &'a str) -> Result, EvalError> { Err(EvalError::Unsupported(format!( "Cannot evaluate public reference: {name}" ))) } - fn eval_binary_operation( - &self, - _left: Value<'a, T, C>, - _op: BinaryOperator, - _right: Value<'a, T, C>, - ) -> Result, EvalError> { - unreachable!() - } - - fn eval_unary_operation( - &self, - _op: UnaryOperator, - _inner: C, - ) -> Result, EvalError> { - unreachable!() - } - - fn eval_expr(&self, _expr: AlgebraicExpression) -> Result, EvalError> { + fn eval_expr(&self, _expr: AlgebraicExpression) -> Result, EvalError> { Err(EvalError::DataNotAvailable) } } @@ -448,12 +414,12 @@ mod internal { use super::*; - pub fn evaluate<'a, 'b, T: FieldElement, C: Custom>( + pub fn evaluate<'a, 'b, T: FieldElement>( expr: &'a Expression, - locals: &[Rc>], + locals: &[Rc>], generic_args: &'b HashMap, - symbols: &impl SymbolLookup<'a, T, C>, - ) -> Result, EvalError> { + symbols: &impl SymbolLookup<'a, T>, + ) -> Result, EvalError> { Ok(match expr { Expression::Reference(reference) => { evaluate_reference(reference, locals, generic_args, symbols)? @@ -477,11 +443,10 @@ mod internal { Expression::BinaryOperation(left, op, right) => { let left = evaluate(left, locals, generic_args, symbols)?; let right = evaluate(right, locals, generic_args, symbols)?; - evaluate_binary_operation(left, *op, right, symbols)? + evaluate_binary_operation(left, *op, right)? } Expression::UnaryOperation(op, expr) => { match (op, evaluate(expr, locals, generic_args, symbols)?) { - (_, Value::Custom(inner)) => symbols.eval_unary_operation(*op, inner)?, (UnaryOperator::Minus, Value::FieldElement(e)) => Value::FieldElement(-e), (UnaryOperator::LogicalNot, Value::Bool(b)) => Value::Bool(!b), (UnaryOperator::Minus, Value::Integer(n)) => Value::Integer(-n), @@ -601,11 +566,11 @@ mod internal { }) } - fn evaluate_literal<'a, T: FieldElement, C: Custom>( + fn evaluate_literal<'a, T: FieldElement>( n: &T, ty: &Option>, generic_args: &HashMap, - ) -> Result, EvalError> { + ) -> Result, EvalError> { let ty = if let Some(TypeName::TypeVar(tv)) = ty { match &generic_args[tv] { Type::Fe => TypeName::Fe, @@ -631,12 +596,12 @@ mod internal { }) } - fn evaluate_reference<'a, T: FieldElement, C: Custom>( + fn evaluate_reference<'a, T: FieldElement>( reference: &'a Reference, - locals: &[Rc>], + locals: &[Rc>], generic_args: &HashMap, - symbols: &impl SymbolLookup<'a, T, C>, - ) -> Result, EvalError> { + symbols: &impl SymbolLookup<'a, T>, + ) -> Result, EvalError> { Ok(match reference { Reference::LocalVar(i, _name) => (*locals[*i as usize]).clone(), @@ -656,16 +621,12 @@ mod internal { }) } - fn evaluate_binary_operation<'a, T: FieldElement, C: Custom>( - left: Value<'a, T, C>, + fn evaluate_binary_operation<'a, T: FieldElement>( + left: Value<'a, T>, op: BinaryOperator, - right: Value<'a, T, C>, - symbols: &impl SymbolLookup<'a, T, C>, - ) -> Result, EvalError> { + right: Value<'a, T>, + ) -> Result, EvalError> { Ok(match (left, op, right) { - (l @ Value::Custom(_), _, r) | (l, _, r @ Value::Custom(_)) => { - symbols.eval_binary_operation(l, op, r)? - } (Value::Array(mut l), BinaryOperator::Add, Value::Array(r)) => { l.extend(r); Value::Array(l) @@ -716,7 +677,7 @@ mod internal { (Value::Expression(l), op, Value::Expression(r)) => match (l, r) { (AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => { let Value::FieldElement(result) = - evaluate_binary_operation_field::<'a, T, C>(l, op, r)? + evaluate_binary_operation_field::<'a, T>(l, op, r)? else { panic!() }; @@ -738,11 +699,11 @@ mod internal { } #[allow(clippy::print_stdout)] - pub fn evaluate_builtin_function<'a, T: FieldElement, C: Custom>( + pub fn evaluate_builtin_function<'a, T: FieldElement>( b: BuiltinFunction, - mut arguments: Vec>>, - symbols: &impl SymbolLookup<'a, T, C>, - ) -> Result, EvalError> { + mut arguments: Vec>>, + symbols: &impl SymbolLookup<'a, T>, + ) -> Result, EvalError> { let params = match b { BuiltinFunction::ArrayLen => 1, BuiltinFunction::Modulus => 0, @@ -816,11 +777,11 @@ mod internal { } } -pub fn evaluate_binary_operation_field<'a, T: FieldElement, C>( +pub fn evaluate_binary_operation_field<'a, T: FieldElement>( left: T, op: BinaryOperator, right: T, -) -> Result, EvalError> { +) -> Result, EvalError> { Ok(match op { BinaryOperator::Add => Value::FieldElement(left + right), BinaryOperator::Sub => Value::FieldElement(left - right), @@ -833,11 +794,11 @@ pub fn evaluate_binary_operation_field<'a, T: FieldElement, C>( }) } -pub fn evaluate_binary_operation_integer<'a, T, C>( +pub fn evaluate_binary_operation_integer<'a, T>( left: &BigInt, op: BinaryOperator, right: &BigInt, -) -> Result, EvalError> { +) -> Result, EvalError> { Ok(match op { BinaryOperator::Add => Value::Integer(left + right), BinaryOperator::Sub => Value::Integer(left - right), @@ -880,7 +841,7 @@ mod test { else { panic!() }; - evaluate::<_, NoCustom>(symbol, &Definitions(&analyzed.definitions)) + evaluate(symbol, &Definitions(&analyzed.definitions)) .unwrap() .to_string() } diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index 0edc434bf..0dc5be1c3 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -157,8 +157,8 @@ pub fn std_analyzed() -> Analyzed { pub fn evaluate_function<'a, T: FieldElement>( analyzed: &'a Analyzed, function: &'a str, - arguments: Vec>>, -) -> evaluator::Value<'a, T, evaluator::NoCustom> { + arguments: Vec>>, +) -> evaluator::Value<'a, T> { let symbols = evaluator::Definitions(&analyzed.definitions); let function = symbols.lookup(function, None).unwrap(); evaluator::evaluate_function_call(function, arguments, &symbols).unwrap()