Remove Custom.

This commit is contained in:
chriseth
2024-03-01 08:33:35 +01:00
parent ce015328ec
commit cb548153a6
3 changed files with 64 additions and 103 deletions

View File

@@ -4,7 +4,7 @@ use powdr_ast::analyzed::{
types::Type, AlgebraicExpression, AlgebraicReference, Expression, PolyID, PolynomialType,
};
use powdr_number::{BigInt, FieldElement};
use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, NoCustom, SymbolLookup, Value};
use powdr_pil_analyzer::evaluator::{self, Definitions, EvalError, SymbolLookup, Value};
use super::{rows::RowPair, Constraint, EvalResult, EvalValue, FixedData, IncompleteCause};
@@ -96,16 +96,16 @@ struct Symbols<'a, T: FieldElement> {
rows: &'a RowPair<'a, 'a, T>,
}
impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Symbols<'a, T> {
impl<'a, T: FieldElement> SymbolLookup<'a, T> for Symbols<'a, T> {
fn lookup<'b>(
&self,
name: &'a str,
generic_args: Option<Vec<Type>>,
) -> Result<Value<'a, T, NoCustom>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
Definitions(&self.fixed_data.analyzed.definitions).lookup(name, generic_args)
}
fn eval_expr(&self, expr: AlgebraicExpression<T>) -> Result<Value<'a, T, NoCustom>, EvalError> {
fn eval_expr(&self, expr: AlgebraicExpression<T>) -> Result<Value<'a, T>, EvalError> {
let AlgebraicExpression::Reference(poly_ref) = expr else {
return Err(EvalError::TypeError(format!(
"Can use std::prover::eval only directly on columns - tried to evaluate {expr}"

View File

@@ -22,36 +22,36 @@ use powdr_number::{BigInt, FieldElement, LargeInt};
pub fn evaluate_expression<'a, T: FieldElement>(
expr: &'a Expression<T>,
definitions: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
) -> Result<Value<'a, T, NoCustom>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
evaluate(expr, &Definitions(definitions))
}
/// Evaluates an expression given a symbol lookup implementation
pub fn evaluate<'a, T: FieldElement, C: Custom>(
pub fn evaluate<'a, T: FieldElement>(
expr: &'a Expression<T>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
symbols: &impl SymbolLookup<'a, T>,
) -> Result<Value<'a, T>, EvalError> {
evaluate_generic(expr, &Default::default(), symbols)
}
/// Evaluates a generic expression given a symbol lookup implementation
/// and values for the generic type parameters.
pub fn evaluate_generic<'a, 'b, T: FieldElement, C: Custom>(
pub fn evaluate_generic<'a, 'b, T: FieldElement>(
expr: &'a Expression<T>,
generic_args: &'b HashMap<String, Type>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
symbols: &impl SymbolLookup<'a, T>,
) -> Result<Value<'a, T>, EvalError> {
internal::evaluate(expr, &[], generic_args, symbols)
}
/// Evaluates a function call.
pub fn evaluate_function_call<'a, T: FieldElement, C: Custom>(
function: Value<'a, T, C>,
arguments: Vec<Rc<Value<'a, T, C>>>,
symbols: &impl SymbolLookup<'a, T, C>,
pub fn evaluate_function_call<'a, T: FieldElement>(
function: Value<'a, T>,
arguments: Vec<Rc<Value<'a, T>>>,
symbols: &impl SymbolLookup<'a, T>,
// TODO maybe we should also make this return an Rc<Value>.
// Otherwise we might have to clone big nested objects.
) -> Result<Value<'a, T, C>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
match function {
Value::BuiltinFunction(b) => internal::evaluate_builtin_function(b, arguments, symbols),
Value::Closure(Closure {
@@ -146,33 +146,32 @@ impl Display for EvalError {
}
#[derive(Clone, PartialEq, Debug)]
pub enum Value<'a, T, C> {
pub enum Value<'a, T> {
Bool(bool),
Integer(BigInt),
FieldElement(T),
String(String),
Tuple(Vec<Self>),
Array(Vec<Self>),
Closure(Closure<'a, T, C>),
Closure(Closure<'a, T>),
BuiltinFunction(BuiltinFunction),
Expression(AlgebraicExpression<T>),
Identity(AlgebraicExpression<T>, AlgebraicExpression<T>),
Custom(C),
}
impl<'a, T: FieldElement, C> From<T> for Value<'a, T, C> {
impl<'a, T: FieldElement> From<T> for Value<'a, T> {
fn from(value: T) -> Self {
Value::FieldElement(value)
}
}
impl<'a, T: FieldElement, C> From<AlgebraicExpression<T>> for Value<'a, T, C> {
impl<'a, T: FieldElement> From<AlgebraicExpression<T>> for Value<'a, T> {
fn from(value: AlgebraicExpression<T>) -> Self {
Value::Expression(value)
}
}
impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> {
impl<'a, T: FieldElement> Value<'a, T> {
/// Tries to convert the value to a field element. For integers, this only works
/// if the integer is non-negative and less than the modulus.
pub fn try_to_field_element(self) -> Result<T, EvalError> {
@@ -228,7 +227,6 @@ impl<'a, T: FieldElement, C: Custom> Value<'a, T, C> {
Value::BuiltinFunction(b) => format!("builtin_{b:?}"),
Value::Expression(_) => "expr".to_string(),
Value::Identity(_, _) => "constr".to_string(),
Value::Custom(c) => c.type_name(),
}
}
}
@@ -270,7 +268,7 @@ pub trait Custom: Display + fmt::Debug + Clone + PartialEq {
fn type_name(&self) -> String;
}
impl<'a, T: Display, C: Custom> Display for Value<'a, T, C> {
impl<'a, T: Display> Display for Value<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Value::Bool(b) => write!(f, "{b}"),
@@ -283,34 +281,18 @@ impl<'a, T: Display, C: Custom> Display for Value<'a, T, C> {
Value::BuiltinFunction(b) => write!(f, "{b:?}"),
Value::Expression(e) => write!(f, "{e}"),
Value::Identity(left, right) => write!(f, "{left} = {right}"),
Value::Custom(c) => write!(f, "{c}"),
}
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum NoCustom {}
impl Custom for NoCustom {
fn type_name(&self) -> String {
unreachable!();
}
}
impl Display for NoCustom {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unreachable!()
}
}
#[derive(Clone, Debug)]
pub struct Closure<'a, T, C> {
pub struct Closure<'a, T> {
pub lambda: &'a LambdaExpression<T, Reference>,
pub environment: Vec<Rc<Value<'a, T, C>>>,
pub environment: Vec<Rc<Value<'a, T>>>,
pub generic_args: HashMap<String, Type>,
}
impl<'a, T, C> PartialEq for Closure<'a, T, C> {
impl<'a, T> PartialEq for Closure<'a, T> {
fn eq(&self, _other: &Self) -> bool {
// Eq is used for pattern matching.
// In the future, we should introduce a proper pattern type.
@@ -318,19 +300,19 @@ impl<'a, T, C> PartialEq for Closure<'a, T, C> {
}
}
impl<'a, T: Display, C> Display for Closure<'a, T, C> {
impl<'a, T: Display> Display for Closure<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.lambda)
}
}
impl<'a, T, C> From<Closure<'a, T, C>> for Value<'a, T, C> {
fn from(value: Closure<'a, T, C>) -> Self {
impl<'a, T> From<Closure<'a, T>> for Value<'a, T> {
fn from(value: Closure<'a, T>) -> Self {
Value::Closure(value)
}
}
impl<'a, T, C> Closure<'a, T, C> {
impl<'a, T> Closure<'a, T> {
pub fn type_name(&self) -> String {
// TODO should use proper types as soon as we have them
"closure".to_string()
@@ -341,12 +323,12 @@ pub struct Definitions<'a, T>(
pub &'a HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
);
impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
impl<'a, T: FieldElement> SymbolLookup<'a, T> for Definitions<'a, T> {
fn lookup<'b>(
&self,
name: &str,
generic_args: Option<Vec<Type>>,
) -> Result<Value<'a, T, NoCustom>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
let name = name.to_string();
let (symbol, value) = &self
.0
@@ -391,7 +373,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
})
}
fn lookup_public_reference(&self, name: &str) -> Result<Value<'a, T, NoCustom>, EvalError> {
fn lookup_public_reference(&self, name: &str) -> Result<Value<'a, T>, EvalError> {
Ok(AlgebraicExpression::PublicReference(name.to_string()).into())
}
}
@@ -404,36 +386,20 @@ impl<'a, T: FieldElement> From<&'a HashMap<String, (Symbol, Option<FunctionValue
}
}
pub trait SymbolLookup<'a, T, C> {
pub trait SymbolLookup<'a, T> {
fn lookup(
&self,
name: &'a str,
generic_args: Option<Vec<Type>>,
) -> Result<Value<'a, T, C>, EvalError>;
fn lookup_public_reference(&self, name: &'a str) -> Result<Value<'a, T, C>, EvalError> {
) -> Result<Value<'a, T>, EvalError>;
fn lookup_public_reference(&self, name: &'a str) -> Result<Value<'a, T>, EvalError> {
Err(EvalError::Unsupported(format!(
"Cannot evaluate public reference: {name}"
)))
}
fn eval_binary_operation(
&self,
_left: Value<'a, T, C>,
_op: BinaryOperator,
_right: Value<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
unreachable!()
}
fn eval_unary_operation(
&self,
_op: UnaryOperator,
_inner: C,
) -> Result<Value<'a, T, C>, EvalError> {
unreachable!()
}
fn eval_expr(&self, _expr: AlgebraicExpression<T>) -> Result<Value<'a, T, C>, EvalError> {
fn eval_expr(&self, _expr: AlgebraicExpression<T>) -> Result<Value<'a, T>, EvalError> {
Err(EvalError::DataNotAvailable)
}
}
@@ -448,12 +414,12 @@ mod internal {
use super::*;
pub fn evaluate<'a, 'b, T: FieldElement, C: Custom>(
pub fn evaluate<'a, 'b, T: FieldElement>(
expr: &'a Expression<T>,
locals: &[Rc<Value<'a, T, C>>],
locals: &[Rc<Value<'a, T>>],
generic_args: &'b HashMap<String, Type>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
symbols: &impl SymbolLookup<'a, T>,
) -> Result<Value<'a, T>, EvalError> {
Ok(match expr {
Expression::Reference(reference) => {
evaluate_reference(reference, locals, generic_args, symbols)?
@@ -477,11 +443,10 @@ mod internal {
Expression::BinaryOperation(left, op, right) => {
let left = evaluate(left, locals, generic_args, symbols)?;
let right = evaluate(right, locals, generic_args, symbols)?;
evaluate_binary_operation(left, *op, right, symbols)?
evaluate_binary_operation(left, *op, right)?
}
Expression::UnaryOperation(op, expr) => {
match (op, evaluate(expr, locals, generic_args, symbols)?) {
(_, Value::Custom(inner)) => symbols.eval_unary_operation(*op, inner)?,
(UnaryOperator::Minus, Value::FieldElement(e)) => Value::FieldElement(-e),
(UnaryOperator::LogicalNot, Value::Bool(b)) => Value::Bool(!b),
(UnaryOperator::Minus, Value::Integer(n)) => Value::Integer(-n),
@@ -601,11 +566,11 @@ mod internal {
})
}
fn evaluate_literal<'a, T: FieldElement, C: Custom>(
fn evaluate_literal<'a, T: FieldElement>(
n: &T,
ty: &Option<TypeName<NoArrayLengths>>,
generic_args: &HashMap<String, Type>,
) -> Result<Value<'a, T, C>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
let ty = if let Some(TypeName::TypeVar(tv)) = ty {
match &generic_args[tv] {
Type::Fe => TypeName::Fe,
@@ -631,12 +596,12 @@ mod internal {
})
}
fn evaluate_reference<'a, T: FieldElement, C: Custom>(
fn evaluate_reference<'a, T: FieldElement>(
reference: &'a Reference,
locals: &[Rc<Value<'a, T, C>>],
locals: &[Rc<Value<'a, T>>],
generic_args: &HashMap<String, Type>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
symbols: &impl SymbolLookup<'a, T>,
) -> Result<Value<'a, T>, EvalError> {
Ok(match reference {
Reference::LocalVar(i, _name) => (*locals[*i as usize]).clone(),
@@ -656,16 +621,12 @@ mod internal {
})
}
fn evaluate_binary_operation<'a, T: FieldElement, C: Custom>(
left: Value<'a, T, C>,
fn evaluate_binary_operation<'a, T: FieldElement>(
left: Value<'a, T>,
op: BinaryOperator,
right: Value<'a, T, C>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
right: Value<'a, T>,
) -> Result<Value<'a, T>, EvalError> {
Ok(match (left, op, right) {
(l @ Value::Custom(_), _, r) | (l, _, r @ Value::Custom(_)) => {
symbols.eval_binary_operation(l, op, r)?
}
(Value::Array(mut l), BinaryOperator::Add, Value::Array(r)) => {
l.extend(r);
Value::Array(l)
@@ -716,7 +677,7 @@ mod internal {
(Value::Expression(l), op, Value::Expression(r)) => match (l, r) {
(AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => {
let Value::FieldElement(result) =
evaluate_binary_operation_field::<'a, T, C>(l, op, r)?
evaluate_binary_operation_field::<'a, T>(l, op, r)?
else {
panic!()
};
@@ -738,11 +699,11 @@ mod internal {
}
#[allow(clippy::print_stdout)]
pub fn evaluate_builtin_function<'a, T: FieldElement, C: Custom>(
pub fn evaluate_builtin_function<'a, T: FieldElement>(
b: BuiltinFunction,
mut arguments: Vec<Rc<Value<'a, T, C>>>,
symbols: &impl SymbolLookup<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
mut arguments: Vec<Rc<Value<'a, T>>>,
symbols: &impl SymbolLookup<'a, T>,
) -> Result<Value<'a, T>, EvalError> {
let params = match b {
BuiltinFunction::ArrayLen => 1,
BuiltinFunction::Modulus => 0,
@@ -816,11 +777,11 @@ mod internal {
}
}
pub fn evaluate_binary_operation_field<'a, T: FieldElement, C>(
pub fn evaluate_binary_operation_field<'a, T: FieldElement>(
left: T,
op: BinaryOperator,
right: T,
) -> Result<Value<'a, T, C>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
Ok(match op {
BinaryOperator::Add => Value::FieldElement(left + right),
BinaryOperator::Sub => Value::FieldElement(left - right),
@@ -833,11 +794,11 @@ pub fn evaluate_binary_operation_field<'a, T: FieldElement, C>(
})
}
pub fn evaluate_binary_operation_integer<'a, T, C>(
pub fn evaluate_binary_operation_integer<'a, T>(
left: &BigInt,
op: BinaryOperator,
right: &BigInt,
) -> Result<Value<'a, T, C>, EvalError> {
) -> Result<Value<'a, T>, EvalError> {
Ok(match op {
BinaryOperator::Add => Value::Integer(left + right),
BinaryOperator::Sub => Value::Integer(left - right),
@@ -880,7 +841,7 @@ mod test {
else {
panic!()
};
evaluate::<_, NoCustom>(symbol, &Definitions(&analyzed.definitions))
evaluate(symbol, &Definitions(&analyzed.definitions))
.unwrap()
.to_string()
}

View File

@@ -157,8 +157,8 @@ pub fn std_analyzed<T: FieldElement>() -> Analyzed<T> {
pub fn evaluate_function<'a, T: FieldElement>(
analyzed: &'a Analyzed<T>,
function: &'a str,
arguments: Vec<Rc<evaluator::Value<'a, T, evaluator::NoCustom>>>,
) -> evaluator::Value<'a, T, evaluator::NoCustom> {
arguments: Vec<Rc<evaluator::Value<'a, T>>>,
) -> evaluator::Value<'a, T> {
let symbols = evaluator::Definitions(&analyzed.definitions);
let function = symbols.lookup(function, None).unwrap();
evaluator::evaluate_function_call(function, arguments, &symbols).unwrap()