From 3dcdbdee85c3d4dee582a22cf06db65e903ac7ac Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 10 Mar 2023 17:19:47 +0100 Subject: [PATCH] Read-write memory. --- src/asm_compiler/mod.rs | 35 +-- .../double_sorted_witness_machine.rs | 218 ++++++++++++++++++ src/commit_evaluator/evaluator.rs | 9 +- src/commit_evaluator/fixed_lookup_machine.rs | 16 +- src/commit_evaluator/machine.rs | 4 +- src/commit_evaluator/machine_extractor.rs | 63 ++++- src/commit_evaluator/mod.rs | 1 + .../sorted_witness_machine.rs | 5 +- src/parser/asm_ast.rs | 8 +- src/parser/powdr.lalrpop | 11 +- tests/integration.rs | 5 + tests/mem_read_write.asm | 76 ++++++ 12 files changed, 410 insertions(+), 41 deletions(-) create mode 100644 src/commit_evaluator/double_sorted_witness_machine.rs create mode 100644 tests/mem_read_write.asm diff --git a/src/asm_compiler/mod.rs b/src/asm_compiler/mod.rs index a601817f4..837971681 100644 --- a/src/asm_compiler/mod.rs +++ b/src/asm_compiler/mod.rs @@ -199,22 +199,17 @@ impl ASMPILConverter { )), } } - InstructionBodyElement::PlookupIdentity(left, right) => { + InstructionBodyElement::PlookupIdentity(left, op, right) => { assert!(left.selector.is_none(), "LHS selector not supported, could and-combine with instruction flag later."); - self.pil.push(Statement::PlookupIdentity( - *start, - SelectedExpressions { - selector: Some(direct_reference(&instruction_flag)), - expressions: substitute_vec(&left.expressions, &substitutions), - }, - SelectedExpressions { - selector: right - .selector - .as_ref() - .map(|s| substitute(s, &substitutions)), - expressions: substitute_vec(&right.expressions, &substitutions), - }, - )); + let left = SelectedExpressions { + selector: Some(direct_reference(&instruction_flag)), + expressions: substitute_vec(&left.expressions, &substitutions), + }; + let right = substitute_selected_exprs(right, &substitutions); + self.pil.push(match op { + PlookupOperator::In => Statement::PlookupIdentity(*start, left, right), + PlookupOperator::Is => Statement::PermutationIdentity(*start, left, right), + }) } } } @@ -677,6 +672,16 @@ fn substitute(input: &Expression, substitution: &HashMap) -> Exp } } +fn substitute_selected_exprs( + input: &SelectedExpressions, + substitution: &HashMap, +) -> SelectedExpressions { + SelectedExpressions { + selector: input.selector.as_ref().map(|s| substitute(s, substitution)), + expressions: substitute_vec(&input.expressions, substitution), + } +} + fn substitute_vec(input: &[Expression], substitution: &HashMap) -> Vec { input.iter().map(|e| substitute(e, substitution)).collect() } diff --git a/src/commit_evaluator/double_sorted_witness_machine.rs b/src/commit_evaluator/double_sorted_witness_machine.rs new file mode 100644 index 000000000..a6d189cd6 --- /dev/null +++ b/src/commit_evaluator/double_sorted_witness_machine.rs @@ -0,0 +1,218 @@ +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::iter::once; + +use itertools::{Either, Itertools}; + +use crate::analyzer::PolynomialReference; +use crate::analyzer::{Expression, Identity, IdentityKind, SelectedExpressions}; +use crate::commit_evaluator::eval_error; +use crate::commit_evaluator::machine::LookupReturn; +use crate::number::AbstractNumberType; + +use super::affine_expression::AffineExpression; +use super::eval_error::EvalError; +use super::machine::{LookupResult, Machine}; +use super::FixedData; + +/// TODO make this generic + +#[derive(Default)] +pub struct DoubleSortedWitnesses { + //key_col: String, + /// Position of the witness columns in the data. + /// The key column has a position of usize::max + //witness_positions: HashMap, + /// (addr, step) -> value + trace: BTreeMap<(AbstractNumberType, AbstractNumberType), Operation>, + data: BTreeMap, +} + +struct Operation { + pub is_write: bool, + pub value: AbstractNumberType, +} + +impl DoubleSortedWitnesses { + pub fn try_new( + _fixed_data: &FixedData, + _identities: &[&Identity], + witness_names: &HashSet<&str>, + ) -> Option> { + // TODO check the identities. + let expected_witnesses: HashSet<_> = [ + "Assembly.m_value", + "Assembly.m_addr", + "Assembly.m_step", + "Assembly.m_change", + "Assembly.m_op", + "Assembly.m_is_write", + "Assembly.m_is_read", + ] + .into_iter() + .collect(); + if expected_witnesses + .symmetric_difference(witness_names) + .next() + .is_none() + { + Some(Box::default()) + } else { + None + } + } +} + +impl Machine for DoubleSortedWitnesses { + fn process_plookup( + &mut self, + fixed_data: &FixedData, + kind: IdentityKind, + left: &[Result], + right: &SelectedExpressions, + ) -> LookupResult { + if kind != IdentityKind::Permutation + || (right.selector + != Some(Expression::PolynomialReference(PolynomialReference { + name: "Assembly.m_is_read".to_owned(), + index: None, + next: false, + })) + && right.selector + != Some(Expression::PolynomialReference(PolynomialReference { + name: "Assembly.m_is_write".to_owned(), + index: None, + next: false, + }))) + { + return Ok(LookupReturn::NotApplicable); + } + + // We blindly assume the lookup is of the form + // OP { ADDR, STEP, X } is m_is_write { m_addr, m_step, m_value } + // or + // OP { ADDR, STEP, X } is m_is_read { m_addr, m_step, m_value } + + // Fail if the LHS has an error. + let (left, errors): (Vec<_>, Vec<_>) = left.iter().partition_map(|x| match x { + Ok(x) => Either::Left(x), + Err(x) => Either::Right(x), + }); + if !errors.is_empty() { + return Err(errors + .into_iter() + .cloned() + .reduce(eval_error::combine) + .unwrap()); + } + + let is_write = match &right.selector { + Some(Expression::PolynomialReference(p)) => p.name == "Assembly.m_is_write", + _ => panic!(), + }; + let addr = left[0].constant_value().ok_or_else(|| { + format!( + "Address must be known: {} = {}", + left[0].format(fixed_data), + right.expressions[0] + ) + })?; + let step = left[1].constant_value().ok_or_else(|| { + format!( + "Step must be known: {} = {}", + left[1].format(fixed_data), + right.expressions[1] + ) + })?; + + println!( + "Query addr={addr}, step={step}, write: {is_write}, left: {}", + left[2].format(fixed_data) + ); + + // TODO this does not check any of the failure modes + let mut assignments = vec![]; + if is_write { + let value = match left[2].constant_value() { + Some(v) => v, + None => return Ok(LookupReturn::Assignments(vec![])), + }; + if fixed_data.verbose { + println!("Memory write: addr={addr}, step={step}, value={value}"); + } + self.data.insert(addr.clone(), value.clone()); + self.trace + .insert((addr, step), Operation { is_write, value }); + } else { + let value = self.data.entry(addr.clone()).or_default(); + self.trace.insert( + (addr.clone(), step.clone()), + Operation { + is_write, + value: value.clone(), + }, + ); + if fixed_data.verbose { + println!("Memory read: addr={addr}, step={step}, value={value}"); + } + assignments.push(match (left[2].clone() - value.clone().into()).solve() { + Some(ass) => ass, + None => return Ok(LookupReturn::Assignments(vec![])), + }); + } + Ok(LookupReturn::Assignments(assignments)) + } + + fn witness_col_values( + &mut self, + fixed_data: &FixedData, + ) -> HashMap> { + let mut addr = vec![]; + let mut step = vec![]; + let mut value = vec![]; + let mut op = vec![]; + let mut is_write = vec![]; + let mut is_read = vec![]; + + for ((a, s), o) in std::mem::take(&mut self.trace) { + addr.push(a.clone()); + step.push(s.clone()); + value.push(o.value); + op.push(1.into()); + + is_write.push((if o.is_write { 1 } else { 0 }).into()); + is_read.push((if o.is_write { 0 } else { 1 }).into()); + } + if addr.is_empty() { + todo!(); + } + while addr.len() < fixed_data.degree as usize { + addr.push(addr.last().unwrap().clone()); + step.push(step.last().unwrap().clone() + 1); + value.push(value.last().unwrap().clone()); + op.push(0.into()); + is_write.push(0.into()); + is_read.push(0.into()); + } + + let change = addr + .iter() + .tuple_windows() + .map(|(a, a_next)| if a == a_next { 0.into() } else { 1.into() }) + .chain(once(1.into())) + .collect::>(); + assert_eq!(change.len(), addr.len()); + + vec![ + ("Assembly.m_value", value), + ("Assembly.m_addr", addr), + ("Assembly.m_step", step), + ("Assembly.m_change", change), + ("Assembly.m_op", op), + ("Assembly.m_is_write", is_write), + ("Assembly.m_is_read", is_read), + ] + .into_iter() + .map(|(n, v)| (n.to_string(), v)) + .collect() + } +} diff --git a/src/commit_evaluator/evaluator.rs b/src/commit_evaluator/evaluator.rs index e6e1e72ef..af302ee14 100644 --- a/src/commit_evaluator/evaluator.rs +++ b/src/commit_evaluator/evaluator.rs @@ -83,7 +83,9 @@ where IdentityKind::Polynomial => { self.process_polynomial_identity(identity.left.selector.as_ref().unwrap()) } - IdentityKind::Plookup => self.process_plookup(identity), + IdentityKind::Plookup | IdentityKind::Permutation => { + self.process_plookup(identity) + } _ => Err("Unsupported lookup type".to_string().into()), } .map_err(|err| { @@ -260,9 +262,6 @@ where } }; } - if identity.right.selector.is_some() { - return Err("Selectors at the RHS not yet supported.".to_string().into()); - } let left = identity .left @@ -278,7 +277,7 @@ where for m in &mut self.machines { // TODO also consider the reasons above. if let LookupReturn::Assignments(assignments) = - m.process_plookup(self.fixed_data, &left, &identity.right)? + m.process_plookup(self.fixed_data, identity.kind, &left, &identity.right)? { return Ok(assignments); } diff --git a/src/commit_evaluator/fixed_lookup_machine.rs b/src/commit_evaluator/fixed_lookup_machine.rs index cae2dfba0..6f76c86a4 100644 --- a/src/commit_evaluator/fixed_lookup_machine.rs +++ b/src/commit_evaluator/fixed_lookup_machine.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; -use crate::analyzer::{Expression, Identity, SelectedExpressions}; +use crate::analyzer::{Expression, Identity, IdentityKind, SelectedExpressions}; use crate::commit_evaluator::eval_error; use crate::commit_evaluator::expression_evaluator::ExpressionEvaluator; use crate::commit_evaluator::machine::LookupReturn; @@ -35,15 +35,17 @@ impl Machine for FixedLookup { fn process_plookup( &mut self, fixed_data: &FixedData, + kind: IdentityKind, left: &[Result], right: &SelectedExpressions, ) -> LookupResult { - // This is a matching machine if the RHS is fully constant. - assert!(right.selector.is_none()); - if right - .expressions - .iter() - .any(|e| contains_witness_ref(e, fixed_data)) + // This is a matching machine if it is a plookup and the RHS is fully constant. + if kind != IdentityKind::Plookup + || right.selector.is_some() + || right + .expressions + .iter() + .any(|e| contains_witness_ref(e, fixed_data)) { return Ok(LookupReturn::NotApplicable); } diff --git a/src/commit_evaluator/machine.rs b/src/commit_evaluator/machine.rs index 9aca32011..2d75ee9ed 100644 --- a/src/commit_evaluator/machine.rs +++ b/src/commit_evaluator/machine.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use crate::{analyzer::SelectedExpressions, number::AbstractNumberType}; +use crate::analyzer::{IdentityKind, SelectedExpressions}; +use crate::number::AbstractNumberType; use super::{affine_expression::AffineExpression, eval_error::EvalError, FixedData}; @@ -21,6 +22,7 @@ pub trait Machine { fn process_plookup( &mut self, fixed_data: &FixedData, + kind: IdentityKind, left: &[Result], right: &SelectedExpressions, ) -> LookupResult; diff --git a/src/commit_evaluator/machine_extractor.rs b/src/commit_evaluator/machine_extractor.rs index d73f6ba1d..2e19e8aba 100644 --- a/src/commit_evaluator/machine_extractor.rs +++ b/src/commit_evaluator/machine_extractor.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use crate::analyzer::{Expression, Identity, SelectedExpressions}; +use super::double_sorted_witness_machine::DoubleSortedWitnesses; use super::fixed_lookup_machine::FixedLookup; use super::machine::Machine; @@ -20,25 +21,26 @@ pub fn split_out_machines<'a>( // We could also split the machine into independent sub-machines. // The lookup-in-fixed-columns machine, it always exists with an empty set of witnesses. - let fixed_lookup = FixedLookup::try_new(fixed, &[], &Default::default()).unwrap(); + let mut machines: Vec> = + vec![FixedLookup::try_new(fixed, &[], &Default::default()).unwrap()]; + let witness_names = witness_cols.iter().map(|c| c.name).collect::>(); let all_witnesses = ReferenceExtractor::new(witness_names.clone()); // Extract all witness columns in the RHS of lookups. - let machine_witnesses = identities + let lookup_witnesses = identities .iter() .map(|i| all_witnesses.in_selected_expressions(&i.right)) .reduce(|l, r| &l | &r) .unwrap_or_default(); - if machine_witnesses.is_empty() { - return (vec![fixed_lookup], identities.iter().collect()); - } + // Recursively extend the set to all witnesses connected through identities. + let machine_witnesses = all_connected_witnesses(&all_witnesses, lookup_witnesses, identities); let machine_witness_extractor = ReferenceExtractor::new(machine_witnesses.clone()); // Split identities into those that only concern the machine // witnesses and those that concern any other witness. - let (machine_identities, identities): (Vec<_>, _) = identities.iter().partition(|i| { - // The identity has at least one a machine witness, but + let (machine_identities, base_identities): (Vec<_>, _) = identities.iter().partition(|i| { + // The identity has at least one machine witness, but // all referenced witnesses are machine witnesses. let mw = machine_witness_extractor.in_identity(i); !mw.is_empty() && all_witnesses.in_identity(i).is_subset(&mw) @@ -49,9 +51,50 @@ pub fn split_out_machines<'a>( if let Some(machine) = SortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses) { - (vec![machine, fixed_lookup], identities) - } else { - (vec![fixed_lookup], identities) + machines.push(machine); + } else if let Some(machine) = + DoubleSortedWitnesses::try_new(fixed, &machine_identities, &machine_witnesses) + { + machines.push(machine); + } + (machines, base_identities) +} + +fn all_connected_witnesses<'a>( + all_witnesses: &'a ReferenceExtractor, + mut witnesses: HashSet<&'a str>, + identities: &'a [Identity], +) -> HashSet<&'a str> { + let mut count = witnesses.len(); + loop { + for i in identities { + match i.kind { + crate::analyzer::IdentityKind::Polynomial => { + // Any current witness in the identity adds all other witnesses. + let in_identity = all_witnesses.in_identity(i); + if in_identity.intersection(&witnesses).next().is_some() { + witnesses.extend(in_identity); + } + } + crate::analyzer::IdentityKind::Plookup + | crate::analyzer::IdentityKind::Permutation + | crate::analyzer::IdentityKind::Connect => { + // If we already have witnesses on the LHS, include the RHS, but not vice-versa. + let in_lhs = all_witnesses.in_selected_expressions(&i.left); + let in_rhs = all_witnesses.in_selected_expressions(&i.right); + if in_lhs.intersection(&witnesses).next().is_some() { + witnesses.extend(in_lhs); + witnesses.extend(in_rhs); + } else if in_rhs.intersection(&witnesses).next().is_some() { + witnesses.extend(in_rhs); + } + } + }; + } + if witnesses.len() == count { + return witnesses; + } + count = witnesses.len() } } diff --git a/src/commit_evaluator/mod.rs b/src/commit_evaluator/mod.rs index 430aa847e..091052abf 100644 --- a/src/commit_evaluator/mod.rs +++ b/src/commit_evaluator/mod.rs @@ -7,6 +7,7 @@ use self::eval_error::EvalError; use self::util::WitnessColumnNamer; mod affine_expression; +mod double_sorted_witness_machine; mod eval_error; mod evaluator; mod expression_evaluator; diff --git a/src/commit_evaluator/sorted_witness_machine.rs b/src/commit_evaluator/sorted_witness_machine.rs index 477de9082..6f2f5f054 100644 --- a/src/commit_evaluator/sorted_witness_machine.rs +++ b/src/commit_evaluator/sorted_witness_machine.rs @@ -120,10 +120,13 @@ impl Machine for SortedWitnesses { fn process_plookup( &mut self, fixed_data: &FixedData, + kind: IdentityKind, left: &[Result], right: &SelectedExpressions, ) -> LookupResult { - assert!(right.selector.is_none()); + if kind != IdentityKind::Plookup || right.selector.is_some() { + return Ok(LookupReturn::NotApplicable); + } let rhs = right .expressions .iter() diff --git a/src/parser/asm_ast.rs b/src/parser/asm_ast.rs index c4ad9b2f2..62a1d8c19 100644 --- a/src/parser/asm_ast.rs +++ b/src/parser/asm_ast.rs @@ -37,5 +37,11 @@ pub struct InstructionParam { #[derive(Debug, PartialEq, Eq, Clone)] pub enum InstructionBodyElement { Expression(Expression), - PlookupIdentity(SelectedExpressions, SelectedExpressions), + PlookupIdentity(SelectedExpressions, PlookupOperator, SelectedExpressions), +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum PlookupOperator { + In, + Is, } diff --git a/src/parser/powdr.lalrpop b/src/parser/powdr.lalrpop index 9668553ba..b05bae684 100644 --- a/src/parser/powdr.lalrpop +++ b/src/parser/powdr.lalrpop @@ -167,7 +167,16 @@ InstructionBodyElements: Vec = { InstructionBodyElement: InstructionBodyElement = { "=" => InstructionBodyElement::Expression(Expression::BinaryOperation(l, BinaryOperator::Sub, r)), - "in" => InstructionBodyElement::PlookupIdentity(<>), + => InstructionBodyElement::PlookupIdentity(<>), +} + +// This is only valid in instructions, not in PIL in general. +// "connect" is not supported because it does not support selectors +// and we need that for the instruction. + +PlookupOperator: PlookupOperator = { + "in" => PlookupOperator::In, + "is" => PlookupOperator::Is, } InstructionParamList: Vec = { diff --git a/tests/integration.rs b/tests/integration.rs index 8ffe20364..48637eeb0 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -141,3 +141,8 @@ fn palindrome() { [7, 1, 7, 3, 9, 3, 7, 1].iter().map(|&x| x.into()).collect(), ); } + +#[test] +fn test_mem_read_write() { + verify_asm("mem_read_write.asm", Default::default()); +} diff --git a/tests/mem_read_write.asm b/tests/mem_read_write.asm new file mode 100644 index 000000000..5036e7617 --- /dev/null +++ b/tests/mem_read_write.asm @@ -0,0 +1,76 @@ +reg pc[@pc]; +reg X[<=]; +reg A; +reg B; +reg I; +reg CNT; +reg ADDR; + +pil{ + col witness XInv; + col witness XIsZero; + XIsZero = 1 - X * XInv; + XIsZero * X = 0; + XIsZero * (1 - XIsZero) = 0; + + // Read-write memory. Columns are sorted by m_addr and + // then by m_step. m_change is 1 if and only if m_addr changes + // in the next row. + col witness m_addr; + col witness m_step; + col witness m_change; + col witness m_value; + // If we have an operation at all (needed because this needs to be a permutation) + col witness m_op; + // If the operation is a write operation. + col witness m_is_write; + col witness m_is_read; + + // positive numbers (assumed to be much smaller than the field order) + col fixed POSITIVE(i) { i + 1 }; + col fixed FIRST = [1]; + col fixed LAST(i) { FIRST(i + 1) }; + col fixed STEP(i) { i }; + + m_change * (1 - m_change) = 0; + + // if m_change is zero, m_addr has to stay the same. + (m_addr' - m_addr) * (1 - m_change) = 0; + + // Except for the last row, if m_change is 1, then m_addr has to increase, + // if it is zero, m_step has to increase. + (1 - LAST) { m_change * (m_addr' - m_addr) + (1 - m_change) * (m_step' - m_step) } in POSITIVE; + + m_op * (1 - m_op) = 0; + m_is_write * (1 - m_is_write) = 0; + m_is_read * (1 - m_is_read) = 0; + // m_is_write can only be 1 if m_op is 1. + m_is_write * (1 - m_op) = 0; + m_is_read * (1 - m_op) = 0; + m_is_read * m_is_write = 0; + + + // If the next line is a read and we stay at the same address, then the + // value cannot change. + (1 - m_is_write') * (1 - m_change) * (m_value' - m_value) = 0; + + // If the next line is a read and we have an address change, + // then the value is zero. + (1 - m_is_write') * m_change * m_value' = 0; + +} + +instr assert_zero <=X= a { XIsZero = 1 } +instr mstore <=X= val { { ADDR, STEP, X } is m_is_write { m_addr, m_step, m_value } } +instr mload r <=X= { { ADDR, STEP, X } is m_is_read { m_addr, m_step, m_value } } + + +ADDR <=X= 3; +mstore 1; +ADDR <=X= 2; +mstore 4; +mload A; +assert_zero A - 4; +ADDR <=X= 3; +mload A; +assert_zero A - 1;