User-directed transcript-based RAM checking. (#176)

* user-directed transcript-based RAM checking

when the `transcript` type qualifier is used.

* fix ram tests

* fmt & lint

* rm incomplete test
This commit is contained in:
Alex Ozdemir
2023-11-02 23:50:16 -07:00
committed by GitHub
parent c0355299df
commit 68b0b45556
18 changed files with 251 additions and 95 deletions

View File

@@ -0,0 +1,10 @@
const u32 N = 100
const u32 A = 100
const field[N] TABLE = [4, ...[5; N-1]]
def main(field[A] is) -> field:
field sum = 0
for u32 i in 0..A do
sum = sum + TABLE[is[i]]
endfor
return sum

View File

@@ -6,7 +6,7 @@ struct Pt {
field y
field z
}
const Pt [LEN] array = [Pt {x: 4, y: 5, z: 6}, ...[Pt {x: 0, y: 1, z: 2}; LEN - 1]]
const transcript Pt [LEN] array = [Pt {x: 4, y: 5, z: 6}, ...[Pt {x: 0, y: 1, z: 2}; LEN - 1]]
def main(private field[ACCESSES] idx) -> field:
field prod = 1

View File

@@ -1,7 +1,7 @@
const u32 LEN = 4
const u32 ACCESSES = 2
const field[LEN] array = [0, ...[100; LEN-1]]
const transcript field[LEN] array = [0, ...[100; LEN-1]]
def main(private field[ACCESSES] y) -> field:
field result = 0

View File

@@ -3,7 +3,7 @@ const u32 LEN = 8196
const field ACC = 10
def main(private field x, private field y, private bool b) -> field:
field[LEN] array = [0; LEN]
transcript field[LEN] array = [0; LEN]
for field i in 0..ACC do
cond_store(array, x+i, 1f, b)
endfor

View File

@@ -8,7 +8,7 @@ struct Pt {
}
def main(private field x, private field y, private bool b) -> field:
Pt [LEN] array = [Pt {x: 0, y: 0} ; LEN]
transcript Pt [LEN] array = [Pt {x: 0, y: 0} ; LEN]
for field i in 0..ACCESSES do
array[x+i] = if b then Pt{x : 1, y: i} else array[x+i] fi
endfor

View File

@@ -270,18 +270,12 @@ fn main() {
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::Obliv);
// The obliv elim pass produces more tuples, that must be eliminated
if options.circ.ram.enabled {
// Waksman can only route scalars, so tuple first!
if options.circ.ram.permutation == circ_opt::PermutationStrategy::Waksman {
opts.push(Opt::Tuple);
}
opts.push(Opt::PersistentRam);
opts.push(Opt::VolatileRam);
opts.push(Opt::SkolemizeChallenges);
opts.push(Opt::ScalarizeVars);
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::Obliv);
}
opts.push(Opt::PersistentRam);
opts.push(Opt::VolatileRam);
opts.push(Opt::SkolemizeChallenges);
opts.push(Opt::ScalarizeVars);
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::Obliv);
opts.push(Opt::LinearScan);
// The linear scan pass produces more tuples, that must be eliminated
opts.push(Opt::Tuple);

View File

@@ -16,7 +16,7 @@ function ram_test {
proof_impl=$2
ex_name=$1
rm -rf P V pi
$BIN --ram true $=3 $ex_name r1cs --action setup --proof-impl $proof_impl
$BIN $=3 $ex_name r1cs --action setup --proof-impl $proof_impl
$ZK_BIN --inputs $ex_name.pin --action prove --proof-impl $proof_impl
$ZK_BIN --inputs $ex_name.vin --action verify --proof-impl $proof_impl
rm -rf P V pi
@@ -24,8 +24,10 @@ function ram_test {
ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/volatile.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/arr_of_str.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
# waksman is broken for non-scalar array values
# ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
# waksman is broken for non-scalar array values
# ram_test ./examples/ZoKrates/pf/mem/arr_of_str.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/volatile.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok mirage ""

View File

@@ -11,7 +11,7 @@ use crate::front::proof::PROVER_ID;
use crate::ir::proof::ConstraintMetadata;
use crate::ir::term::*;
use log::{debug, trace, warn};
use log::{debug, info, trace, warn};
use rug::Integer;
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
@@ -141,7 +141,11 @@ fn loc_store(struct_: T, loc: &[ZAccess], val: T) -> Result<T, String> {
enum ZVis {
Public,
Private(u8),
}
enum ArrayParamMetadata {
Committed,
Transcript,
}
impl<'ast> ZGen<'ast> {
@@ -685,12 +689,25 @@ impl<'ast> ZGen<'ast> {
for p in f.parameters.iter() {
let ty = self.type_(&p.ty);
debug!("Entry param: {}: {}", p.id.value, ty);
let md = self.interpret_array_md(&p.array_metadata);
let vis = self.interpret_visibility(&p.visibility);
if let ZVis::Committed = &vis {
persistent_arrays.push(p.id.value.clone());
let r = self.circ_declare_input(p.id.value.clone(), &ty, vis, None, false, &md);
let unwrapped = self.unwrap(r, &p.span);
if let Some(md_some) = md {
match md_some {
ArrayParamMetadata::Committed => {
info!(
"Input committed array of type {} in {:?}",
ty,
self.file_stack.borrow().last().unwrap()
);
persistent_arrays.push(p.id.value.clone());
}
ArrayParamMetadata::Transcript => {
self.mark_array_as_transcript(&p.id.value, unwrapped);
}
}
}
let r = self.circ_declare_input(p.id.value.clone(), &ty, vis, None, false);
self.unwrap(r, &p.span);
}
for s in &f.statements {
self.unwrap(self.stmt_impl_::<false>(s), s.span());
@@ -722,7 +739,14 @@ impl<'ast> ZGen<'ast> {
let name = "return".to_owned();
let ret_val = r.unwrap_term();
let ret_var_val = self
.circ_declare_input(name, ty, ZVis::Public, Some(ret_val.clone()), false)
.circ_declare_input(
name,
ty,
ZVis::Public,
Some(ret_val.clone()),
false,
&None,
)
.expect("circ_declare return");
let ret_eq = eq(ret_val, ret_var_val).unwrap().term;
let mut assertions = std::mem::take(&mut *self.assertions.borrow_mut());
@@ -772,13 +796,20 @@ impl<'ast> ZGen<'ast> {
}
}
}
fn interpret_array_md(
&self,
md: &Option<ast::ArrayParamMetadata<'ast>>,
) -> Option<ArrayParamMetadata> {
match md {
Some(ast::ArrayParamMetadata::Committed(_)) => Some(ArrayParamMetadata::Committed),
Some(ast::ArrayParamMetadata::Transcript(_)) => Some(ArrayParamMetadata::Transcript),
None => None,
}
}
fn interpret_visibility(&self, visibility: &Option<ast::Visibility<'ast>>) -> ZVis {
match visibility {
None | Some(ast::Visibility::Public(_)) => ZVis::Public,
Some(ast::Visibility::Committed(_)) => match self.mode {
Mode::Proof => ZVis::Committed,
_ => unimplemented!(),
},
Some(ast::Visibility::Private(private)) => match self.mode {
Mode::Proof | Mode::Opt | Mode::ProofOfHighValue(_) => {
if private.number.is_some() {
@@ -1230,7 +1261,16 @@ impl<'ast> ZGen<'ast> {
l.identifier.value.clone(),
decl_ty,
e,
)
)?;
let md = self.interpret_array_md(&l.array_metadata);
if let Some(ArrayParamMetadata::Transcript) = md {
let value = self
.circ_get_value(Loc::local(l.identifier.value.clone()))
.map_err(|e| format!("{e}"))?
.unwrap_term();
self.mark_array_as_transcript(&l.identifier.value, value);
}
Ok(())
}
}
} else {
@@ -1465,6 +1505,13 @@ impl<'ast> ZGen<'ast> {
);
}
if let Some(ast::ArrayParamMetadata::Transcript(_)) = &c.array_metadata {
if !value.type_().is_array() {
self.err(format!("Non-array transcript {}", &c.id.value), &c.span);
}
self.mark_array_as_transcript(&c.id.value, value.clone());
}
// insert into constant map
if self
.constants
@@ -1849,6 +1896,22 @@ impl<'ast> ZGen<'ast> {
}
}
fn mark_array_as_transcript(&self, name: &str, array: T) {
info!(
"Transcript array {} of type {} in {:?}",
name,
array.ty,
self.file_stack.borrow().last().unwrap()
);
self.circ
.borrow()
.cir_ctx()
.cs
.borrow_mut()
.ram_arrays
.insert(array.term);
}
/*** circify wrapper functions (hides RefCell) ***/
fn circ_enter_condition(&self, cond: Term) {
@@ -1894,32 +1957,30 @@ impl<'ast> ZGen<'ast> {
vis: ZVis,
precomputed_value: Option<T>,
mangle_name: bool,
md: &Option<ArrayParamMetadata>,
) -> Result<T, CircError> {
match vis {
ZVis::Public => {
self.circ
.borrow_mut()
.declare_input(name, ty, None, precomputed_value, mangle_name)
}
ZVis::Private(i) => self.circ.borrow_mut().declare_input(
if let Some(ArrayParamMetadata::Committed) = md {
let size = match ty {
Ty::Array(size, _) => *size,
_ => panic!(),
};
Ok(self.circ.borrow_mut().start_persistent_array(
&name,
size,
default_field(),
crate::front::proof::PROVER_ID,
))
} else {
self.circ.borrow_mut().declare_input(
name,
ty,
Some(i),
match vis {
ZVis::Public => None,
ZVis::Private(i) => Some(i),
},
precomputed_value,
mangle_name,
),
ZVis::Committed => {
let size = match ty {
Ty::Array(size, _) => *size,
_ => panic!(),
};
Ok(self.circ.borrow_mut().start_persistent_array(
&name,
size,
default_field(),
crate::front::proof::PROVER_ID,
))
}
)
}
}

View File

@@ -102,6 +102,10 @@ impl Ty {
_ => panic!("Not an array type: {:?}", self),
}
}
/// Is this an array?
pub fn is_array(&self) -> bool {
matches!(self, Self::Array(_, _) | Self::MutArray(_))
}
}
#[derive(Clone, Debug)]

View File

@@ -189,6 +189,17 @@ pub fn walk_parameter<'ast, Z: ZVisitorMut<'ast>>(
visitor.visit_span(&mut param.span)
}
pub fn walk_array_param_metadata<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
vis: &mut ast::ArrayParamMetadata<'ast>,
) -> ZVisitorResult {
use ast::ArrayParamMetadata::*;
match vis {
Committed(x) => visitor.visit_array_committed(x),
Transcript(x) => visitor.visit_array_transcript(x),
}
}
pub fn walk_visibility<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
vis: &mut ast::Visibility<'ast>,
@@ -196,7 +207,6 @@ pub fn walk_visibility<'ast, Z: ZVisitorMut<'ast>>(
use ast::Visibility::*;
match vis {
Public(pu) => visitor.visit_public_visibility(pu),
Committed(c) => visitor.visit_commited_visibility(c),
Private(pr) => visitor.visit_private_visibility(pr),
}
}

View File

@@ -113,7 +113,18 @@ pub trait ZVisitorMut<'ast>: Sized {
Ok(())
}
fn visit_commited_visibility(&mut self, _c: &mut ast::CommittedVisibility) -> ZVisitorResult {
fn visit_array_param_metadata(
&mut self,
vis: &mut ast::ArrayParamMetadata<'ast>,
) -> ZVisitorResult {
walk_array_param_metadata(self, vis)
}
fn visit_array_committed(&mut self, _c: &mut ast::ArrayCommitted<'ast>) -> ZVisitorResult {
Ok(())
}
fn visit_array_transcript(&mut self, _c: &mut ast::ArrayTranscript<'ast>) -> ZVisitorResult {
Ok(())
}

View File

@@ -5,7 +5,7 @@ use fxhash::FxHashMap as HashMap;
use fxhash::FxHashSet as HashSet;
use std::collections::BinaryHeap;
use log::trace;
use log::{debug, trace};
/// Graph of the *arrays* in the computation.
///
@@ -72,6 +72,7 @@ impl ArrayGraph {
let mut cs = TermMap::default();
let mut arrs = TermSet::default();
// locate all array terms
for t in c.terms_postorder() {
if check(&t).is_array() {
arrs.insert(t.clone());
@@ -80,6 +81,7 @@ impl ArrayGraph {
}
}
// compute parents and children
for t in c.terms_postorder() {
if check(&t).is_array() {
for c in t.cs() {
@@ -90,12 +92,16 @@ impl ArrayGraph {
}
}
}
let mut ram_terms: TermSet = TermSet::default();
// first, we grow the set of RAM terms, from leaves towards dependents.
{
let mut stack: Vec<Term> = arrs
// we start with the explicitly marked RAMs
trace!("Starting with {} RAMS", c.ram_arrays.len());
let mut stack: Vec<Term> = c
.ram_arrays
.iter()
.filter(|a| right_sort(a, field) && array_leaf(a))
.filter(|a| arrs.contains(a))
.cloned()
.collect();
while let Some(top) = stack.pop() {
@@ -371,6 +377,10 @@ pub fn extract(c: &mut Computation, cfg: AccessCfg) -> Vec<Ram> {
/// Extract any volatile RAMS from a computation, and emit checks.
pub fn apply(c: &mut Computation, cfg: &AccessCfg) {
if c.ram_arrays.is_empty() {
debug!("Skipping VolatileRam; no RAM arrays");
return;
}
let rams = extract(c, cfg.clone());
if !rams.is_empty() {
for ram in rams {
@@ -392,6 +402,7 @@ mod test {
(computation
(metadata (parties ) (inputs ) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 4 ()))
(set_default_modulus 11
(let
(
@@ -422,6 +433,7 @@ mod test {
(computation
(metadata (parties ) (inputs ) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 4 ()))
(set_default_modulus 11
(let
(
@@ -464,6 +476,7 @@ mod test {
(computation
(metadata (parties ) (inputs (a bool)) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 4 ()))
(set_default_modulus 11
(let
(
@@ -506,6 +519,7 @@ mod test {
(computation
(metadata (parties ) (inputs (a bool)) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 4 ()))
(set_default_modulus 11
(let
(
@@ -547,6 +561,7 @@ mod test {
(computation
(metadata (parties ) (inputs (a bool)) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f000m11 4 ()))
(set_default_modulus 11
(let
(
@@ -593,6 +608,7 @@ mod test {
(computation
(metadata (parties ) (inputs (a bool)) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 16 ()))
(set_default_modulus 11
(let
(
@@ -622,6 +638,7 @@ mod test {
(computation
(metadata (parties ) (inputs ) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 4 ()))
(set_default_modulus 11
(let
(
@@ -653,6 +670,7 @@ mod test {
(computation
(metadata (parties ) (inputs ) (commitments))
(precompute () () (#t ))
(ram_arrays (#a (mod 11) #f0m11 4 ()))
(set_default_modulus 11
(let
(
@@ -708,6 +726,7 @@ mod test {
(commitments)
)
(precompute () () (#t ))
(ram_arrays ((fill (mod 101) 4) #f0m11))
(set_default_modulus 101
(let(
('1 ((fill (mod 101) 4) #f0))

View File

@@ -72,13 +72,15 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
}
Opt::ConstantFold(ignore) => {
let mut cache = TermCache::with_capacity(TERM_CACHE_LIMIT);
cache.resize(std::usize::MAX);
for a in &mut c.outputs {
// allow unbounded size during a single fold_cache call
cache.resize(std::usize::MAX);
*a = cfold::fold_cache(a, &mut cache, &ignore.clone());
// then shrink back down to size between calls
cache.resize(TERM_CACHE_LIMIT);
}
c.ram_arrays = c
.ram_arrays
.iter()
.map(|a| cfold::fold_cache(a, &mut cache, &ignore.clone()))
.collect();
}
Opt::Sha => {
for a in &mut c.outputs {

View File

@@ -77,6 +77,7 @@ impl Constraints for Computation {
metadata,
precomputes: Default::default(),
persistent_arrays: Default::default(),
ram_arrays: Default::default(),
}
}
}

View File

@@ -2072,6 +2072,8 @@ pub struct Computation {
/// * name: a variable name (array type) indicating the input state
/// * name: a term indicating the output state
pub persistent_arrays: Vec<(String, Term)>,
/// Check these arrays using RAM transcripts
pub ram_arrays: TermSet,
}
impl Computation {
@@ -2221,6 +2223,7 @@ impl Computation {
metadata: ComputationMetadata::default(),
precomputes: Default::default(),
persistent_arrays: Default::default(),
ram_arrays: Default::default(),
}
}

View File

@@ -15,7 +15,7 @@
//! * `I`: integer (arbitrary-precision)
//! * `X`: identifier
//! * regex: `[^()0-9#; \t\n\f][^(); \t\n\f#]*`
//! * Computation `C`: `(computation M P ARRAYS T)`
//! * Computation `C`: `(computation M P PERSISTENT_ARRAYS RAM_ARRAYS T)`
//! * Metadata `M`: `(metadata PARTIES INPUTS COMMITMENTS)`
//! * PARTIES is `(parties X1 .. Xn)`
//! * INPUTS is `(inputs INPUT1 .. INPUTn)`
@@ -28,11 +28,12 @@
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
//! * OUTPUTS is `((X1 S1) .. (Xn Sn))`
//! * TUPLE_TERM is a tuple of the same arity as the output
//! * ARRAYS (optional): `(persistent_arrays ARRAY*)`:
//! * PERSISTENT_ARRAYS (optional): `(persistent_arrays ARRAY*)`:
//! * ARRAY is `(X S T)`
//! * X is the name of the inital state
//! * S is the size
//! * T is the state (final)
//! * RAM_ARRAYS (optional): `(ram_arrays T*)`:
//! * Sort `S`:
//! * `bool`
//! * `f32`
@@ -723,10 +724,10 @@ impl<'src> IrInterp<'src> {
let (metadata, input_names) = self.metadata(&tts[0]);
let precomputes = self.precompute(&tts[1]);
let mut persistent_arrays = Vec::new();
let mut skip_one = false;
if let List(tts_inner) = &tts[2] {
let mut ram_arrays = Vec::new();
let mut num_skipped = 0;
while let List(tts_inner) = &tts[2 + num_skipped] {
if tts_inner[0] == Leaf(Token::Ident, b"persistent_arrays") {
skip_one = true;
for tti in tts_inner.iter().skip(1) {
let ttis = self.unwrap_list(tti, "persistent_arrays");
let id = self.ident_string(&ttis[0]);
@@ -734,12 +735,18 @@ impl<'src> IrInterp<'src> {
let term = self.term(&ttis[2]);
persistent_arrays.push((id, term));
}
num_skipped += 1;
} else if tts_inner[0] == Leaf(Token::Ident, b"ram_arrays") {
for tti in tts_inner.iter().skip(1) {
let term = self.term(tti);
ram_arrays.push(term);
}
num_skipped += 1;
} else {
break;
}
}
let mut iter = tts.iter().skip(2);
if skip_one {
iter.next();
}
let iter = tts.iter().skip(2 + num_skipped);
let outputs = iter.map(|tti| self.term(tti)).collect();
self.unbind(input_names);
Computation {
@@ -747,6 +754,7 @@ impl<'src> IrInterp<'src> {
metadata,
precomputes,
persistent_arrays,
ram_arrays: ram_arrays.into_iter().collect(),
}
}
@@ -903,6 +911,13 @@ pub fn serialize_computation(c: &Computation) -> String {
}
writeln!(&mut out, "\n)").unwrap();
}
if !c.ram_arrays.is_empty() {
writeln!(&mut out, "(ram_arrays").unwrap();
for term in &c.ram_arrays {
writeln!(&mut out, " {}", serialize_term(term)).unwrap();
}
writeln!(&mut out, "\n)").unwrap();
}
for o in &c.outputs {
writeln!(&mut out, "\n {}", serialize_term(o)).unwrap();
}
@@ -1221,6 +1236,7 @@ mod test {
(tuple (not (and c d)))
)
(persistent_arrays (AA 2 (#a (bv 4) false 4 ((#b0000 true)))))
(ram_arrays (#a (bv 4) false 4 ((#b0001 true))))
(let (
(B ((update 1) A b))
) (xor ((field 1) B)

View File

@@ -14,14 +14,14 @@ main_import_directive = { "import" ~ quoted_string ~ ("as" ~ identifier)? ~ NEWL
import_symbol = { identifier ~ ("as" ~ identifier)? }
import_symbol_list = _{ import_symbol ~ ("," ~ import_symbol)* }
function_definition = {"def" ~ identifier ~ constant_generics_declaration? ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* }
const_definition = {"const" ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*}
const_definition = {"const" ~ array_param_metadata? ~ ty ~ identifier ~ "=" ~ expression ~ NEWLINE*}
type_definition = {"type" ~ identifier ~ constant_generics_declaration? ~ "=" ~ ty ~ NEWLINE*}
return_types = _{ ( "->" ~ ( "(" ~ ty_list ~ ")" | ty ))? }
constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" }
constant_generics_list = _{ identifier ~ ("," ~ identifier)* }
parameter_list = _{(parameter ~ ("," ~ parameter)*)?}
parameter = {vis? ~ ty ~ identifier}
parameter = {array_param_metadata? ~ vis? ~ ty ~ identifier}
// basic types
ty_field = {"field"}
@@ -45,8 +45,11 @@ struct_field = { ty ~ identifier }
vis_private_num = @{ "<" ~ ASCII_DIGIT* ~ ">" }
vis_private = {"private" ~ vis_private_num? }
vis_public = {"public"}
vis_committed = {"committed"}
vis = { vis_private | vis_public | vis_committed }
vis = { vis_private | vis_public }
array_param_metadata = { apm_committed | apm_transcript }
apm_committed = { "committed" }
apm_transcript = { "transcript" }
// Statements
statement = { (return_statement // does not require subsequent newline
@@ -109,7 +112,7 @@ array_initializer_expression = { "[" ~ expression ~ ";" ~ expression ~ "]" }
// End Expressions
typed_identifier = { ty ~ identifier }
typed_identifier = { array_param_metadata? ~ ty ~ identifier }
assignee = { identifier ~ assignee_access* }
assignee_access = { array_access | member_access }
identifier = @{ ((!keyword ~ ASCII_ALPHA) | (keyword ~ (ASCII_ALPHANUMERIC | "_"))) ~ (ASCII_ALPHANUMERIC | "_")* }

View File

@@ -8,22 +8,22 @@ use zokrates_parser::Rule;
extern crate lazy_static;
pub use ast::{
Access, AnyString, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType,
AssertionStatement, Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression,
BinaryOperator, BooleanLiteralExpression, BooleanType, CallAccess, CommittedVisibility,
CondStoreStatement, ConstantDefinition, ConstantGenericValue, Curve, DecimalLiteralExpression,
DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldSuffix,
FieldType, File, FromExpression, FromImportDirective, FunctionDefinition, HexLiteralExpression,
HexNumberExpression, IdentifierExpression, ImportDirective, ImportSymbol,
InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement,
LiteralExpression, MainImportDirective, MemberAccess, NegOperator, NotOperator, Parameter,
PosOperator, PostfixExpression, Pragma, PrivateNumber, PrivateVisibility, PublicVisibility,
Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement,
StrOperator, StructDefinition, StructField, StructType, SymbolDeclaration, TernaryExpression,
ToExpression, Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee,
U16NumberExpression, U16Suffix, U16Type, U32NumberExpression, U32Suffix, U32Type,
U64NumberExpression, U64Suffix, U64Type, U8NumberExpression, U8Suffix, U8Type, UnaryExpression,
UnaryOperator, Underscore, Visibility, EOI,
Access, AnyString, Arguments, ArrayAccess, ArrayCommitted, ArrayInitializerExpression,
ArrayParamMetadata, ArrayTranscript, ArrayType, AssertionStatement, Assignee, AssigneeAccess,
BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, BooleanLiteralExpression,
BooleanType, CallAccess, CondStoreStatement, ConstantDefinition, ConstantGenericValue, Curve,
DecimalLiteralExpression, DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics,
Expression, FieldSuffix, FieldType, File, FromExpression, FromImportDirective,
FunctionDefinition, HexLiteralExpression, HexNumberExpression, IdentifierExpression,
ImportDirective, ImportSymbol, InlineArrayExpression, InlineStructExpression,
InlineStructMember, IterationStatement, LiteralExpression, MainImportDirective, MemberAccess,
NegOperator, NotOperator, Parameter, PosOperator, PostfixExpression, Pragma, PrivateNumber,
PrivateVisibility, PublicVisibility, Range, RangeOrExpression, ReturnStatement, Span, Spread,
SpreadOrExpression, Statement, StrOperator, StructDefinition, StructField, StructType,
SymbolDeclaration, TernaryExpression, ToExpression, Type, TypeDefinition, TypedIdentifier,
TypedIdentifierOrAssignee, U16NumberExpression, U16Suffix, U16Type, U32NumberExpression,
U32Suffix, U32Type, U64NumberExpression, U64Suffix, U64Type, U8NumberExpression, U8Suffix,
U8Type, UnaryExpression, UnaryOperator, Underscore, Visibility, EOI,
};
mod ast {
@@ -193,6 +193,7 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::const_definition))]
pub struct ConstantDefinition<'ast> {
pub array_metadata: Option<ArrayParamMetadata<'ast>>,
pub ty: Type<'ast>,
pub id: IdentifierExpression<'ast>,
pub expression: Expression<'ast>,
@@ -342,6 +343,7 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::parameter))]
pub struct Parameter<'ast> {
pub array_metadata: Option<ArrayParamMetadata<'ast>>,
pub visibility: Option<Visibility<'ast>>,
pub ty: Type<'ast>,
pub id: IdentifierExpression<'ast>,
@@ -349,11 +351,31 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::array_param_metadata))]
pub enum ArrayParamMetadata<'ast> {
Committed(ArrayCommitted<'ast>),
Transcript(ArrayTranscript<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::apm_committed))]
pub struct ArrayCommitted<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::apm_transcript))]
pub struct ArrayTranscript<'ast> {
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::vis))]
pub enum Visibility<'ast> {
Public(PublicVisibility),
Committed(CommittedVisibility),
Private(PrivateVisibility<'ast>),
}
@@ -370,10 +392,6 @@ mod ast {
#[pest_ast(rule(Rule::vis_public))]
pub struct PublicVisibility {}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::vis_committed))]
pub struct CommittedVisibility {}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::vis_private))]
pub struct PrivateVisibility<'ast> {
@@ -720,6 +738,7 @@ mod ast {
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::typed_identifier))]
pub struct TypedIdentifier<'ast> {
pub array_metadata: Option<ArrayParamMetadata<'ast>>,
pub ty: Type<'ast>,
pub identifier: IdentifierExpression<'ast>,
#[pest_ast(outer())]
@@ -1414,6 +1433,7 @@ mod tests {
statements: vec![Statement::Definition(DefinitionStatement {
lhs: vec![
TypedIdentifierOrAssignee::TypedIdentifier(TypedIdentifier {
array_metadata: None,
ty: Type::Basic(BasicType::Field(FieldType {
span: Span::new(&source, 23, 28).unwrap()
})),