This commit is contained in:
chriseth
2024-07-31 18:19:34 +02:00
committed by GitHub
parent ce6cb0bc6c
commit 98ebf3fbc6
7 changed files with 438 additions and 46 deletions

View File

@@ -2,7 +2,8 @@
//! i.e. it turns more complex expressions in identities to simpler expressions.
use std::{
collections::{BTreeMap, HashMap, HashSet},
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
fmt::Display,
iter::once,
str::FromStr,
sync::Arc,
@@ -11,7 +12,7 @@ use std::{
use powdr_ast::{
analyzed::{
self, AlgebraicExpression, AlgebraicReference, Analyzed, Expression,
FunctionValueDefinition, Identity, IdentityKind, PolynomialType, PublicDeclaration,
FunctionValueDefinition, Identity, IdentityKind, PolyID, PolynomialType, PublicDeclaration,
SelectedExpressions, StatementIdentifier, Symbol, SymbolKind,
},
parsed::{
@@ -19,7 +20,7 @@ use powdr_ast::{
asm::{AbsoluteSymbolPath, SymbolPath},
display::format_type_scheme_around_name,
types::{ArrayType, Type},
TypedExpression,
FunctionKind, TypedExpression,
},
};
use powdr_number::{DegreeType, FieldElement};
@@ -45,6 +46,7 @@ pub fn condense<T: FieldElement>(
let mut condensed_identities = vec![];
let mut intermediate_columns = HashMap::new();
let mut new_columns = vec![];
let mut new_values = HashMap::new();
// Condense identities and intermediate columns and update the source order.
let source_order = source_order
.into_iter()
@@ -98,15 +100,14 @@ pub fn condense<T: FieldElement>(
}
s => Some(s),
};
// Extract and prepend the new witness columns, then identities
// Extract and prepend the new columns, then identities
// and finally the original statement (if it exists).
let new_cols = condenser
.extract_new_columns()
.into_iter()
.map(|(new_col, value)| {
let name = new_col.absolute_name.clone();
new_columns.push((new_col, value));
StatementIdentifier::Definition(name)
.map(|new_col| {
new_columns.push(new_col.clone());
StatementIdentifier::Definition(new_col.absolute_name)
})
.collect::<Vec<_>>();
@@ -120,6 +121,12 @@ pub fn condense<T: FieldElement>(
})
.collect::<Vec<_>>();
for (name, hint) in condenser.extract_new_column_values() {
if new_values.insert(name.clone(), hint).is_some() {
panic!("Column {name} already has a hint set, but tried to add another one.",)
}
}
new_cols
.into_iter()
.chain(identity_statements)
@@ -128,8 +135,20 @@ pub fn condense<T: FieldElement>(
.collect();
definitions.retain(|name, _| !intermediate_columns.contains_key(name));
for (symbol, value) in new_columns {
definitions.insert(symbol.absolute_name.clone(), (symbol, value));
for symbol in new_columns {
definitions.insert(symbol.absolute_name.clone(), (symbol, None));
}
for (name, new_value) in new_values {
if let Some((_, value)) = definitions.get_mut(&name) {
if !value.is_none() {
panic!(
"Column {name} already has a value / hint set, but tried to add another one."
)
}
*value = Some(new_value);
} else {
panic!("Column {name} not found.");
}
}
for decl in public_declarations.values_mut() {
@@ -164,10 +183,12 @@ pub struct Condenser<'a, T> {
namespace: AbsoluteSymbolPath,
/// ID dispensers.
counters: Counters,
/// The generated columns since the last extraction.
new_columns: Vec<(Symbol, Option<FunctionValueDefinition>)>,
/// The names of all new olumns ever generated, to avoid duplicates.
all_new_names: HashSet<String>,
/// The generated columns since the last extraction in creation order.
new_columns: Vec<Symbol>,
/// The hints and fixed column definitions added since the last extraction.
new_column_values: HashMap<String, FunctionValueDefinition>,
/// The names of all new columns ever generated, to avoid duplicates.
new_symbols: HashSet<String>,
new_constraints: Vec<AnalyzedIdentity<T>>,
}
@@ -181,7 +202,8 @@ impl<'a, T: FieldElement> Condenser<'a, T> {
namespace: Default::default(),
counters,
new_columns: vec![],
all_new_names: HashSet::new(),
new_column_values: Default::default(),
new_symbols: HashSet::new(),
new_constraints: vec![],
}
}
@@ -226,11 +248,17 @@ impl<'a, T: FieldElement> Condenser<'a, T> {
self.degree = degree;
}
/// Returns the witness columns generated since the last call to this function.
pub fn extract_new_columns(&mut self) -> Vec<(Symbol, Option<FunctionValueDefinition>)> {
/// Returns columns generated since the last call to this function.
pub fn extract_new_columns(&mut self) -> Vec<Symbol> {
std::mem::take(&mut self.new_columns)
}
/// Return the new column values (fixed column definitions or witness column hints)
/// added since the last call to this function.
pub fn extract_new_column_values(&mut self) -> HashMap<String, FunctionValueDefinition> {
std::mem::take(&mut self.new_column_values)
}
/// Returns the new constraints generated since the last call to this function.
pub fn extract_new_constraints(&mut self) -> Vec<AnalyzedIdentity<T>> {
std::mem::take(&mut self.new_constraints)
@@ -322,24 +350,16 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> {
} else {
PolynomialType::Committed
});
let value = value.map(|v|{
if let Value::Closure(evaluator::Closure {
lambda,
environment: _,
type_args: _,
}) = v.as_ref()
{
if !lambda.outer_var_references.is_empty() {
return Err(EvalError::TypeError(format!("Lambda expression for fixed column {name} must not reference outer variables.")))
}
Ok(FunctionValueDefinition::Expression(TypedExpression {
e: Expression::LambdaExpression(source.clone(), (*lambda).clone()),
type_scheme: None,
}))
} else {
Err(EvalError::TypeError(format!("Only lambda expressions are allowed for dynamically-created fixed columns. Got {v}.")))
}
}).transpose()?;
let value = value
.map(|v| {
closure_to_function(&source, v.as_ref(), FunctionKind::Pure).map_err(|e| match e {
EvalError::TypeError(e) => {
EvalError::TypeError(format!("Error creating fixed column {name}: {e}"))
}
_ => e,
})
})
.transpose()?;
let symbol = Symbol {
id: self.counters.dispense_symbol_id(kind, None),
@@ -351,18 +371,71 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T> for Condenser<'a, T> {
degree: Some(self.degree.unwrap()),
};
self.all_new_names.insert(name.clone());
self.new_columns.push((symbol.clone(), value));
self.new_symbols.insert(name.clone());
self.new_columns.push(symbol.clone());
if let Some(value) = value {
self.new_column_values.insert(name.clone(), value);
}
Ok(
Value::Expression(AlgebraicExpression::Reference(AlgebraicReference {
name,
poly_id: (&symbol).into(),
poly_id: PolyID::from(&symbol),
next: false,
}))
.into(),
)
}
fn set_hint(
&mut self,
col: Arc<Value<'a, T>>,
expr: Arc<Value<'a, T>>,
) -> Result<(), EvalError> {
let name = match col.as_ref() {
Value::Expression(AlgebraicExpression::Reference(AlgebraicReference {
name,
poly_id,
next: false,
})) => {
if poly_id.ptype != PolynomialType::Committed {
return Err(EvalError::TypeError(format!(
"Expected reference to witness column as first argument for std::prover::set_hint, but got {} column {name}.",
poly_id.ptype
)));
}
if name.contains('[') {
return Err(EvalError::TypeError(format!(
"Array elements are not supported for std::prover::set_hint (called on {name})."
)));
}
name.clone()
}
col => {
return Err(EvalError::TypeError(format!(
"Expected reference to witness column as first argument for std::prover::set_hint, but got {col}: {}",
col.type_formatted()
)));
}
};
let value = closure_to_function(&SourceRef::unknown(), expr.as_ref(), FunctionKind::Query)
.map_err(|e| match e {
EvalError::TypeError(e) => {
EvalError::TypeError(format!("Error setting hint for column {col}: {e}"))
}
_ => e,
})?;
match self.new_column_values.entry(name) {
Entry::Vacant(entry) => entry.insert(value),
Entry::Occupied(_) => {
return Err(EvalError::TypeError(format!(
"Column {col} already has a hint set, but tried to add another one."
)));
}
};
Ok(())
}
fn add_constraints(
&mut self,
constraints: Arc<Value<'a, T>>,
@@ -392,7 +465,7 @@ impl<'a, T: FieldElement> Condenser<'a, T> {
.chain((1..).map(Some))
.map(|cnt| format!("{name}{}", cnt.map(|c| format!("_{c}")).unwrap_or_default()))
.map(|name| self.namespace.with_part(&name).to_dotted_string())
.find(|name| !self.symbols.contains_key(name) && !self.all_new_names.contains(name))
.find(|name| !self.symbols.contains_key(name) && !self.new_symbols.contains(name))
.unwrap()
}
}
@@ -513,3 +586,48 @@ fn to_expr<T: Clone>(value: &Value<'_, T>) -> AlgebraicExpression<T> {
panic!()
}
}
/// Turns a value of function type (i.e. a closure) into a FunctionValueDefinition
/// and sets the expected function kind.
/// Does not allow captured variables.
fn closure_to_function<T: Clone + Display>(
source: &SourceRef,
value: &Value<'_, T>,
expected_kind: FunctionKind,
) -> Result<FunctionValueDefinition, EvalError> {
let Value::Closure(evaluator::Closure {
lambda,
environment: _,
type_args,
}) = value
else {
return Err(EvalError::TypeError(format!(
"Expected lambda expressions but got {value}."
)));
};
if !type_args.is_empty() {
return Err(EvalError::TypeError(
"Lambda expression must not have type arguments.".to_string(),
));
}
if !lambda.outer_var_references.is_empty() {
return Err(EvalError::TypeError(format!(
"Lambda expression must not reference outer variables: {lambda}"
)));
}
if lambda.kind != FunctionKind::Pure && lambda.kind != expected_kind {
return Err(EvalError::TypeError(format!(
"Expected {expected_kind} lambda expression but got {}.",
lambda.kind
)));
}
let mut lambda = (*lambda).clone();
lambda.kind = expected_kind;
Ok(FunctionValueDefinition::Expression(TypedExpression {
e: Expression::LambdaExpression(source.clone(), lambda),
type_scheme: None,
}))
}

View File

@@ -308,7 +308,7 @@ impl<'a, T: FieldElement> Value<'a, T> {
}
}
const BUILTINS: [(&str, BuiltinFunction); 10] = [
const BUILTINS: [(&str, BuiltinFunction); 11] = [
("std::array::len", BuiltinFunction::ArrayLen),
("std::check::panic", BuiltinFunction::Panic),
("std::convert::expr", BuiltinFunction::ToExpr),
@@ -317,6 +317,7 @@ const BUILTINS: [(&str, BuiltinFunction); 10] = [
("std::debug::print", BuiltinFunction::Print),
("std::field::modulus", BuiltinFunction::Modulus),
("std::prelude::challenge", BuiltinFunction::Challenge),
("std::prover::set_hint", BuiltinFunction::SetHint),
("std::prover::degree", BuiltinFunction::Degree),
("std::prover::eval", BuiltinFunction::Eval),
];
@@ -341,6 +342,8 @@ pub enum BuiltinFunction {
ToFe,
/// std::prover::challenge: int, int -> expr, constructs a challenge with a given stage and ID.
Challenge,
/// std::prover::set_hint: expr, (int -> std::prover::Query) -> (), adds a hint to a witness column.
SetHint,
/// std::prover::degree: -> int, returns the current column length / degree.
Degree,
/// std::prover::eval: expr -> fe, evaluates an expression on the current row
@@ -551,6 +554,16 @@ pub trait SymbolLookup<'a, T: FieldElement> {
)))
}
fn set_hint(
&mut self,
_col: Arc<Value<'a, T>>,
_expr: Arc<Value<'a, T>>,
) -> Result<(), EvalError> {
Err(EvalError::Unsupported(
"Tried to add hint to column outside of statement context.".to_string(),
))
}
fn add_constraints(
&mut self,
_constraints: Arc<Value<'a, T>>,
@@ -1105,6 +1118,7 @@ fn evaluate_builtin_function<'a, T: FieldElement>(
BuiltinFunction::ToFe => 1,
BuiltinFunction::ToInt => 1,
BuiltinFunction::Challenge => 2,
BuiltinFunction::SetHint => 2,
BuiltinFunction::Degree => 0,
BuiltinFunction::Eval => 1,
};
@@ -1140,7 +1154,7 @@ fn evaluate_builtin_function<'a, T: FieldElement>(
} else {
print!("{msg}");
}
Value::Array(Default::default()).into()
Value::Tuple(vec![]).into()
}
BuiltinFunction::ToExpr => {
let arg = arguments.pop().unwrap();
@@ -1173,6 +1187,12 @@ fn evaluate_builtin_function<'a, T: FieldElement>(
}))
.into()
}
BuiltinFunction::SetHint => {
let expr = arguments.pop().unwrap();
let col = arguments.pop().unwrap();
symbols.set_hint(col, expr)?;
Value::Tuple(vec![]).into()
}
BuiltinFunction::Degree => symbols.degree()?,
BuiltinFunction::Eval => {
let arg = arguments.pop().unwrap();

View File

@@ -11,7 +11,7 @@ use powdr_ast::{
use powdr_number::DegreeType;
use powdr_parser_util::SourceRef;
use std::{
collections::{HashMap, HashSet},
collections::{BTreeSet, HashMap, HashSet},
str::FromStr,
};
@@ -311,7 +311,7 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
}: LambdaExpression,
) -> LambdaExpression<Expression> {
let previous_local_vars = self.save_local_variables();
let previous_local_var_refs = self.local_var_references.clone();
let previous_local_var_refs = std::mem::take(&mut self.local_var_references);
let local_variable_height = self.local_variable_counter;
let params = params
@@ -326,11 +326,13 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
}
let body = Box::new(self.process_expression(*body));
let outer_var_references =
let outer_var_references: BTreeSet<u64> =
std::mem::replace(&mut self.local_var_references, previous_local_var_refs)
.into_iter()
.filter(|id| *id < local_variable_height)
.collect();
self.local_var_references
.extend(outer_var_references.clone());
self.reset_local_variables(previous_local_vars);
LambdaExpression {
kind,

View File

@@ -4,7 +4,10 @@ use powdr_ast::{
analyzed::{
Expression, FunctionValueDefinition, Reference, Symbol, SymbolKind, TypedExpression,
},
parsed::{types::Type, BlockExpression, FunctionKind, LambdaExpression, StatementInsideBlock},
parsed::{
types::Type, BlockExpression, FunctionCall, FunctionKind, LambdaExpression,
StatementInsideBlock,
},
};
use lazy_static::lazy_static;
@@ -79,6 +82,32 @@ impl<'a> SideEffectChecker<'a> {
}
e.children().try_for_each(|e| self.check(e))
}
Expression::FunctionCall(
_,
FunctionCall {
function,
arguments,
},
) if matches!(function.as_ref(), Expression::Reference(_, Reference::Poly(r)) if r.name == "std::prover::set_hint") =>
{
// The function "set_hint" is special: It expects a "query" function as
// second argument, so we switch context when descending into the second argument.
self.check(function)?;
match &arguments[..] {
[col, hint] => {
self.check(col)?;
assert_eq!(self.context, FunctionKind::Constr);
self.context = FunctionKind::Query;
let result = self.check(hint);
self.context = FunctionKind::Constr;
result
}
_ => {
// Not the correct number of arguments, will lead to a type error later.
arguments.iter().try_for_each(|e| self.check(e))
}
}
}
_ => e.children().try_for_each(|e| self.check(e)),
}
}
@@ -117,6 +146,7 @@ lazy_static! {
("std::field::modulus", FunctionKind::Pure),
("std::prelude::challenge", FunctionKind::Constr), // strictly, only new_challenge would need "constr"
("std::prover::degree", FunctionKind::Pure),
("std::prover::set_hint", FunctionKind::Constr),
("std::prover::eval", FunctionKind::Query),
]
.into_iter()

View File

@@ -48,6 +48,10 @@ lazy_static! {
("std::field::modulus", ("", "-> int")),
("std::prelude::challenge", ("", "int, int -> expr")),
("std::prover::degree", ("", "-> int")),
(
"std::prover::set_hint",
("", "expr, (int -> std::prover::Query) -> ()")
),
("std::prover::eval", ("", "expr -> fe")),
]
.into_iter()

View File

@@ -230,7 +230,7 @@ fn new_fixed_column() {
}
#[test]
#[should_panic = "Lambda expression for fixed column N.fi must not reference outer variables."]
#[should_panic = "Error creating fixed column N.fi: Lambda expression must not reference outer variables: (|i| (i + j) * 2)"]
fn new_fixed_column_as_closure() {
let input = r#"namespace N(16);
let f = constr |j| {
@@ -243,3 +243,166 @@ fn new_fixed_column_as_closure() {
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
fn set_hint() {
let input = r#"
namespace std::prover;
let set_hint = 8;
let eval = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
let y;
std::prover::set_hint(y, |i| std::prover::Query::Hint(std::prover::eval(x)));
{
let z;
std::prover::set_hint(z, query |_| std::prover::Query::Hint(1));
};
"#;
let expected = r#"namespace std::prover;
let set_hint = 8;
let eval = 8;
enum Query {
Hint(fe),
None,
}
namespace N(16);
col witness x;
col witness y(i) query std::prover::Query::Hint(std::prover::eval(N.x));
col witness z(_) query std::prover::Query::Hint(1);
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
#[should_panic = "Expected type: int -> std::prover::Query"]
fn set_hint_invalid_function() {
let input = r#"
namespace std::prover;
let set_hint = 8;
let eval = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
std::prover::set_hint(x, query |_, _| std::prover::Query::Hint(1));
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
#[should_panic = "Array elements are not supported for std::prover::set_hint (called on N.x[0])."]
fn set_hint_array_element() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x: col[2];
std::prover::set_hint(x[0], query |_| std::prover::Query::Hint(1));
"#;
let expected = r#"namespace std::prover;
let set_hint = 8;
let eval = 8;
enum Query {
Hint(fe),
None,
}
namespace N(16);
col witness x(_) query std::prover::Query::Hint(1);
col witness y(i) query std::prover::Query::Hint(std::prover::eval(N.x));
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
#[should_panic = "Expected reference to witness column as first argument for std::prover::set_hint, but got intermediate column N.y."]
fn set_hint_no_col() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
let y: expr = x;
std::prover::set_hint(y, query |_| std::prover::Query::Hint(1));
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
#[should_panic = "Column N.x already has a hint set, but tried to add another one."]
fn set_hint_twice() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
std::prover::set_hint(x, query |_| std::prover::Query::Hint(1));
std::prover::set_hint(x, query |_| std::prover::Query::Hint(2));
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
#[should_panic = "Column N.x already has a hint set, but tried to add another one."]
fn set_hint_twice_in_constr() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let y;
{
let x;
std::prover::set_hint(x, query |_| std::prover::Query::Hint(1));
std::prover::set_hint(x, query |_| std::prover::Query::Hint(2));
};
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
fn set_hint_outside() {
let input = r#"
namespace std::prover;
let set_hint = 8;
let eval = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
let y;
let create_wit = constr || { let w; w };
let z = create_wit();
let set_hint = constr |c| { std::prover::set_hint(c, query |_| std::prover::Query::Hint(8)); };
set_hint(x);
set_hint(y);
(|| { set_hint(z); })();
"#;
let expected = r#"namespace std::prover;
let set_hint = 8;
let eval = 8;
enum Query {
Hint(fe),
None,
}
namespace N(16);
col witness x(_) query std::prover::Query::Hint(8);
col witness y(_) query std::prover::Query::Hint(8);
let create_wit: -> expr = (constr || {
let w: col;
w
});
let z: expr = N.create_wit();
let set_hint: expr -> () = (constr |c| {
std::prover::set_hint(c, (query |_| std::prover::Query::Hint(8)));
});
col witness w(_) query std::prover::Query::Hint(8);
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}

View File

@@ -100,3 +100,58 @@ fn fixed_with_constr_type() {
let input = "let x: col = constr |i| 2;";
analyze_string::<GoldilocksField>(input);
}
#[test]
fn set_hint() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
std::prover::set_hint(x, query |i| std::prover::Query::Hint(1));
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
fn set_hint_can_use_query() {
let input = r#"
namespace std::prover;
let set_hint = 8;
let eval = 7;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
let y;
std::prover::set_hint(x, query |_| std::prover::Query::Hint(std::prover::eval(y)));
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
fn set_hint_pure() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
std::prover::set_hint(x, |i| std::prover::Query::Hint(1));
"#;
analyze_string::<GoldilocksField>(input);
}
#[test]
#[should_panic = "Used a constr lambda function inside a query context"]
fn set_hint_constr() {
let input = r#"
namespace std::prover;
let set_hint = 8;
enum Query { Hint(fe), None, }
namespace N(16);
let x;
std::prover::set_hint(x, constr |i| std::prover::Query::Hint(1));
"#;
analyze_string::<GoldilocksField>(input);
}