Use evaluator for condenser.

This commit is contained in:
chriseth
2023-11-16 15:44:55 +01:00
parent 912cab8fa3
commit 01bfba1a40
3 changed files with 297 additions and 138 deletions

View File

@@ -1,20 +1,22 @@
//! Component that turns data from the PILAnalyzer into Analyzed,
//! i.e. it turns more complex expressions in identities to simpler expressions.
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Display, rc::Rc};
use ast::{
analyzed::{
AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition,
Identity, PolyID, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
Identity, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
StatementIdentifier, Symbol, SymbolKind,
},
evaluate_binary_operation, evaluate_unary_operation,
parsed::{visitor::ExpressionVisitable, IndexAccess, SelectedExpressions, UnaryOperator},
parsed::{visitor::ExpressionVisitable, SelectedExpressions, UnaryOperator},
};
use itertools::Itertools;
use number::{DegreeType, FieldElement};
use crate::evaluator;
use crate::evaluator::{
self, evaluate, evaluate_function_call, Custom, EvalError, SymbolLookup, Value,
};
pub fn condense<T: FieldElement>(
degree: Option<DegreeType>,
@@ -24,7 +26,6 @@ pub fn condense<T: FieldElement>(
source_order: Vec<StatementIdentifier>,
) -> Analyzed<T> {
let condenser = Condenser {
constants: compute_constants(&definitions),
symbols: definitions.clone(),
};
@@ -78,8 +79,6 @@ pub fn condense<T: FieldElement>(
pub struct Condenser<T> {
/// All the definitions from the PIL file.
pub symbols: HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
/// Definitions that evaluate to constant numbers.
pub constants: HashMap<String, T>,
}
impl<T: FieldElement> Condenser<T> {
@@ -125,143 +124,202 @@ impl<T: FieldElement> Condenser<T> {
}
pub fn condense_expression(&self, e: &Expression<T>) -> AlgebraicExpression<T> {
match e {
Expression::Reference(Reference::Poly(poly)) => {
if let Some(value) = self.constants.get(&poly.name) {
return AlgebraicExpression::Number(*value);
evaluator::evaluate(e, &self)
.and_then(|result| {
// TODO at this point, we could also support arrays of constraints, but we would
// need to make it clear if an array is supported in this context (it is at statement level,
// but not inside a lookup for example).
match result {
Value::Custom(Condensate { expr, .. }) => Ok(expr),
Value::Number(n) => Ok(n.into()),
_ => Err(EvalError::TypeError(format!(
"Expected constraint, but got {result}"
))),
}
let symbol = &self
.symbols
.get(&poly.name)
.unwrap_or_else(|| panic!("Symbol {} not found.", poly.name))
.0;
})
.unwrap_or_else(|err| {
panic!("Error reducing expression to constraint:\nExpression: {e}\nError: {err:?}")
})
}
}
assert!(
symbol.length.is_none(),
"Arrays cannot be used as a whole in this context, only individual array elements can be used."
);
impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate<T>> for &'a Condenser<T> {
fn lookup(&self, name: &str) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
let name = name.to_string();
let (symbol, value) = &self
.symbols
.get(&name)
.ok_or_else(|| EvalError::SymbolNotFound(format!("Symbol {name} not found.")))?;
Ok(if matches!(symbol.kind, SymbolKind::Poly(_)) {
if symbol.is_array() {
Value::Array(
symbol
.array_elements()
.map(|(name, poly_id)| {
AlgebraicExpression::Reference(AlgebraicReference {
name,
poly_id,
next: false,
})
.into()
})
.collect(),
)
} else {
AlgebraicExpression::Reference(AlgebraicReference {
name: poly.name.clone(),
name,
poly_id: symbol.into(),
next: false,
})
.into()
}
Expression::Reference(Reference::LocalVar(_, _)) => {
panic!("Local variables not allowed here.")
}
Expression::Number(n) => AlgebraicExpression::Number(*n),
Expression::BinaryOperation(left, op, right) => {
match (
self.condense_expression(left),
self.condense_expression(right),
) {
(AlgebraicExpression::Number(l), AlgebraicExpression::Number(r)) => {
AlgebraicExpression::Number(evaluate_binary_operation(l, *op, r))
}
(l, r) => AlgebraicExpression::BinaryOperation(
Box::new(l),
(*op).try_into().unwrap(),
Box::new(r),
),
} else {
match value {
Some(FunctionValueDefinition::Expression(value)) => {
evaluator::evaluate(value, self)?
}
_ => Err(EvalError::Unsupported(
"Cannot evaluate arrays and queries.".to_string(),
))?,
}
Expression::UnaryOperation(op, inner) => {
let inner = self.condense_expression(inner);
if *op == UnaryOperator::Next {
let AlgebraicExpression::Reference(reference) = inner else {
panic!(
"Can apply \"'\" operator only directly to columns in this context."
);
})
}
fn lookup_public_reference(
&self,
name: &str,
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
Ok(AlgebraicExpression::PublicReference(name.to_string()).into())
}
fn eval_function_application(
&self,
function: Condensate<T>,
arguments: &[Rc<Value<'a, T, Condensate<T>>>],
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
match function.expr {
AlgebraicExpression::Reference(AlgebraicReference {
name,
poly_id,
next,
}) if poly_id.ptype == PolynomialType::Constant => {
let arguments = if next {
assert_eq!(arguments.len(), 1);
let Value::Number(arg) = *arguments[0] else {
return Err(EvalError::TypeError(
"Expected numeric argument when evaluating function with next ref."
.to_string(),
));
};
assert!(!reference.next, "Double application of \"'\"");
AlgebraicExpression::Reference(AlgebraicReference {
next: true,
..reference
})
vec![Rc::new(Value::Number(arg + 1.into()))]
} else {
match inner {
AlgebraicExpression::Number(n) => {
AlgebraicExpression::Number(evaluate_unary_operation(*op, n))
}
_ => AlgebraicExpression::UnaryOperation(
(*op).try_into().unwrap(),
Box::new(inner),
),
arguments.to_vec()
};
match self.symbols[&name].1.as_ref() {
Some(FunctionValueDefinition::Expression(v)) => {
let function = evaluate(v, self)?;
evaluate_function_call(function, arguments, self)
}
None => Err(EvalError::SymbolNotFound(format!(
"Symbol not found in function call: {name}"
))),
_ => Err(EvalError::Unsupported(format!(
"Cannot evaluate arrays or queries: {name}"
))),
}
}
Expression::PublicReference(r) => AlgebraicExpression::PublicReference(r.clone()),
Expression::IndexAccess(IndexAccess { array, index }) => {
let array_symbol = match array.as_ref() {
ast::parsed::Expression::Reference(Reference::Poly(PolynomialReference {
name,
poly_id: _,
})) => {
&self
.symbols
.get(name)
.unwrap_or_else(|| panic!("Symbol {name} not found."))
.0
}
_ => panic!("Expected direct reference before array index access."),
};
let Some(length) = array_symbol.length else {
panic!("Array-access for non-array {}.", array_symbol.absolute_name);
};
_ => Err(EvalError::TypeError(format!(
"Function application not supported: {function}({})",
arguments.iter().format(", ")
))),
}
}
let index = evaluator::evaluate_expression(index, &self.symbols)
.and_then(|v| v.try_to_number())
.expect("Index needs to be constant number.")
.to_degree();
assert!(
index < length,
"Array access to index {index} for array of length {length}: {}",
array_symbol.absolute_name,
);
let poly_id: PolyID = array_symbol.into();
AlgebraicExpression::Reference(AlgebraicReference {
poly_id: PolyID {
id: poly_id.id + index,
..poly_id
},
name: array_symbol.array_element_name(index),
next: false,
})
fn eval_binary_operation(
&self,
left: Value<'a, T, Condensate<T>>,
op: ast::parsed::BinaryOperator,
right: Value<'a, T, Condensate<T>>,
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
let left: Condensate<T> = left.try_into()?;
let right: Condensate<T> = right.try_into()?;
Ok(AlgebraicExpression::BinaryOperation(
Box::new(left.expr),
op.try_into().map_err(EvalError::TypeError)?,
Box::new(right.expr),
)
.into())
}
fn eval_unary_operation(
&self,
op: UnaryOperator,
inner: Condensate<T>,
) -> Result<Value<'a, T, Condensate<T>>, EvalError> {
if op == UnaryOperator::Next {
let AlgebraicExpression::Reference(reference) = inner.expr else {
return Err(EvalError::TypeError(format!(
"Expected column for \"'\" operator, but got: {inner}"
)));
};
if reference.next {
return Err(EvalError::TypeError(format!(
"Double application of \"'\" on: {reference}"
)));
}
Expression::String(_) => panic!("Strings are not allowed here."),
Expression::Tuple(_) => panic!(),
Expression::LambdaExpression(_) => panic!(),
Expression::ArrayLiteral(_) => panic!(),
Expression::FunctionCall(_) => panic!(),
Expression::FreeInput(_) => panic!(),
Expression::MatchExpression(_, _) => panic!(),
Ok(AlgebraicExpression::Reference(AlgebraicReference {
next: true,
..reference
})
.into())
} else {
Ok(AlgebraicExpression::UnaryOperation(
op.try_into().map_err(EvalError::TypeError)?,
Box::new(inner.expr),
)
.into())
}
}
}
/// Returns a HashMap of all symbols that have a constant single value.
fn compute_constants<T: FieldElement>(
definitions: &HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
) -> HashMap<String, T> {
definitions
.iter()
.filter_map(|(name, (symbol, value))| {
// TODO we could try to compute anything that evaluates to a "value" here.
if symbol.kind == SymbolKind::Constant() {
let Some(FunctionValueDefinition::Expression(value)) = value else {
panic!();
};
Ok(value)
.and_then(|value| {
evaluator::evaluate_expression(value, definitions)?.try_to_number()
})
.map(|value| (name.to_owned(), value))
.ok()
} else {
None
}
})
.collect()
#[derive(Clone)]
struct Condensate<T> {
pub expr: AlgebraicExpression<T>,
}
impl<T: PartialEq> PartialEq for Condensate<T> {
fn eq(&self, other: &Self) -> bool {
self.expr == other.expr
}
}
impl<T: Display> Display for Condensate<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)
}
}
impl<T: FieldElement> Custom for Condensate<T> {}
impl<'a, T: FieldElement> TryFrom<Value<'a, T, Self>> for Condensate<T> {
type Error = EvalError;
fn try_from(value: Value<'a, T, Self>) -> Result<Self, Self::Error> {
match value {
Value::Number(n) => Ok(Self { expr: n.into() }),
Value::Custom(v) => Ok(v),
value => Err(EvalError::TypeError(format!(
"Expected algebraic expression, got {value}"
))),
}
}
}
impl<'a, T: FieldElement> From<AlgebraicExpression<T>> for Value<'a, T, Condensate<T>> {
fn from(expr: AlgebraicExpression<T>) -> Self {
Value::Custom(Condensate { expr })
}
}

View File

@@ -3,7 +3,10 @@ use std::{collections::HashMap, fmt::Display, rc::Rc};
use ast::{
analyzed::{Expression, FunctionValueDefinition, Reference, Symbol},
evaluate_binary_operation, evaluate_unary_operation,
parsed::{display::quote, FunctionCall, LambdaExpression, MatchArm, MatchPattern},
parsed::{
display::quote, BinaryOperator, FunctionCall, LambdaExpression, MatchArm, MatchPattern,
UnaryOperator,
},
};
use itertools::Itertools;
use number::FieldElement;
@@ -199,6 +202,23 @@ pub trait SymbolLookup<'a, T, C> {
function: C,
arguments: &[Rc<Value<'a, T, C>>],
) -> Result<Value<'a, T, C>, EvalError>;
fn eval_binary_operation(
&self,
_left: Value<'a, T, C>,
_op: BinaryOperator,
_right: Value<'a, T, C>,
) -> Result<Value<'a, T, C>, EvalError> {
unreachable!()
}
fn eval_unary_operation(
&self,
_op: UnaryOperator,
_inner: C,
) -> Result<Value<'a, T, C>, EvalError> {
unreachable!()
}
}
mod internal {
@@ -228,16 +248,27 @@ mod internal {
.collect::<Result<_, _>>()?,
),
Expression::BinaryOperation(left, op, right) => {
Value::Number(evaluate_binary_operation(
evaluate(left, locals, symbols)?.try_to_number()?,
*op,
evaluate(right, locals, symbols)?.try_to_number()?,
))
let left = evaluate(left, locals, symbols)?;
let right = evaluate(right, locals, symbols)?;
match (&left, &right) {
(Value::Custom(_), _) | (_, Value::Custom(_)) => {
symbols.eval_binary_operation(left, *op, right)?
}
(Value::Number(l), Value::Number(r)) => {
Value::Number(evaluate_binary_operation(*l, *op, *r))
}
_ => Err(EvalError::TypeError(format!(
"Operator {op} not supported on types: {left} {op} {right}"
)))?,
}
}
Expression::UnaryOperation(op, expr) => Value::Number(evaluate_unary_operation(
*op,
evaluate(expr, locals, symbols)?.try_to_number()?,
)),
Expression::UnaryOperation(op, expr) => match evaluate(expr, locals, symbols)? {
Value::Custom(inner) => symbols.eval_unary_operation(*op, inner)?,
Value::Number(n) => Value::Number(evaluate_unary_operation(*op, n)),
inner => Err(EvalError::TypeError(format!(
"Operator {op} not supported on types: {op} {inner}"
)))?,
},
Expression::LambdaExpression(lambda) => {
// TODO only copy the part of the environment that is actually referenced?
(Closure {

View File

@@ -834,7 +834,7 @@ namespace N(65536);
}
#[test]
#[should_panic = "Arrays cannot be used as a whole in this context"]
#[should_panic = "Operator - not supported on types"]
fn no_direct_array_references() {
let input = r#"namespace N(16);
col witness y[3];
@@ -845,7 +845,7 @@ namespace N(65536);
}
#[test]
#[should_panic = "Array access to index 3 for array of length 3"]
#[should_panic = "Tried to access element 3 of array of size 3."]
fn no_out_of_bounds() {
let input = r#"namespace N(16);
col witness y[3];
@@ -865,4 +865,74 @@ namespace N(65536);
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, input);
}
#[test]
fn symbolic_functions() {
let input = r#"namespace N(16);
let last_row = 15;
let ISLAST = |i| match i { last_row => 1, _ => 0 };
let x;
let y;
let constrain_equal_expr = |A, B| A - B;
let on_regular_row = |cond| (1 - ISLAST) * cond;
on_regular_row(constrain_equal_expr(x', y)) = 0;
on_regular_row(constrain_equal_expr(y', x + y)) = 0;
"#;
let expected = r#"constant last_row = 15;
namespace N(16);
col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } };
col witness x;
col witness y;
let constrain_equal_expr = |A, B| (A - B);
col fixed on_regular_row(cond) { ((1 - N.ISLAST) * cond) };
((1 - N.ISLAST) * (N.x' - N.y)) = 0;
((1 - N.ISLAST) * (N.y' - (N.x + N.y))) = 0;
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn next_op_on_param() {
let input = r#"namespace N(16);
let last_row = 15;
let ISLAST = |i| match i { last_row => 1, _ => 0 };
let x;
let y;
let next_is_seven = |t| t' - 7;
next_is_seven(y) = 0;
"#;
let expected = r#"constant last_row = 15;
namespace N(16);
col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } };
col witness x;
col witness y;
col fixed next_is_seven(t) { (t' - 7) };
(N.y' - 7) = 0;
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn fixed_concrete_and_symbolic() {
let input = r#"namespace N(16);
let last_row = 15;
let ISLAST = |i| match i { last_row => 1, _ => 0, };
let x;
let y;
y - ISLAST(3) = 0;
x - ISLAST = 0;
"#;
let expected = r#"constant last_row = 15;
namespace N(16);
col fixed ISLAST(i) { match i { last_row => 1, _ => 0, } };
col witness x;
col witness y;
(N.y - 0) = 0;
(N.x - N.ISLAST) = 0;
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
}