mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-04-20 03:03:25 -04:00
Use evaluator for condenser.
This commit is contained in:
@@ -1,20 +1,22 @@
|
||||
//! Component that turns data from the PILAnalyzer into Analyzed,
|
||||
//! i.e. it turns more complex expressions in identities to simpler expressions.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, fmt::Display, rc::Rc};
|
||||
|
||||
use ast::{
|
||||
analyzed::{
|
||||
AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition,
|
||||
Identity, PolyID, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
|
||||
Identity, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
|
||||
StatementIdentifier, Symbol, SymbolKind,
|
||||
},
|
||||
evaluate_binary_operation, evaluate_unary_operation,
|
||||
parsed::{visitor::ExpressionVisitable, IndexAccess, SelectedExpressions, UnaryOperator},
|
||||
parsed::{visitor::ExpressionVisitable, SelectedExpressions, UnaryOperator},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use number::{DegreeType, FieldElement};
|
||||
|
||||
use crate::evaluator;
|
||||
use crate::evaluator::{
|
||||
self, evaluate, evaluate_function_call, Custom, EvalError, SymbolLookup, Value,
|
||||
};
|
||||
|
||||
pub fn condense<T: FieldElement>(
|
||||
degree: Option<DegreeType>,
|
||||
@@ -24,7 +26,6 @@ pub fn condense<T: FieldElement>(
|
||||
source_order: Vec<StatementIdentifier>,
|
||||
) -> Analyzed<T> {
|
||||
let condenser = Condenser {
|
||||
constants: compute_constants(&definitions),
|
||||
symbols: definitions.clone(),
|
||||
};
|
||||
|
||||
@@ -78,8 +79,6 @@ pub fn condense<T: FieldElement>(
|
||||
pub struct Condenser<T> {
|
||||
/// All the definitions from the PIL file.
|
||||
pub symbols: HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
|
||||
/// Definitions that evaluate to constant numbers.
|
||||
pub constants: HashMap<String, T>,
|
||||
}
|
||||
|
||||
impl<T: FieldElement> Condenser<T> {
|
||||
@@ -125,143 +124,202 @@ impl<T: FieldElement> Condenser<T> {
|
||||
}
|
||||
|
||||
pub fn condense_expression(&self, e: &Expression<T>) -> AlgebraicExpression<T> {
|
||||
match e {
|
||||
Expression::Reference(Reference::Poly(poly)) => {
|
||||
if let Some(value) = self.constants.get(&poly.name) {
|
||||
return AlgebraicExpression::Number(*value);
|
||||
evaluator::evaluate(e, &self)
|
||||
.and_then(|result| {
|
||||
// TODO at this point, we could also support arrays of constraints, but we would
|
||||
// need to make it clear if an array is supported in this context (it is at statement level,
|
||||
// but not inside a lookup for example).
|
||||
match result {
|
||||
Value::Custom(Condensate { expr, .. }) => Ok(expr),
|
||||
Value::Number(n) => Ok(n.into()),
|
||||
_ => Err(EvalError::TypeError(format!(
|
||||
"Expected constraint, but got {result}"
|
||||
))),
|
||||
}
|
||||
let symbol = &self
|
||||
.symbols
|
||||
.get(&poly.name)
|
||||
.unwrap_or_else(|| panic!("Symbol {} not found.", poly.name))
|
||||
.0;
|
||||
})
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Error reducing expression to constraint:\nExpression: {e}\nError: {err:?}")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
symbol.length.is_none(),
|
||||
"Arrays cannot be used as a whole in this context, only individual array elements can be used."
|
||||
);
|
||||
impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate<T>> for &'a Condenser<T> {
|
||||
fn lookup(&self, name: &str) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
|
||||
let name = name.to_string();
|
||||
let (symbol, value) = &self
|
||||
.symbols
|
||||
.get(&name)
|
||||
.ok_or_else(|| EvalError::SymbolNotFound(format!("Symbol {name} not found.")))?;
|
||||
|
||||
Ok(if matches!(symbol.kind, SymbolKind::Poly(_)) {
|
||||
if symbol.is_array() {
|
||||
Value::Array(
|
||||
symbol
|
||||
.array_elements()
|
||||
.map(|(name, poly_id)| {
|
||||
AlgebraicExpression::Reference(AlgebraicReference {
|
||||
name,
|
||||
poly_id,
|
||||
next: false,
|
||||
})
|
||||
.into()
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
AlgebraicExpression::Reference(AlgebraicReference {
|
||||
name: poly.name.clone(),
|
||||
name,
|
||||
poly_id: symbol.into(),
|
||||
next: false,
|
||||
})
|
||||
.into()
|
||||
}
|
||||
Expression::Reference(Reference::LocalVar(_, _)) => {
|
||||
panic!("Local variables not allowed here.")
|
||||
}
|
||||
Expression::Number(n) => AlgebraicExpression::Number(*n),
|
||||
Expression::BinaryOperation(left, op, right) => {
|
||||
match (
|
||||
self.condense_expression(left),
|
||||
self.condense_expression(right),
|
||||
) {
|
||||
(AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => {
|
||||
AlgebraicExpression::Number(evaluate_binary_operation(l, *op, r))
|
||||
}
|
||||
(l, r) => AlgebraicExpression::BinaryOperation(
|
||||
Box::new(l),
|
||||
(*op).try_into().unwrap(),
|
||||
Box::new(r),
|
||||
),
|
||||
} else {
|
||||
match value {
|
||||
Some(FunctionValueDefinition::Expression(value)) => {
|
||||
evaluator::evaluate(value, self)?
|
||||
}
|
||||
_ => Err(EvalError::Unsupported(
|
||||
"Cannot evaluate arrays and queries.".to_string(),
|
||||
))?,
|
||||
}
|
||||
Expression::UnaryOperation(op, inner) => {
|
||||
let inner = self.condense_expression(inner);
|
||||
if *op == UnaryOperator::Next {
|
||||
let AlgebraicExpression::Reference(reference) = inner else {
|
||||
panic!(
|
||||
"Can apply \"'\" operator only directly to columns in this context."
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
fn lookup_public_reference(
|
||||
&self,
|
||||
name: &str,
|
||||
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
|
||||
Ok(AlgebraicExpression::PublicReference(name.to_string()).into())
|
||||
}
|
||||
|
||||
fn eval_function_application(
|
||||
&self,
|
||||
function: Condensate<T>,
|
||||
arguments: &[Rc<Value<'a, T, Condensate<T>>>],
|
||||
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
|
||||
match function.expr {
|
||||
AlgebraicExpression::Reference(AlgebraicReference {
|
||||
name,
|
||||
poly_id,
|
||||
next,
|
||||
}) if poly_id.ptype == PolynomialType::Constant => {
|
||||
let arguments = if next {
|
||||
assert_eq!(arguments.len(), 1);
|
||||
let Value::Number(arg) = *arguments[0] else {
|
||||
return Err(EvalError::TypeError(
|
||||
"Expected numeric argument when evaluating function with next ref."
|
||||
.to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
assert!(!reference.next, "Double application of \"'\"");
|
||||
AlgebraicExpression::Reference(AlgebraicReference {
|
||||
next: true,
|
||||
..reference
|
||||
})
|
||||
vec![Rc::new(Value::Number(arg + 1.into()))]
|
||||
} else {
|
||||
match inner {
|
||||
AlgebraicExpression::Number(n) => {
|
||||
AlgebraicExpression::Number(evaluate_unary_operation(*op, n))
|
||||
}
|
||||
_ => AlgebraicExpression::UnaryOperation(
|
||||
(*op).try_into().unwrap(),
|
||||
Box::new(inner),
|
||||
),
|
||||
arguments.to_vec()
|
||||
};
|
||||
|
||||
match self.symbols[&name].1.as_ref() {
|
||||
Some(FunctionValueDefinition::Expression(v)) => {
|
||||
let function = evaluate(v, self)?;
|
||||
evaluate_function_call(function, arguments, self)
|
||||
}
|
||||
None => Err(EvalError::SymbolNotFound(format!(
|
||||
"Symbol not found in function call: {name}"
|
||||
))),
|
||||
_ => Err(EvalError::Unsupported(format!(
|
||||
"Cannot evaluate arrays or queries: {name}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
Expression::PublicReference(r) => AlgebraicExpression::PublicReference(r.clone()),
|
||||
Expression::IndexAccess(IndexAccess { array, index }) => {
|
||||
let array_symbol = match array.as_ref() {
|
||||
ast::parsed::Expression::Reference(Reference::Poly(PolynomialReference {
|
||||
name,
|
||||
poly_id: _,
|
||||
})) => {
|
||||
&self
|
||||
.symbols
|
||||
.get(name)
|
||||
.unwrap_or_else(|| panic!("Symbol {name} not found."))
|
||||
.0
|
||||
}
|
||||
_ => panic!("Expected direct reference before array index access."),
|
||||
};
|
||||
let Some(length) = array_symbol.length else {
|
||||
panic!("Array-access for non-array {}.", array_symbol.absolute_name);
|
||||
};
|
||||
_ => Err(EvalError::TypeError(format!(
|
||||
"Function application not supported: {function}({})",
|
||||
arguments.iter().format(", ")
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
let index = evaluator::evaluate_expression(index, &self.symbols)
|
||||
.and_then(|v| v.try_to_number())
|
||||
.expect("Index needs to be constant number.")
|
||||
.to_degree();
|
||||
assert!(
|
||||
index < length,
|
||||
"Array access to index {index} for array of length {length}: {}",
|
||||
array_symbol.absolute_name,
|
||||
);
|
||||
let poly_id: PolyID = array_symbol.into();
|
||||
AlgebraicExpression::Reference(AlgebraicReference {
|
||||
poly_id: PolyID {
|
||||
id: poly_id.id + index,
|
||||
..poly_id
|
||||
},
|
||||
name: array_symbol.array_element_name(index),
|
||||
next: false,
|
||||
})
|
||||
fn eval_binary_operation(
|
||||
&self,
|
||||
left: Value<'a, T, Condensate<T>>,
|
||||
op: ast::parsed::BinaryOperator,
|
||||
right: Value<'a, T, Condensate<T>>,
|
||||
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
|
||||
let left: Condensate<T> = left.try_into()?;
|
||||
let right: Condensate<T> = right.try_into()?;
|
||||
Ok(AlgebraicExpression::BinaryOperation(
|
||||
Box::new(left.expr),
|
||||
op.try_into().map_err(EvalError::TypeError)?,
|
||||
Box::new(right.expr),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
fn eval_unary_operation(
|
||||
&self,
|
||||
op: UnaryOperator,
|
||||
inner: Condensate<T>,
|
||||
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
|
||||
if op == UnaryOperator::Next {
|
||||
let AlgebraicExpression::Reference(reference) = inner.expr else {
|
||||
return Err(EvalError::TypeError(format!(
|
||||
"Expected column for \"'\" operator, but got: {inner}"
|
||||
)));
|
||||
};
|
||||
|
||||
if reference.next {
|
||||
return Err(EvalError::TypeError(format!(
|
||||
"Double application of \"'\" on: {reference}"
|
||||
)));
|
||||
}
|
||||
Expression::String(_) => panic!("Strings are not allowed here."),
|
||||
Expression::Tuple(_) => panic!(),
|
||||
Expression::LambdaExpression(_) => panic!(),
|
||||
Expression::ArrayLiteral(_) => panic!(),
|
||||
Expression::FunctionCall(_) => panic!(),
|
||||
Expression::FreeInput(_) => panic!(),
|
||||
Expression::MatchExpression(_, _) => panic!(),
|
||||
Ok(AlgebraicExpression::Reference(AlgebraicReference {
|
||||
next: true,
|
||||
..reference
|
||||
})
|
||||
.into())
|
||||
} else {
|
||||
Ok(AlgebraicExpression::UnaryOperation(
|
||||
op.try_into().map_err(EvalError::TypeError)?,
|
||||
Box::new(inner.expr),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a HashMap of all symbols that have a constant single value.
|
||||
fn compute_constants<T: FieldElement>(
|
||||
definitions: &HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
|
||||
) -> HashMap<String, T> {
|
||||
definitions
|
||||
.iter()
|
||||
.filter_map(|(name, (symbol, value))| {
|
||||
// TODO we could try to compute anything that evaluates to a "value" here.
|
||||
if symbol.kind == SymbolKind::Constant() {
|
||||
let Some(FunctionValueDefinition::Expression(value)) = value else {
|
||||
panic!();
|
||||
};
|
||||
Ok(value)
|
||||
.and_then(|value| {
|
||||
evaluator::evaluate_expression(value, definitions)?.try_to_number()
|
||||
})
|
||||
.map(|value| (name.to_owned(), value))
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
#[derive(Clone)]
|
||||
struct Condensate<T> {
|
||||
pub expr: AlgebraicExpression<T>,
|
||||
}
|
||||
|
||||
impl<T: PartialEq> PartialEq for Condensate<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.expr == other.expr
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Display> Display for Condensate<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.expr)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement> Custom for Condensate<T> {}
|
||||
|
||||
impl<'a, T: FieldElement> TryFrom<Value<'a, T, Self>> for Condensate<T> {
|
||||
type Error = EvalError;
|
||||
|
||||
fn try_from(value: Value<'a, T, Self>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
Value::Number(n) => Ok(Self { expr: n.into() }),
|
||||
Value::Custom(v) => Ok(v),
|
||||
value => Err(EvalError::TypeError(format!(
|
||||
"Expected algebraic expression, got {value}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: FieldElement> From<AlgebraicExpression<T>> for Value<'a, T, Condensate<T>> {
|
||||
fn from(expr: AlgebraicExpression<T>) -> Self {
|
||||
Value::Custom(Condensate { expr })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ use std::{collections::HashMap, fmt::Display, rc::Rc};
|
||||
use ast::{
|
||||
analyzed::{Expression, FunctionValueDefinition, Reference, Symbol},
|
||||
evaluate_binary_operation, evaluate_unary_operation,
|
||||
parsed::{display::quote, FunctionCall, LambdaExpression, MatchArm, MatchPattern},
|
||||
parsed::{
|
||||
display::quote, BinaryOperator, FunctionCall, LambdaExpression, MatchArm, MatchPattern,
|
||||
UnaryOperator,
|
||||
},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use number::FieldElement;
|
||||
@@ -199,6 +202,23 @@ pub trait SymbolLookup<'a, T, C> {
|
||||
function: C,
|
||||
arguments: &[Rc<Value<'a, T, C>>],
|
||||
) -> Result<Value<'a, T, C>, EvalError>;
|
||||
|
||||
fn eval_binary_operation(
|
||||
&self,
|
||||
_left: Value<'a, T, C>,
|
||||
_op: BinaryOperator,
|
||||
_right: Value<'a, T, C>,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn eval_unary_operation(
|
||||
&self,
|
||||
_op: UnaryOperator,
|
||||
_inner: C,
|
||||
) -> Result<Value<'a, T, C>, EvalError> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
mod internal {
|
||||
@@ -228,16 +248,27 @@ mod internal {
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
Expression::BinaryOperation(left, op, right) => {
|
||||
Value::Number(evaluate_binary_operation(
|
||||
evaluate(left, locals, symbols)?.try_to_number()?,
|
||||
*op,
|
||||
evaluate(right, locals, symbols)?.try_to_number()?,
|
||||
))
|
||||
let left = evaluate(left, locals, symbols)?;
|
||||
let right = evaluate(right, locals, symbols)?;
|
||||
match (&left, &right) {
|
||||
(Value::Custom(_), _) | (_, Value::Custom(_)) => {
|
||||
symbols.eval_binary_operation(left, *op, right)?
|
||||
}
|
||||
(Value::Number(l), Value::Number(r)) => {
|
||||
Value::Number(evaluate_binary_operation(*l, *op, *r))
|
||||
}
|
||||
_ => Err(EvalError::TypeError(format!(
|
||||
"Operator {op} not supported on types: {left} {op} {right}"
|
||||
)))?,
|
||||
}
|
||||
}
|
||||
Expression::UnaryOperation(op, expr) => Value::Number(evaluate_unary_operation(
|
||||
*op,
|
||||
evaluate(expr, locals, symbols)?.try_to_number()?,
|
||||
)),
|
||||
Expression::UnaryOperation(op, expr) => match evaluate(expr, locals, symbols)? {
|
||||
Value::Custom(inner) => symbols.eval_unary_operation(*op, inner)?,
|
||||
Value::Number(n) => Value::Number(evaluate_unary_operation(*op, n)),
|
||||
inner => Err(EvalError::TypeError(format!(
|
||||
"Operator {op} not supported on types: {op} {inner}"
|
||||
)))?,
|
||||
},
|
||||
Expression::LambdaExpression(lambda) => {
|
||||
// TODO only copy the part of the environment that is actually referenced?
|
||||
(Closure {
|
||||
|
||||
@@ -834,7 +834,7 @@ namespace N(65536);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Arrays cannot be used as a whole in this context"]
|
||||
#[should_panic = "Operator - not supported on types"]
|
||||
fn no_direct_array_references() {
|
||||
let input = r#"namespace N(16);
|
||||
col witness y[3];
|
||||
@@ -845,7 +845,7 @@ namespace N(65536);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Array access to index 3 for array of length 3"]
|
||||
#[should_panic = "Tried to access element 3 of array of size 3."]
|
||||
fn no_out_of_bounds() {
|
||||
let input = r#"namespace N(16);
|
||||
col witness y[3];
|
||||
@@ -865,4 +865,74 @@ namespace N(65536);
|
||||
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symbolic_functions() {
|
||||
let input = r#"namespace N(16);
|
||||
let last_row = 15;
|
||||
let ISLAST = |i| match i { last_row => 1, _ => 0 };
|
||||
let x;
|
||||
let y;
|
||||
let constrain_equal_expr = |A, B| A - B;
|
||||
let on_regular_row = |cond| (1 - ISLAST) * cond;
|
||||
on_regular_row(constrain_equal_expr(x', y)) = 0;
|
||||
on_regular_row(constrain_equal_expr(y', x + y)) = 0;
|
||||
"#;
|
||||
let expected = r#"constant last_row = 15;
|
||||
namespace N(16);
|
||||
col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } };
|
||||
col witness x;
|
||||
col witness y;
|
||||
let constrain_equal_expr = |A, B| (A - B);
|
||||
col fixed on_regular_row(cond) { ((1 - N.ISLAST) * cond) };
|
||||
((1 - N.ISLAST) * (N.x' - N.y)) = 0;
|
||||
((1 - N.ISLAST) * (N.y' - (N.x + N.y))) = 0;
|
||||
"#;
|
||||
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_op_on_param() {
|
||||
let input = r#"namespace N(16);
|
||||
let last_row = 15;
|
||||
let ISLAST = |i| match i { last_row => 1, _ => 0 };
|
||||
let x;
|
||||
let y;
|
||||
let next_is_seven = |t| t' - 7;
|
||||
next_is_seven(y) = 0;
|
||||
"#;
|
||||
let expected = r#"constant last_row = 15;
|
||||
namespace N(16);
|
||||
col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } };
|
||||
col witness x;
|
||||
col witness y;
|
||||
col fixed next_is_seven(t) { (t' - 7) };
|
||||
(N.y' - 7) = 0;
|
||||
"#;
|
||||
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_concrete_and_symbolic() {
|
||||
let input = r#"namespace N(16);
|
||||
let last_row = 15;
|
||||
let ISLAST = |i| match i { last_row => 1, _ => 0, };
|
||||
let x;
|
||||
let y;
|
||||
y - ISLAST(3) = 0;
|
||||
x - ISLAST = 0;
|
||||
"#;
|
||||
let expected = r#"constant last_row = 15;
|
||||
namespace N(16);
|
||||
col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } };
|
||||
col witness x;
|
||||
col witness y;
|
||||
(N.y - 0) = 0;
|
||||
(N.x - N.ISLAST) = 0;
|
||||
"#;
|
||||
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user