mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-04-20 03:03:25 -04:00
Remove Custom.
This commit is contained in:
@@ -4,7 +4,7 @@ use powdr_ast::analyzed::{
|
||||
types::Type, AlgebraicExpression, AlgebraicReference, Expression, PolyID, PolynomialType,
|
||||
};
|
||||
use powdr_number::{BigInt, FieldElement};
|
||||
use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, NoCustom, SymbolLookup, Value};
|
||||
use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, SymbolLookup, Value};
|
||||
|
||||
use super::{rows::RowPair, Constraint, EvalResult, EvalValue, FixedData, IncompleteCause};
|
||||
|
||||
@@ -96,16 +96,16 @@ struct Symbols<'a, T: FieldElement> {
|
||||
rows: &'a RowPair<'a, 'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Symbols<'a, T> {
|
||||
impl<'a, T: FieldElement> SymbolLookup<'a, T> for Symbols<'a, T> {
|
||||
fn lookup<'b>(
|
||||
&self,
|
||||
name: &'a str,
|
||||
generic_args: Option<Vec<Type>>,
|
||||
) -> Result<Value<'a, T, NoCustom>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
Definitions(&self.fixed_data.analyzed.definitions).lookup(name, generic_args)
|
||||
}
|
||||
|
||||
fn eval_expr(&self, expr: AlgebraicExpression<T>) -> Result<Value<'a, T, NoCustom>, EvalError> {
|
||||
fn eval_expr(&self, expr: AlgebraicExpression<T>) -> Result<Value<'a, T>, EvalError> {
|
||||
let AlgebraicExpression::Reference(poly_ref) = expr else {
|
||||
return Err(EvalError::TypeError(format!(
|
||||
"Can use std::prover::eval only directly on columns - tried to evaluate {expr}"
|
||||
|
||||
@@ -22,36 +22,36 @@ use powdr_number::{BigInt, FieldElement, LargeInt};
|
||||
pub fn evaluate_expression<'a, T: FieldElement>(
|
||||
expr: &'a Expression<T>,
|
||||
definitions: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
|
||||
) -> Result<Value<'a, T, NoCustom>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
evaluate(expr, &Definitions(definitions))
|
||||
}
|
||||
|
||||
/// Evaluates an expression given a symbol lookup implementation
|
||||
pub fn evaluate<'a, T: FieldElement, C: Custom>(
|
||||
pub fn evaluate<'a, T: FieldElement>(
|
||||
expr: &'a Expression<T>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
symbols: &impl SymbolLookup<'a, T>,
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
evaluate_generic(expr, &Default::default(), symbols)
|
||||
}
|
||||
|
||||
/// Evaluates a generic expression given a symbol lookup implementation
|
||||
/// and values for the generic type parameters.
|
||||
pub fn evaluate_generic<'a, 'b, T: FieldElement, C: Custom>(
|
||||
pub fn evaluate_generic<'a, 'b, T: FieldElement>(
|
||||
expr: &'a Expression<T>,
|
||||
generic_args: &'b HashMap<String, Type>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
symbols: &impl SymbolLookup<'a, T>,
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
internal::evaluate(expr, &[], generic_args, symbols)
|
||||
}
|
||||
|
||||
/// Evaluates a function call.
|
||||
pub fn evaluate_function_call<'a, T: FieldElement, C: Custom>(
|
||||
function: Value<'a, T, C>,
|
||||
arguments: Vec<Rc<Value<'a, T, C>>>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
pub fn evaluate_function_call<'a, T: FieldElement>(
|
||||
function: Value<'a, T>,
|
||||
arguments: Vec<Rc<Value<'a, T>>>,
|
||||
symbols: &impl SymbolLookup<'a, T>,
|
||||
// TODO maybe we should also make this return an Rc<Value>.
|
||||
// Otherwise we might have to clone big nested objects.
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
match function {
|
||||
Value::BuiltinFunction(b) => internal::evaluate_builtin_function(b, arguments, symbols),
|
||||
Value::Closure(Closure {
|
||||
@@ -146,33 +146,32 @@ impl Display for EvalError {
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Value<'a, T, C> {
|
||||
pub enum Value<'a, T> {
|
||||
Bool(bool),
|
||||
Integer(BigInt),
|
||||
FieldElement(T),
|
||||
String(String),
|
||||
Tuple(Vec<Self>),
|
||||
Array(Vec<Self>),
|
||||
Closure(Closure<'a, T, C>),
|
||||
Closure(Closure<'a, T>),
|
||||
BuiltinFunction(BuiltinFunction),
|
||||
Expression(AlgebraicExpression<T>),
|
||||
Identity(AlgebraicExpression<T>, AlgebraicExpression<T>),
|
||||
Custom(C),
|
||||
}
|
||||
|
||||
impl<'a, T: FieldElement, C> From<T> for Value<'a, T, C> {
|
||||
impl<'a, T: FieldElement> From<T> for Value<'a, T> {
|
||||
fn from(value: T) -> Self {
|
||||
Value::FieldElement(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: FieldElement, C> From<AlgebraicExpression<T>> for Value<'a, T, C> {
|
||||
impl<'a, T: FieldElement> From<AlgebraicExpression<T>> for Value<'a, T> {
|
||||
fn from(value: AlgebraicExpression<T>) -> Self {
|
||||
Value::Expression(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> {
|
||||
impl<'a, T: FieldElement> Value<'a, T> {
|
||||
/// Tries to convert the value to a field element. For integers, this only works
|
||||
/// if the integer is non-negative and less than the modulus.
|
||||
pub fn try_to_field_element(self) -> Result<T, EvalError> {
|
||||
@@ -228,7 +227,6 @@ impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> {
|
||||
Value::BuiltinFunction(b) => format!("builtin_{b:?}"),
|
||||
Value::Expression(_) => "expr".to_string(),
|
||||
Value::Identity(_, _) => "constr".to_string(),
|
||||
Value::Custom(c) => c.type_name(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,7 +268,7 @@ pub trait Custom: Display + fmt::Debug + Clone + PartialEq {
|
||||
fn type_name(&self) -> String;
|
||||
}
|
||||
|
||||
impl<'a, T: Display, C: Custom> Display for Value<'a, T, C> {
|
||||
impl<'a, T: Display> Display for Value<'a, T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Value::Bool(b) => write!(f, "{b}"),
|
||||
@@ -283,34 +281,18 @@ impl<'a, T: Display, C: Custom> Display for Value<'a, T, C> {
|
||||
Value::BuiltinFunction(b) => write!(f, "{b:?}"),
|
||||
Value::Expression(e) => write!(f, "{e}"),
|
||||
Value::Identity(left, right) => write!(f, "{left} = {right}"),
|
||||
Value::Custom(c) => write!(f, "{c}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum NoCustom {}
|
||||
|
||||
impl Custom for NoCustom {
|
||||
fn type_name(&self) -> String {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for NoCustom {
|
||||
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Closure<'a, T, C> {
|
||||
pub struct Closure<'a, T> {
|
||||
pub lambda: &'a LambdaExpression<T, Reference>,
|
||||
pub environment: Vec<Rc<Value<'a, T, C>>>,
|
||||
pub environment: Vec<Rc<Value<'a, T>>>,
|
||||
pub generic_args: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl<'a, T, C> PartialEq for Closure<'a, T, C> {
|
||||
impl<'a, T> PartialEq for Closure<'a, T> {
|
||||
fn eq(&self, _other: &Self) -> bool {
|
||||
// Eq is used for pattern matching.
|
||||
// In the future, we should introduce a proper pattern type.
|
||||
@@ -318,19 +300,19 @@ impl<'a, T, C> PartialEq for Closure<'a, T, C> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Display, C> Display for Closure<'a, T, C> {
|
||||
impl<'a, T: Display> Display for Closure<'a, T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.lambda)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, C> From<Closure<'a, T, C>> for Value<'a, T, C> {
|
||||
fn from(value: Closure<'a, T, C>) -> Self {
|
||||
impl<'a, T> From<Closure<'a, T>> for Value<'a, T> {
|
||||
fn from(value: Closure<'a, T>) -> Self {
|
||||
Value::Closure(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, C> Closure<'a, T, C> {
|
||||
impl<'a, T> Closure<'a, T> {
|
||||
pub fn type_name(&self) -> String {
|
||||
// TODO should use proper types as soon as we have them
|
||||
"closure".to_string()
|
||||
@@ -341,12 +323,12 @@ pub struct Definitions<'a, T>(
|
||||
pub &'a HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
|
||||
);
|
||||
|
||||
impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
|
||||
impl<'a, T: FieldElement> SymbolLookup<'a, T> for Definitions<'a, T> {
|
||||
fn lookup<'b>(
|
||||
&self,
|
||||
name: &str,
|
||||
generic_args: Option<Vec<Type>>,
|
||||
) -> Result<Value<'a, T, NoCustom>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
let name = name.to_string();
|
||||
let (symbol, value) = &self
|
||||
.0
|
||||
@@ -391,7 +373,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
|
||||
})
|
||||
}
|
||||
|
||||
fn lookup_public_reference(&self, name: &str) -> Result<Value<'a, T, NoCustom>, EvalError> {
|
||||
fn lookup_public_reference(&self, name: &str) -> Result<Value<'a, T>, EvalError> {
|
||||
Ok(AlgebraicExpression::PublicReference(name.to_string()).into())
|
||||
}
|
||||
}
|
||||
@@ -404,36 +386,20 @@ impl<'a, T: FieldElement> From<&'a HashMap<String, (Symbol, Option<FunctionValue
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SymbolLookup<'a, T, C> {
|
||||
pub trait SymbolLookup<'a, T> {
|
||||
fn lookup(
|
||||
&self,
|
||||
name: &'a str,
|
||||
generic_args: Option<Vec<Type>>,
|
||||
) -> Result<Value<'a, T, C>, EvalError>;
|
||||
fn lookup_public_reference(&self, name: &'a str) -> Result<Value<'a, T, C>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError>;
|
||||
|
||||
fn lookup_public_reference(&self, name: &'a str) -> Result<Value<'a, T>, EvalError> {
|
||||
Err(EvalError::Unsupported(format!(
|
||||
"Cannot evaluate public reference: {name}"
|
||||
)))
|
||||
}
|
||||
|
||||
fn eval_binary_operation(
|
||||
&self,
|
||||
_left: Value<'a, T, C>,
|
||||
_op: BinaryOperator,
|
||||
_right: Value<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn eval_unary_operation(
|
||||
&self,
|
||||
_op: UnaryOperator,
|
||||
_inner: C,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn eval_expr(&self, _expr: AlgebraicExpression<T>) -> Result<Value<'a, T, C>, EvalError> {
|
||||
fn eval_expr(&self, _expr: AlgebraicExpression<T>) -> Result<Value<'a, T>, EvalError> {
|
||||
Err(EvalError::DataNotAvailable)
|
||||
}
|
||||
}
|
||||
@@ -448,12 +414,12 @@ mod internal {
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn evaluate<'a, 'b, T: FieldElement, C: Custom>(
|
||||
pub fn evaluate<'a, 'b, T: FieldElement>(
|
||||
expr: &'a Expression<T>,
|
||||
locals: &[Rc<Value<'a, T, C>>],
|
||||
locals: &[Rc<Value<'a, T>>],
|
||||
generic_args: &'b HashMap<String, Type>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
symbols: &impl SymbolLookup<'a, T>,
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
Ok(match expr {
|
||||
Expression::Reference(reference) => {
|
||||
evaluate_reference(reference, locals, generic_args, symbols)?
|
||||
@@ -477,11 +443,10 @@ mod internal {
|
||||
Expression::BinaryOperation(left, op, right) => {
|
||||
let left = evaluate(left, locals, generic_args, symbols)?;
|
||||
let right = evaluate(right, locals, generic_args, symbols)?;
|
||||
evaluate_binary_operation(left, *op, right, symbols)?
|
||||
evaluate_binary_operation(left, *op, right)?
|
||||
}
|
||||
Expression::UnaryOperation(op, expr) => {
|
||||
match (op, evaluate(expr, locals, generic_args, symbols)?) {
|
||||
(_, Value::Custom(inner)) => symbols.eval_unary_operation(*op, inner)?,
|
||||
(UnaryOperator::Minus, Value::FieldElement(e)) => Value::FieldElement(-e),
|
||||
(UnaryOperator::LogicalNot, Value::Bool(b)) => Value::Bool(!b),
|
||||
(UnaryOperator::Minus, Value::Integer(n)) => Value::Integer(-n),
|
||||
@@ -601,11 +566,11 @@ mod internal {
|
||||
})
|
||||
}
|
||||
|
||||
fn evaluate_literal<'a, T: FieldElement, C: Custom>(
|
||||
fn evaluate_literal<'a, T: FieldElement>(
|
||||
n: &T,
|
||||
ty: &Option<TypeName<NoArrayLengths>>,
|
||||
generic_args: &HashMap<String, Type>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
let ty = if let Some(TypeName::TypeVar(tv)) = ty {
|
||||
match &generic_args[tv] {
|
||||
Type::Fe => TypeName::Fe,
|
||||
@@ -631,12 +596,12 @@ mod internal {
|
||||
})
|
||||
}
|
||||
|
||||
fn evaluate_reference<'a, T: FieldElement, C: Custom>(
|
||||
fn evaluate_reference<'a, T: FieldElement>(
|
||||
reference: &'a Reference,
|
||||
locals: &[Rc<Value<'a, T, C>>],
|
||||
locals: &[Rc<Value<'a, T>>],
|
||||
generic_args: &HashMap<String, Type>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
symbols: &impl SymbolLookup<'a, T>,
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
Ok(match reference {
|
||||
Reference::LocalVar(i, _name) => (*locals[*i as usize]).clone(),
|
||||
|
||||
@@ -656,16 +621,12 @@ mod internal {
|
||||
})
|
||||
}
|
||||
|
||||
fn evaluate_binary_operation<'a, T: FieldElement, C: Custom>(
|
||||
left: Value<'a, T, C>,
|
||||
fn evaluate_binary_operation<'a, T: FieldElement>(
|
||||
left: Value<'a, T>,
|
||||
op: BinaryOperator,
|
||||
right: Value<'a, T, C>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
right: Value<'a, T>,
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
Ok(match (left, op, right) {
|
||||
(l @ Value::Custom(_), _, r) | (l, _, r @ Value::Custom(_)) => {
|
||||
symbols.eval_binary_operation(l, op, r)?
|
||||
}
|
||||
(Value::Array(mut l), BinaryOperator::Add, Value::Array(r)) => {
|
||||
l.extend(r);
|
||||
Value::Array(l)
|
||||
@@ -716,7 +677,7 @@ mod internal {
|
||||
(Value::Expression(l), op, Value::Expression(r)) => match (l, r) {
|
||||
(AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => {
|
||||
let Value::FieldElement(result) =
|
||||
evaluate_binary_operation_field::<'a, T, C>(l, op, r)?
|
||||
evaluate_binary_operation_field::<'a, T>(l, op, r)?
|
||||
else {
|
||||
panic!()
|
||||
};
|
||||
@@ -738,11 +699,11 @@ mod internal {
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stdout)]
|
||||
pub fn evaluate_builtin_function<'a, T: FieldElement, C: Custom>(
|
||||
pub fn evaluate_builtin_function<'a, T: FieldElement>(
|
||||
b: BuiltinFunction,
|
||||
mut arguments: Vec<Rc<Value<'a, T, C>>>,
|
||||
symbols: &impl SymbolLookup<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
mut arguments: Vec<Rc<Value<'a, T>>>,
|
||||
symbols: &impl SymbolLookup<'a, T>,
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
let params = match b {
|
||||
BuiltinFunction::ArrayLen => 1,
|
||||
BuiltinFunction::Modulus => 0,
|
||||
@@ -816,11 +777,11 @@ mod internal {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn evaluate_binary_operation_field<'a, T: FieldElement, C>(
|
||||
pub fn evaluate_binary_operation_field<'a, T: FieldElement>(
|
||||
left: T,
|
||||
op: BinaryOperator,
|
||||
right: T,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
Ok(match op {
|
||||
BinaryOperator::Add => Value::FieldElement(left + right),
|
||||
BinaryOperator::Sub => Value::FieldElement(left - right),
|
||||
@@ -833,11 +794,11 @@ pub fn evaluate_binary_operation_field<'a, T: FieldElement, C>(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn evaluate_binary_operation_integer<'a, T, C>(
|
||||
pub fn evaluate_binary_operation_integer<'a, T>(
|
||||
left: &BigInt,
|
||||
op: BinaryOperator,
|
||||
right: &BigInt,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
) -> Result<Value<'a, T>, EvalError> {
|
||||
Ok(match op {
|
||||
BinaryOperator::Add => Value::Integer(left + right),
|
||||
BinaryOperator::Sub => Value::Integer(left - right),
|
||||
@@ -880,7 +841,7 @@ mod test {
|
||||
else {
|
||||
panic!()
|
||||
};
|
||||
evaluate::<_, NoCustom>(symbol, &Definitions(&analyzed.definitions))
|
||||
evaluate(symbol, &Definitions(&analyzed.definitions))
|
||||
.unwrap()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
@@ -157,8 +157,8 @@ pub fn std_analyzed<T: FieldElement>() -> Analyzed<T> {
|
||||
pub fn evaluate_function<'a, T: FieldElement>(
|
||||
analyzed: &'a Analyzed<T>,
|
||||
function: &'a str,
|
||||
arguments: Vec<Rc<evaluator::Value<'a, T, evaluator::NoCustom>>>,
|
||||
) -> evaluator::Value<'a, T, evaluator::NoCustom> {
|
||||
arguments: Vec<Rc<evaluator::Value<'a, T>>>,
|
||||
) -> evaluator::Value<'a, T> {
|
||||
let symbols = evaluator::Definitions(&analyzed.definitions);
|
||||
let function = symbols.lookup(function, None).unwrap();
|
||||
evaluator::evaluate_function_call(function, arguments, &symbols).unwrap()
|
||||
|
||||
Reference in New Issue
Block a user