Merge branch 'master' into c_frontend

This commit is contained in:
Edward Chen
2022-02-07 03:03:45 -05:00
272 changed files with 14422 additions and 3177 deletions

73
Cargo.lock generated
View File

@@ -17,6 +17,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.18"
@@ -225,6 +236,7 @@ dependencies = [
"pest",
"pest-ast",
"pest_derive",
"petgraph",
"quickcheck",
"quickcheck_macros",
"rand",
@@ -396,6 +408,12 @@ dependencies = [
"subtle",
]
[[package]]
name = "fixedbitset"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "398ea4fabe40b9b0d885340a2a991a44c8a645624075ad966d21f88688e2b69e"
[[package]]
name = "fnv"
version = "1.0.7"
@@ -486,12 +504,22 @@ dependencies = [
"subtle",
]
[[package]]
name = "hashbrown"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
dependencies = [
"ahash",
]
[[package]]
name = "hashconsing"
version = "1.3.0"
source = "git+https://github.com/alex-ozdemir/hashconsing.git?branch=phash#4070a07409b8536b0379eedf516658c92a4caa84"
source = "git+https://github.com/alex-ozdemir/hashconsing.git?branch=phash#a74c1d01742580a16243b6805b647349abbdfe59"
dependencies = [
"lazy_static",
"lru",
]
[[package]]
@@ -524,6 +552,16 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9007da9cacbd3e6343da136e98b0d2df013f553d35bdec8b518f07bea768e19c"
[[package]]
name = "indexmap"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5"
dependencies = [
"autocfg",
"hashbrown",
]
[[package]]
name = "itertools"
version = "0.7.11"
@@ -585,6 +623,15 @@ dependencies = [
"xml-rs",
]
[[package]]
name = "lru"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "274353858935c992b13c0ca408752e2121da852d07dec7ce5f108c77dfa14d1f"
dependencies = [
"hashbrown",
]
[[package]]
name = "maplit"
version = "1.0.2"
@@ -644,6 +691,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da32515d9f6e6e489d7bc9d84c71b060db7247dc035bbe44eac88cf87486d8d5"
[[package]]
name = "opaque-debug"
version = "0.2.3"
@@ -715,6 +768,16 @@ dependencies = [
"sha-1",
]
[[package]]
name = "petgraph"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a13a2fa9d0b63e5f22328828741e523766fff0ee9e779316902290dff3f824f"
dependencies = [
"fixedbitset",
"indexmap",
]
[[package]]
name = "ppv-lite86"
version = "0.2.14"
@@ -911,9 +974,9 @@ dependencies = [
[[package]]
name = "rsmt2"
version = "0.12.0"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b616b1ac7f0393f4441a0d6301b35b1bccd209c61dd65c8852cb8397d68cdc89"
checksum = "4affc8a99241732d2e214728974bd002bf3aa1d114760cf5c3ac6c1fd5650c7d"
dependencies = [
"error-chain",
]
@@ -1227,7 +1290,7 @@ checksum = "d2d7d3948613f75c98fd9328cfdcc45acc4d360655289d0a7d4ec931392200a3"
[[package]]
name = "zokrates_parser"
version = "0.1.6"
version = "0.2.4"
dependencies = [
"pest",
"pest_derive",
@@ -1235,7 +1298,7 @@ dependencies = [
[[package]]
name = "zokrates_pest_ast"
version = "0.1.5"
version = "0.2.3"
dependencies = [
"from-pest",
"lazy_static",

View File

@@ -13,7 +13,7 @@ rug = "1.11"
gmp-mpfr-sys = "1.4"
lazy_static = "1.4"
rand = "0.8"
rsmt2 = "0.12"
rsmt2 = "0.14"
#rsmt2 = { git = "https://github.com/alex-ozdemir/rsmt2.git" }
ieee754 = "0.2"
zokrates_parser = { path = "third_party/ZoKrates/zokrates_parser" }
@@ -25,8 +25,8 @@ bellman = "0.11"
ff = "0.11"
#funty = "=1.1"
fxhash = "0.2"
good_lp = { version = "1.1", features = ["lp-solvers", "coin_cbc"], default-features = false }
lp-solvers = "0.0.4"
good_lp = { version = "1.1", features = ["lp-solvers", "coin_cbc"], default-features = false, optional = true }
lp-solvers = { version = "0.0.4", optional = true }
serde_json = "1.0"
lang-c = "0.10.1"
pest = "2.1"
@@ -34,6 +34,7 @@ pest_derive = "2.1"
pest-ast = "0.3"
from-pest = "0.3"
itertools = "0.10"
petgraph = "0.6"
[dev-dependencies]
quickcheck = "1"
@@ -43,5 +44,18 @@ bls12_381 = "0.6"
structopt = "0.3"
approx = "0.5.0"
[features]
default = ["lp", "bls12381"]
lp = ["good_lp", "lp-solvers"]
bls12381 = []
[[example]]
name = "circ"
required-features = ["lp"]
[[example]]
name = "opa_bench"
required-features = ["lp"]
[profile.release]
debug = true

View File

@@ -15,8 +15,8 @@ build_aby_c: build_deps build
./scripts/build_aby.zsh
build:
cargo build --release --example circ
cargo build --example circ
cargo build --release --examples
cargo build --examples
test: build build_aby_zokrates build_aby_c
cargo test

31
README_zsharp.md Normal file
View File

@@ -0,0 +1,31 @@
# zsharp (nee zok07) interpreter quickstart
**WARNING** this interpreter is still experimental! When things break, please
tell me about them :)
## building
1. see `scripts/dependencies_*` for info on installing deps.
Note that on M1 macs the homebrew instructions don't quite work, because
the coin-or build from homebrew is broken (for now). Don't worry---you
don't actually need this dep to build the zsharp interpreter.
2. circ uses some experimental APIs, so you'll need rust nightly!
3. To build the Z# interpreter cli,
`cargo build --release --example zxi --no-default-features`
## running
After building as above, `target/release/examples/zxi` will have been
generated. This executable takes one argument, the name of a .zok file.
Absolute and relative paths are both OK:
target/release/examples/zxi /tmp/foo.zok
target/release/examples/zxi ../../path/to/somewhere/else.zok
You may want to set the `RUST_LOG` environment variable to see more info
about the typechecking and interpreting process:
RUST_LOG=debug target/release/examples/zxi /tmp/foo.zok

View File

@@ -21,6 +21,11 @@ Concrete:
* We use it to fuzz IR passes
* General problem: Fuzzing language FEs
[ ] Implement sorts using hash-consing.
[ ] Modeling RAM transformations in Coq and proving their correctness
1. model a term IR, with functional arrays (like ours!).
2. model a RAM-augmented term IR, with conditional stores and reads
3. write a converter
4. prove that it works
Vague:
[ ] FE analysis infrastructure
@@ -49,4 +54,4 @@ Bigger research questions:
[ ] SoK: compiling to R1CS
* focus on embedding complex datatypes:
* use lookups
[ ] Compiling to branching programs
[ ] Compiling to branching programs

View File

@@ -8,7 +8,7 @@ use bellman::Circuit;
use bls12_381::{Bls12, Scalar};
use circ::front::c::{self, C};
use circ::front::datalog::{self, Datalog};
use circ::front::zokrates::{self, Zokrates};
use circ::front::zsharp::{self, ZSharpFE};
use circ::front::{FrontEnd, Mode};
use circ::ir::{
opt::{opt, Opt},
@@ -21,12 +21,11 @@ use circ::target::r1cs::bellman::parse_instance;
use circ::target::r1cs::opt::reduce_linearities;
use circ::target::r1cs::trans::to_r1cs;
use circ::target::smt::find_model;
use env_logger;
use good_lp::default_solver;
use std::fs::File;
use std::io::Read;
use std::io::Write;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use structopt::clap::arg_enum;
use structopt::StructOpt;
@@ -96,7 +95,7 @@ enum Backend {
arg_enum! {
#[derive(PartialEq, Debug)]
enum Language {
Zokrates,
Zsharp,
Datalog,
C,
Auto,
@@ -105,7 +104,7 @@ arg_enum! {
#[derive(PartialEq, Debug)]
pub enum DeterminedLanguage {
Zokrates,
Zsharp,
Datalog,
C,
}
@@ -134,22 +133,18 @@ arg_enum! {
}
}
fn determine_language(l: &Language, input_path: &PathBuf) -> DeterminedLanguage {
match l {
&Language::Datalog => DeterminedLanguage::Datalog,
&Language::Zokrates => DeterminedLanguage::Zokrates,
&Language::C => DeterminedLanguage::C,
&Language::Auto => {
fn determine_language(l: &Language, input_path: &Path) -> DeterminedLanguage {
match *l {
Language::Datalog => DeterminedLanguage::Datalog,
Language::Zsharp => DeterminedLanguage::Zsharp,
Language::C => DeterminedLanguage::C,
Language::Auto => {
let p = input_path.to_str().unwrap();
if p.ends_with(".zok") {
DeterminedLanguage::Zokrates
DeterminedLanguage::Zsharp
} else if p.ends_with(".pl") {
DeterminedLanguage::Datalog
} else if p.ends_with(".c") {
DeterminedLanguage::C
} else if p.ends_with(".cpp") {
DeterminedLanguage::C
} else if p.ends_with(".cc") {
} else if p.ends_with(".c") || p.ends_with(".cpp") || p.ends_with(".cc") {
DeterminedLanguage::C
} else {
println!("Could not deduce the input language from path '{}', please set the language manually", p);
@@ -178,13 +173,13 @@ fn main() {
};
let language = determine_language(&options.frontend.language, &options.path);
let cs = match language {
DeterminedLanguage::Zokrates => {
let inputs = zokrates::Inputs {
DeterminedLanguage::Zsharp => {
let inputs = zsharp::Inputs {
file: options.path,
inputs: options.frontend.inputs,
mode: mode.clone(),
mode,
};
Zokrates::gen(inputs)
ZSharpFE::gen(inputs)
}
DeterminedLanguage::Datalog => {
let inputs = datalog::Inputs {
@@ -198,7 +193,7 @@ fn main() {
let inputs = c::Inputs {
file: options.path,
inputs: options.frontend.inputs,
mode: mode.clone(),
mode,
};
C::gen(inputs)
}
@@ -260,7 +255,7 @@ fn main() {
..
} => {
println!("Converting to r1cs");
let r1cs = to_r1cs(cs, circ::front::zokrates::ZOKRATES_MODULUS.clone());
let r1cs = to_r1cs(cs, circ::front::zsharp::ZSHARP_MODULUS.clone());
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
let r1cs = reduce_linearities(r1cs);
println!("Final R1cs size: {}", r1cs.constraints().len());
@@ -301,7 +296,7 @@ fn main() {
println!("Converting to aby");
let lang_str = match language {
DeterminedLanguage::C => "c".to_string(),
DeterminedLanguage::Zokrates => "zok".to_string(),
DeterminedLanguage::Zsharp => "zok".to_string(),
_ => panic!("Language isn't supported by MPC backend: {:#?}", language),
};
println!("Cost model: {}", cost_model);
@@ -323,7 +318,7 @@ fn main() {
if var.contains("f0") {
let i = var.find("f0").unwrap();
let s = &var[i + 8..];
let e = s.find("_").unwrap();
let e = s.find('_').unwrap();
writeln!(f, "{} {}", &s[..e], val.round() as u64).unwrap();
}
}

View File

@@ -30,6 +30,6 @@ fn main() {
metadata: ComputationMetadata::default(),
values: None,
};
let _assignment = ilp::assign(&cs, &format!("hycc"));
let _assignment = ilp::assign(&cs, "hycc");
//dbg!(&assignment);
}

187
examples/zxc.rs Normal file
View File

@@ -0,0 +1,187 @@
/*
use bellman::gadgets::test::TestConstraintSystem;
use bellman::groth16::{
create_random_proof, generate_parameters, generate_random_parameters, prepare_verifying_key,
verify_proof, Parameters, Proof, VerifyingKey,
};
use bellman::Circuit;
use bls12_381::{Bls12, Scalar};
*/
use circ::front::zsharp::{self, ZSharpFE};
use circ::front::{FrontEnd, Mode};
use circ::ir::opt::{opt, Opt};
/*
use circ::target::r1cs::bellman::parse_instance;
*/
use circ::target::r1cs::opt::reduce_linearities;
use circ::target::r1cs::trans::to_r1cs;
/*
use std::fs::File;
use std::io::Read;
use std::io::Write;
*/
use std::path::PathBuf;
use structopt::clap::arg_enum;
use structopt::StructOpt;
#[derive(Debug, StructOpt)]
#[structopt(name = "zxc", about = "CirC: the circuit compiler")]
struct Options {
/// Input file
#[structopt(parse(from_os_str), name = "PATH")]
path: PathBuf,
#[structopt(flatten)]
frontend: FrontendOptions,
/*
#[structopt(long, default_value = "P", parse(from_os_str))]
prover_key: PathBuf,
#[structopt(long, default_value = "V", parse(from_os_str))]
verifier_key: PathBuf,
#[structopt(long, default_value = "pi", parse(from_os_str))]
proof: PathBuf,
#[structopt(long, default_value = "x", parse(from_os_str))]
instance: PathBuf,
*/
#[structopt(short = "L")]
skip_linred: bool,
#[structopt(long, default_value = "count")]
action: ProofAction,
}
#[derive(Debug, StructOpt)]
struct FrontendOptions {
/// File with input witness
#[structopt(long, name = "FILE", parse(from_os_str))]
inputs: Option<PathBuf>,
}
arg_enum! {
#[derive(PartialEq, Debug)]
enum ProofAction {
Count,
Prove,
Setup,
Verify,
}
}
arg_enum! {
#[derive(PartialEq, Debug)]
enum ProofOption {
Count,
Prove,
}
}
fn main() {
env_logger::Builder::from_default_env()
.format_level(false)
.format_timestamp(None)
.init();
let options = Options::from_args();
println!("{:?}", options);
let cs = {
let inputs = zsharp::Inputs {
file: options.path,
inputs: options.frontend.inputs,
mode: Mode::Proof,
};
ZSharpFE::gen(inputs)
};
print!("Optimizing IR... ");
let cs = opt(
cs,
vec![
Opt::ScalarizeVars,
Opt::Flatten,
Opt::Sha,
Opt::ConstantFold,
Opt::Flatten,
Opt::Inline,
// Tuples must be eliminated before oblivious array elim
Opt::Tuple,
Opt::ConstantFold,
Opt::Obliv,
// The obliv elim pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::LinearScan,
// The linear scan pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::Flatten,
Opt::ConstantFold,
Opt::Inline,
],
);
println!("done.");
let action = options.action;
/*
let proof = options.proof;
let prover_key = options.prover_key;
let verifier_key = options.verifier_key;
let instance = options.instance;
*/
println!("Converting to r1cs");
let r1cs = to_r1cs(cs, circ::front::zsharp::ZSHARP_MODULUS.clone());
let r1cs = if options.skip_linred {
println!("Skipping linearity reduction, as requested.");
r1cs
} else {
println!(
"R1cs size before linearity reduction: {}",
r1cs.constraints().len()
);
reduce_linearities(r1cs)
};
println!("Final R1cs size: {}", r1cs.constraints().len());
match action {
ProofAction::Count => {
println!("{:#?}", r1cs.constraints());
}
ProofAction::Prove => {
unimplemented!()
/*
println!("Proving");
r1cs.check_all();
let rng = &mut rand::thread_rng();
let mut pk_file = File::open(prover_key).unwrap();
let pk = Parameters::<Bls12>::read(&mut pk_file, false).unwrap();
let pf = create_random_proof(&r1cs, &pk, rng).unwrap();
let mut pf_file = File::create(proof).unwrap();
pf.write(&mut pf_file).unwrap();
*/
}
ProofAction::Setup => {
unimplemented!()
/*
let rng = &mut rand::thread_rng();
let p =
generate_random_parameters::<bls12_381::Bls12, _, _>(&r1cs, rng).unwrap();
let mut pk_file = File::create(prover_key).unwrap();
p.write(&mut pk_file).unwrap();
let mut vk_file = File::create(verifier_key).unwrap();
p.vk.write(&mut vk_file).unwrap();
*/
}
ProofAction::Verify => {
unimplemented!()
/*
println!("Verifying");
let mut vk_file = File::open(verifier_key).unwrap();
let vk = VerifyingKey::<Bls12>::read(&mut vk_file).unwrap();
let pvk = prepare_verifying_key(&vk);
let mut pf_file = File::open(proof).unwrap();
let pf = Proof::read(&mut pf_file).unwrap();
let instance_vec = parse_instance(&instance);
verify_proof(&pvk, &pf, &instance_vec).unwrap();
*/
}
};
}

50
examples/zxi.rs Normal file
View File

@@ -0,0 +1,50 @@
use circ::front::zsharp::{Inputs, ZSharpFE};
use circ::front::Mode;
use std::path::PathBuf;
use structopt::StructOpt;
#[derive(Debug, StructOpt)]
#[structopt(name = "circ", about = "CirC: the circuit compiler")]
struct Options {
/// Input file
#[structopt(parse(from_os_str))]
zsharp_path: PathBuf,
/// File with input witness
#[structopt(short, long, name = "FILE", parse(from_os_str))]
inputs: Option<PathBuf>,
/// Number of parties for an MPC. If missing, generates a proof circuit.
#[structopt(short, long, name = "PARTIES")]
parties: Option<u8>,
/// Whether to maximize the output
#[structopt(short, long)]
maximize: bool,
}
fn main() {
env_logger::Builder::from_default_env()
.format_level(false)
.format_timestamp(None)
.init();
let options = Options::from_args();
//println!("{:?}", options);
let mode = if options.maximize {
Mode::Opt
} else {
match options.parties {
Some(p) => Mode::Mpc(p),
None => Mode::Proof,
}
};
let inputs = Inputs {
file: options.zsharp_path,
inputs: options.inputs,
mode,
};
let cs = ZSharpFE::interpret(inputs);
cs.pretty(&mut std::io::stdout().lock())
.expect("error pretty-printing value");
println!();
}

View File

@@ -1,2 +1,2 @@
set -xe
pacman -S cvc4 coinor-cbc
pacman -S cvc4 coin-or-cbc

View File

@@ -0,0 +1,9 @@
set -xe
# breaks on M1 processors for now
# https://github.com/coin-or-tools/homebrew-coinor/issues/62
brew tap coin-or-tools/coinor
brew install coin-or-tools/coinor/cbc
brew tap cvc4/cvc4
brew install cvc4/cvc4/cvc4

View File

@@ -56,3 +56,5 @@ pf_test str_arr_str
pf_test arr_str_arr_str
pf_test var_idx_arr_str_arr_str
pf_test mm
scripts/zx_tests/run_tests.sh

View File

@@ -0,0 +1,4 @@
def main() -> bool:
bool a = [4u32; 4u32] == [5u32; 4u32]
bool b = [4u32; 4u32] != [5u32; 4u32]
return a || b

View File

@@ -0,0 +1,4 @@
def main() -> bool:
bool a = [4u32; 4u32] == [5u32; 5u32]
bool b = [4u32; 4u32] != [5u32; 4u32]
return a || b

View File

@@ -0,0 +1,4 @@
def main() -> bool:
bool a = [4u32; 4u32] == [5u32; 4u32]
bool b = [4u32; 4u32] != [5u32; 5u32]
return a || b

View File

@@ -0,0 +1,3 @@
def main() -> u32:
u32[3] a = [1, 2, 3]
return a[3]

View File

@@ -0,0 +1,4 @@
def main() -> u32:
u32[3] a = [1, 2, 3]
a[3] = 4
return a[0]

View File

@@ -0,0 +1,5 @@
def main() -> u32:
u32[4] a = [1, 2, 3, 4]
a[2] = 5
assert(a[2] == 5)
return a[2]

View File

@@ -0,0 +1,103 @@
import "utils/casts/u8_to_bits"
import "utils/casts/u8_from_bits"
import "utils/casts/u8_to_field"
import "utils/casts/field_to_u8"
import "utils/casts/u16_to_bits"
import "utils/casts/u16_from_bits"
import "utils/casts/u16_to_field"
import "utils/casts/field_to_u16"
import "utils/casts/u32_to_bits"
import "utils/casts/u32_from_bits"
import "utils/casts/u32_to_field"
import "utils/casts/field_to_u32"
import "utils/casts/u64_to_bits"
import "utils/casts/u64_from_bits"
import "utils/casts/u64_to_field"
import "utils/casts/field_to_u64"
import "utils/pack/bool/unpack"
import "utils/pack/bool/pack"
def main() -> bool:
// check for msb0 bit order
u8 i1 = 128
bool[8] o1 = u8_to_bits(i1)
assert(o1[0])
assert(!o1[7])
u16 i2 = 32768
bool[16] o2 = u16_to_bits(i2)
assert(o2[0])
assert(!o2[15])
u32 i3 = 2147483648
bool[32] o3 = u32_to_bits(i3)
assert(o3[0])
assert(!o3[31])
u64 i4 = 9223372036854775808
bool[64] o4 = u64_to_bits(i4)
assert(o4[0])
assert(!o4[63])
// u8 -> field -> bits -> u8
u8 t1_0 = 42
field t1_1 = u8_to_field(t1_0)
bool[8] t1_2 = unpack(t1_1)
u8 t1_3 = u8_from_bits(t1_2)
assert(t1_0 == t1_3)
// XXX(TODO) pack builtin
// u8 -> bits -> field -> u8
u8 t2_0 = 77
bool[8] t2_1 = u8_to_bits(t2_0)
field t2_2 = pack(t2_1)
u8 t2_3 = field_to_u8(t2_2)
assert(t2_0 == t2_3)
// u16 -> field -> bits -> u16
u16 t3_0 = 46971
field t3_1 = u16_to_field(t3_0)
bool[16] t3_2 = unpack(t3_1)
u16 t3_3 = u16_from_bits(t3_2)
assert(t3_0 == t3_3)
// u16 -> bits -> field -> u16
u16 t4_0 = 63336
bool[16] t4_1 = u16_to_bits(t4_0)
field t4_2 = pack(t4_1)
u16 t4_3 = field_to_u16(t4_2)
assert(t4_0 == t4_3)
// u32 -> field -> bits -> u32
u32 t5_0 = 2652390681
field t5_1 = u32_to_field(t5_0)
bool[32] t5_2 = unpack(t5_1)
u32 t5_3 = u32_from_bits(t5_2)
assert(t5_0 == t5_3)
// u32 -> bits -> field -> u32
u32 t6_0 = 1173684415
bool[32] t6_1 = u32_to_bits(t6_0)
field t6_2 = pack(t6_1)
u32 t6_3 = field_to_u32(t6_2)
assert(t6_0 == t6_3)
// u64 -> field -> bits -> u64
u64 t7_0 = 18312416462297086083
field t7_1 = u64_to_field(t7_0)
bool[64] t7_2 = unpack(t7_1)
u64 t7_3 = u64_from_bits(t7_2)
assert(t7_0 == t7_3)
// u64 -> bits -> field -> u64
u64 t8_0 = 4047977501435466453
bool[64] t8_1 = u64_to_bits(t8_0)
field t8_2 = pack(t8_1)
u64 t8_3 = field_to_u64(t8_2)
assert(t8_0 == t8_3)
return true

View File

@@ -0,0 +1,52 @@
import "utils/casts/bool_array_to_u32_array"
def main() -> u32:
bool[2] ones = [true, true]
bool[6] zeros = [false, false, false, false, false, false]
bool[8] byte0 = [...ones, ...zeros] // 0xc0
bool[8] byte1 = [...zeros, ...ones] // 0x03
bool[16] word0 = [...byte0, ...byte0] // 0xc0c0
bool[16] word1 = [...byte0, ...byte1] // 0xc003
bool[16] word2 = [...byte1, ...byte0] // 0x03c0
bool[16] word3 = [...byte1, ...byte1] // 0x0303
bool[32] dwrd0 = [...word0, ...word0]
bool[32] dwrd1 = [...word0, ...word1]
bool[32] dwrd2 = [...word0, ...word2]
bool[32] dwrd3 = [...word0, ...word3]
bool[32] dwrd4 = [...word1, ...word0]
bool[32] dwrd5 = [...word1, ...word1]
bool[32] dwrd6 = [...word1, ...word2]
bool[32] dwrd7 = [...word1, ...word3]
bool[32] dwrd8 = [...word2, ...word0]
bool[32] dwrd9 = [...word2, ...word1]
bool[32] dwrdA = [...word2, ...word2]
bool[32] dwrdB = [...word2, ...word3]
bool[32] dwrdC = [...word3, ...word0]
bool[32] dwrdD = [...word3, ...word1]
bool[32] dwrdE = [...word3, ...word2]
bool[32] dwrdF = [...word3, ...word3]
bool[16 * 32] foo = [...dwrd0, ...dwrd1, ...dwrd2, ...dwrd3, ...dwrd4, ...dwrd5, ...dwrd6, ...dwrd7, ...dwrd8, ...dwrd9, ...dwrdA, ...dwrdB, ...dwrdC, ...dwrdD, ...dwrdE, ...dwrdF ]
u32[16] a = bool_array_to_u32_array(foo)
assert(a[0] == 0xc0c0c0c0)
assert(a[1] == 0xc0c0c003)
assert(a[2] == 0xc0c003c0)
assert(a[3] == 0xc0c00303)
assert(a[4] == 0xc003c0c0)
assert(a[5] == 0xc003c003)
assert(a[6] == 0xc00303c0)
assert(a[7] == 0xc0030303)
assert(a[8] == 0x03c0c0c0)
assert(a[9] == 0x03c0c003)
assert(a[10] == 0x03c003c0)
assert(a[11] == 0x03c00303)
assert(a[12] == 0x0303c0c0)
assert(a[13] == 0x0303c003)
assert(a[14] == 0x030303c0)
assert(a[15] == 0x03030303)
return a[0]

View File

@@ -0,0 +1,10 @@
const u32[5] asdf = [1,2,3,4,5]
def last<N>(u32[N] a) -> u32:
return a[N-1]
def foo<N>(u32[N] a) -> u32:
return last([...a, ...a])
def main() -> u32:
return foo([1,2,3])

View File

@@ -0,0 +1,10 @@
const u32[5] asdf = [1,2,3,4,5]
def last<N>(u32[N] a) -> u32:
return a[N-1]
def foo<N>(u32[N] a) -> u32:
return last([...a, ...a])
def main() -> u32:
return foo(asdf)

View File

@@ -0,0 +1,18 @@
def main() -> bool:
field a = 0
field b = -1
field c = 2
field d = 2
assert(b > a)
assert(b >= a)
assert(a < b)
assert(a <= b)
assert(c > a)
assert(c >= a)
assert(c < b)
assert(c <= b)
assert(d >= c)
assert(c <= d)
assert(c != b)
assert(c == d)
return true

View File

@@ -0,0 +1,5 @@
def main() -> bool:
field a = 12824923210
field b = 18423229
assert(a % b == 2355826)
return false

View File

@@ -0,0 +1,4 @@
from "EMBED" import FIELD_SIZE_IN_BITS
def main() -> u32:
return FIELD_SIZE_IN_BITS

View File

@@ -0,0 +1,3 @@
def main() -> u32:
u32[3][2] foo = [[1,2], [3,4], [5,6,7]]
return foo[0][0]

View File

@@ -0,0 +1,6 @@
const u32[3] A = [1, 2, 3]
const u32[2][3] B = [A, A]
const u32[1][2][3] C = [B]
def main() -> u32[1][2][3]:
return C

View File

@@ -0,0 +1,9 @@
const u32[3] A = [1, 2, 3]
const u32[2][3] B = [A, A]
const u32[1][2][3] C = [B]
def get_C() -> u32[1][2][3]:
return C
def main() -> u32[3]:
return get_C()[0][1]

View File

@@ -0,0 +1,9 @@
const u32[3] A = [1, 2, 3]
const u32[2][3] B = [A, A]
const u32[1][2][3] C = [B]
def get_C() -> u32[1][2][3]:
return C
def main() -> u32[3]:
return get_C()[1][1]

View File

@@ -0,0 +1,5 @@
const u32[5] asdf = [1,2,3,4,5]
def main() -> u32[4]:
u32[5] qwer = [1,2,3,4,5]
return [...asdf[1..3], 4, qwer[2]]

View File

@@ -0,0 +1,5 @@
const u32[5] asdf = [1,2,3,4,5]
def main() -> u32[5]:
u32[5] qwer = [1,2,3,4,5]
return [...asdf[1..3], 4, qwer[2]]

View File

@@ -0,0 +1,8 @@
struct InlineTest<N> {
u32[N] x
field y
}
def main() -> InlineTest<4>:
InlineTest<4> foo = InlineTest { x: [1, 2, 3, 4], y: 1 }
return foo

View File

@@ -0,0 +1,8 @@
struct InlineTest<N> {
u32[N] x
field y
}
def main() -> InlineTest<4>:
InlineTest<5> foo = InlineTest { x: [1, 2, 3, 4, 5], y: 1 }
return foo

View File

@@ -0,0 +1,9 @@
struct InlineTest<N> {
u32[N] x
field y
}
const InlineTest<4> foo = InlineTest { x: [1, 2, 3, 4, 5], y: 1 }
def main() -> InlineTest<4>:
return foo

View File

@@ -0,0 +1,8 @@
struct InlineTest<N> {
u32[N] x
field y
}
def main() -> InlineTest<4>:
InlineTest<4> foo = InlineTest { x: [1, 2, 3, 4, 5], y: 1 }
return foo

View File

@@ -0,0 +1,8 @@
struct InlineTest<N> {
u32[N] x
field y
}
def main() -> InlineTest<4>:
InlineTest<4> foo = MisspelledInlineTest { x: [1, 2, 3, 4], y: 1 }
return foo

View File

@@ -0,0 +1,9 @@
struct InlineTest<N> {
u32[N] x
field y
}
const InlineTest<4> foo = MisspelledInlineTest { x: [1, 2, 3, 4], y: 1 }
def main() -> InlineTest<4>:
return foo

View File

@@ -0,0 +1,17 @@
struct Foo<N> {
u32[N] a
u64 b
}
struct Bar<N> {
Foo<N> a
u64 b
}
const Bar<4> baz = Bar {
a: Foo { a: [1, 2, 3, 4], b: 0 },
b: 0
}
def main() -> Bar<4>:
return baz

View File

@@ -0,0 +1,17 @@
struct Foo<N> {
u32[N] a
u64 b
}
struct Bar<N> {
Foo<N> a
u64 b
}
const Bar<4> baz = Bar {
a: Foo { a: [1, 2, 3], b: 0 },
b: 0
}
def main() -> Bar<4>:
return baz

View File

@@ -0,0 +1,6 @@
def main() -> bool:
assert(0xfa == 250)
assert(0xbeef == 48879)
assert(0xdeadbeef == 3735928559)
assert(0xc0ffee1111111111 == 13907095931411566865)
return true

View File

@@ -0,0 +1,5 @@
const u32 A = 1
const u32 A = 2
def main() -> bool:
return false

View File

@@ -0,0 +1,6 @@
from "EMBED" import FIELD_SIZE_IN_BITS as A
const u32 A = 2
def main() -> bool:
return false

View File

@@ -0,0 +1,8 @@
def foo() -> u32:
return 1
def foo() -> u32:
return 2
def main() -> u32:
return foo()

View File

@@ -0,0 +1,5 @@
from "EMBED" import FIELD_SIZE_IN_BITS as A
from "EMBED" import u16_to_bits as A
def main() -> bool:
return false

View File

@@ -0,0 +1,5 @@
import "EMBED"
import "EMBED"
def main() -> bool:
return false

View File

@@ -0,0 +1,10 @@
struct Foo {
u32 a
}
struct Foo {
u32 b
}
def main() -> bool:
return true

28
scripts/zx_tests/run_tests.sh Executable file
View File

@@ -0,0 +1,28 @@
#!/bin/bash
TESTDIR=$(dirname -- "$0")
ZXI=${TESTDIR}/../../target/release/examples/zxi
error=0
echo Running zx should-pass tests:
for i in ${TESTDIR}/*.zx; do
${ZXI} "$i" &>/dev/null
if [ "$?" != "0" ]; then
echo "[failure: should-pass] $i"
error=1
fi
done
echo Done.
echo
echo Running zx should-fail tests:
for i in ${TESTDIR}/*.zxf; do
${ZXI} "$i" &>/dev/null
if [ "$?" == "0" ]; then
echo "[failure: should-fail] $i"
error=1
fi
done
echo Done.
exit $error

View File

@@ -0,0 +1,21 @@
from "field" import s_divisible, s_remainder
def main() -> bool:
field q = 4
field a = -2048
assert((a % q) != 0)
assert(s_divisible(a, q))
assert(s_remainder(a, q) == 0)
field b = 2048
assert((b % q) == 0)
assert(s_divisible(b, q))
assert(s_remainder(b, q) == 0)
field c = -2049
assert((c % 2) == 0)
assert(!s_divisible(c, q))
assert(s_remainder(c, q) == 3)
return true

View File

@@ -0,0 +1,7 @@
def main() -> bool:
u32 total = 0
for u32 j in 0..7 do
total = total + j
endfor
assert(total == 21)
return true

View File

@@ -0,0 +1,5 @@
def last<N>(u32[N] a) -> u32:
return a[N-1]
def main() -> u32:
return last([1u32,2,3])

View File

@@ -0,0 +1,6 @@
def dbl<N,NN>(u32[N] a) -> u32[NN]:
// XXX NN is unconstrained! this is a weird and annoying thing
return [...a,...a]
def main() -> u32[6]:
return dbl([1u32,2,3])

View File

@@ -0,0 +1,6 @@
def dbl<N,NN>(u32[N] a) -> u32[NN]:
// XXX NN is unconstrained! this is a weird and annoying thing
return [...a,...a]
def main() -> u32[5]:
return dbl([1u32,2,3])

View File

@@ -0,0 +1,5 @@
def last<N>(u32[N] a) -> u32:
return a[N-1]
def main() -> u32:
return last([1u32, 2, ...[3u32, 4, 5]])

View File

@@ -0,0 +1,7 @@
const u32[5] asdf = [1,2,3,4,5]
def last<N>(u32[N] a) -> u32:
return a[N-1]
def main() -> u32:
return last(asdf)

View File

@@ -0,0 +1,7 @@
const u32[5] asdf = [1,2,3,4,5]
def dbl<N,NN>(u32[N] a) -> u32[NN]:
return [...a,...a]
def main() -> u32[10]:
return dbl(asdf)

View File

@@ -0,0 +1,7 @@
const u32[5] asdf = [1,2,3,4,5]
def dbl<N,NN>(u32[N] a) -> u32[NN]:
return [...a,...a]
def main() -> u32[6]:
return dbl(asdf)

View File

@@ -0,0 +1,22 @@
struct Bar {
u8 d
u16 e
}
struct Foo {
u32[7] a
field b
u64 c
Bar d
}
def main() -> Foo:
Bar w = Bar { d: 0, e: 0 }
assert(w == w)
Foo x = Foo { a: [7; 7], b: 1, c: 0, d: w }
Foo y = Foo { a: [8; 7], b: 0, c: 1, d: w }
assert(x != y)
assert(!(x == y))
return x

View File

@@ -0,0 +1,18 @@
struct Bar {
u8 d
u16 e
}
struct Foo {
u32[7] a
field b
u64 c
Bar d
}
def main() -> bool:
Bar w = Bar { d: 0, e: 0 }
Foo x = Foo { a: [7; 7], b: 1, c: 0, d: w }
assert(x != w)
assert(!(x == y))
return x == y || x != y

View File

@@ -0,0 +1,12 @@
struct Foo {
u32 a
u8 b
}
def main() -> u8:
Foo bar = Foo { a: 1, b: 2 }
assert(bar.a == 1)
assert(bar.b == 2)
bar.a = 2
assert(bar.a == 2)
return bar.b

View File

@@ -598,7 +598,12 @@ impl<E: Embeddable> Circify<E> {
// get condition under which assignment happens
let guard = self.condition.clone();
// build condition-aware new value
let ite_val = Val::Term(self.e.ite(&mut self.cir_ctx, guard, new, (*old).clone()));
let ite = match guard.as_bool_opt() {
Some(true) => new,
Some(false) => old.clone(),
None => self.e.ite(&mut self.cir_ctx, guard, new, (*old).clone()),
};
let ite_val = Val::Term(ite);
// TODO: add language-specific coersion here if needed
assert!(self.vals.insert(new_name, ite_val.clone()).is_none());
Ok(ite_val)
@@ -908,7 +913,7 @@ mod test {
&**a,
format!("{}.0", raw_name),
user_name.as_ref().map(|u| format!("{}.0", u)),
visibility.clone(),
visibility,
)),
Box::new(self.declare(
ctx,
@@ -941,13 +946,7 @@ mod test {
match t {
T::Base(a) => T::Base(ctx.cs.borrow_mut().assign(&name, a, visibility)),
T::Pair(a, b) => T::Pair(
Box::new(self.assign(
ctx,
_ty,
format!("{}.0", name),
*a,
visibility.clone(),
)),
Box::new(self.assign(ctx, _ty, format!("{}.0", name), *a, visibility)),
Box::new(self.assign(ctx, _ty, format!("{}.1", name), *b, visibility)),
),
}

View File

@@ -52,7 +52,7 @@ program = { SOI ~ rule* ~ EOI }
WHITESPACE = _{ " " | "\t" | "\n" }
// basic types (ZoKrates)
// basic types (ZoKrates/Z#)
ty_field = {"field"}
ty_uint = @{"u" ~ ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }
ty_bool = {"bool"}

View File

@@ -10,7 +10,7 @@ use log::debug;
use rug::Integer;
use crate::circify::{Circify, Loc, Val};
use crate::front::zokrates::{PROVER_VIS, PUBLIC_VIS};
use crate::front::zsharp::{PROVER_VIS, PUBLIC_VIS};
use crate::ir::opt::cfold::fold;
use crate::ir::term::extras::as_uint_constant;
use crate::ir::term::*;

View File

@@ -9,7 +9,7 @@ use super::error::ErrorKind;
use super::ty::Ty;
use crate::circify::{CirCtx, Embeddable};
use crate::front::zokrates::{ZOKRATES_MODULUS_ARC, ZOK_FIELD_SORT};
use crate::front::zsharp::{ZSHARP_FIELD_SORT, ZSHARP_MODULUS_ARC};
use crate::ir::term::*;
/// A term
@@ -66,7 +66,7 @@ where
{
leaf_term(Op::Const(Value::Field(FieldElem::new(
Integer::from(i),
ZOKRATES_MODULUS_ARC.clone(),
ZSHARP_MODULUS_ARC.clone(),
))))
}
@@ -85,9 +85,9 @@ impl Ty {
match self {
Self::Bool => Sort::Bool,
Self::Uint(w) => Sort::BitVector(*w as usize),
Self::Field => ZOK_FIELD_SORT.clone(),
Self::Field => ZSHARP_FIELD_SORT.clone(),
Self::Array(n, b) => {
Sort::Array(Box::new(ZOK_FIELD_SORT.clone()), Box::new(b.sort()), *n)
Sort::Array(Box::new(ZSHARP_FIELD_SORT.clone()), Box::new(b.sort()), *n)
}
}
}
@@ -320,7 +320,7 @@ pub fn or(s: &T, t: &T) -> Result<T> {
pub fn uint_to_field(s: &T) -> Result<T> {
match &s.ty {
Ty::Uint(_) => Ok(T::new(
term![Op::UbvToPf(ZOKRATES_MODULUS_ARC.clone()); s.ir.clone()],
term![Op::UbvToPf(ZSHARP_MODULUS_ARC.clone()); s.ir.clone()],
Ty::Field,
)),
_ => Err(ErrorKind::InvalidUnOp("to_field".into(), s.clone())),
@@ -451,7 +451,7 @@ impl Datalog {
/// Initialize the Datalog lang def
pub fn new() -> Self {
Self {
modulus: ZOKRATES_MODULUS_ARC.clone(),
modulus: ZSHARP_MODULUS_ARC.clone(),
}
}
}

View File

@@ -2,8 +2,7 @@
pub mod c;
pub mod datalog;
#[allow(clippy::all)]
pub mod zokrates;
pub mod zsharp;
use super::ir::term::Computation;
use std::fmt::{self, Display, Formatter};

View File

@@ -1,646 +0,0 @@
//! The ZoKrates front-end
mod parser;
mod term;
use super::{FrontEnd, Mode};
use crate::circify::{Circify, Loc, Val};
use crate::ir::proof::{self, ConstraintMetadata};
use crate::ir::term::extras::Letified;
use crate::ir::term::*;
use log::debug;
use rug::Integer;
use std::collections::HashMap;
use std::fmt::Display;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use zokrates_pest_ast as ast;
use term::*;
/// The modulus for the ZoKrates language.
pub use term::ZOKRATES_MODULUS;
/// The modulus for the ZoKrates language.
pub use term::ZOKRATES_MODULUS_ARC;
/// The modulus for the ZoKrates language.
pub use term::ZOK_FIELD_SORT;
/// The prover visibility
pub const PROVER_VIS: Option<PartyId> = Some(proof::PROVER_ID);
/// Public visibility
pub const PUBLIC_VIS: Option<PartyId> = None;
/// Inputs to the ZoKrates compiler
pub struct Inputs {
/// The file to look for `main` in.
pub file: PathBuf,
/// The file to look for concrete arguments to main in. Optional.
///
/// ## Examples
///
/// If main takes `x: u64, y: field`, this file might contain
///
/// ```ignore
/// x 4
/// y -1
/// ```
pub inputs: Option<PathBuf>,
/// The mode to generate for (MPC or proof). Effects visibility.
pub mode: Mode,
}
/// The ZoKrates front-end. Implements [FrontEnd].
pub struct Zokrates;
impl FrontEnd for Zokrates {
type Inputs = Inputs;
fn gen(i: Inputs) -> Computation {
let loader = parser::ZLoad::new();
let asts = loader.load(&i.file);
let mut g = ZGen::new(i.inputs, asts, i.mode);
g.visit_files();
g.file_stack.push(i.file);
g.entry_fn("main");
g.file_stack.pop();
g.circ.consume().borrow().clone()
}
}
struct ZGen<'ast> {
circ: Circify<ZoKrates>,
stdlib: parser::ZStdLib,
asts: HashMap<PathBuf, ast::File<'ast>>,
file_stack: Vec<PathBuf>,
functions: HashMap<(PathBuf, String), ast::Function<'ast>>,
import_map: HashMap<(PathBuf, String), (PathBuf, String)>,
mode: Mode,
}
struct ZLoc {
var: Loc,
accesses: Vec<ZAccess>,
}
enum ZAccess {
Member(String),
Idx(T),
}
fn loc_store(struct_: T, loc: &[ZAccess], val: T) -> Result<T, String> {
match loc.first() {
None => Ok(val),
Some(ZAccess::Member(field)) => {
let inner = field_select(&struct_, &field)?;
let new_inner = loc_store(inner, &loc[1..], val)?;
field_store(struct_, &field, new_inner)
}
Some(ZAccess::Idx(idx)) => {
let old_inner = array_select(struct_.clone(), idx.clone())?;
let new_inner = loc_store(old_inner, &loc[1..], val)?;
array_store(struct_, idx.clone(), new_inner)
}
}
}
impl<'ast> ZGen<'ast> {
fn new(inputs: Option<PathBuf>, asts: HashMap<PathBuf, ast::File<'ast>>, mode: Mode) -> Self {
let this = Self {
circ: Circify::new(ZoKrates::new(inputs.map(|i| parser::parse_inputs(i)))),
asts,
stdlib: parser::ZStdLib::new(),
file_stack: vec![],
functions: HashMap::default(),
import_map: HashMap::default(),
mode,
};
this.circ
.cir_ctx()
.cs
.borrow_mut()
.metadata
.add_prover_and_verifier();
this
}
/// Unwrap a result with a span-dependent error
fn err<E: Display>(&self, e: E, s: &ast::Span) -> ! {
println!("Error: {}", e);
println!("In: {}", self.cur_path().display());
for l in s.lines() {
println!(" {}", l);
}
std::process::exit(1)
}
fn unwrap<T, E: Display>(&self, r: Result<T, E>, s: &ast::Span) -> T {
r.unwrap_or_else(|e| self.err(e, s))
}
fn builtin_call(fn_name: &str, mut args: Vec<T>) -> Result<T, String> {
match fn_name {
"EMBED/u8_to_bits" if args.len() == 1 => uint_to_bits(args.pop().unwrap()),
"EMBED/u16_to_bits" if args.len() == 1 => uint_to_bits(args.pop().unwrap()),
"EMBED/u32_to_bits" if args.len() == 1 => uint_to_bits(args.pop().unwrap()),
"EMBED/u8_from_bits" if args.len() == 1 => uint_from_bits(args.pop().unwrap()),
"EMBED/u16_from_bits" if args.len() == 1 => uint_from_bits(args.pop().unwrap()),
"EMBED/u32_from_bits" if args.len() == 1 => uint_from_bits(args.pop().unwrap()),
"EMBED/unpack" if args.len() == 1 => field_to_bits(args.pop().unwrap()),
_ => Err(format!("Unknown builtin '{}'", fn_name)),
}
}
fn stmt(&mut self, s: &ast::Statement<'ast>) {
debug!("Stmt: {}", s.span().as_str());
match s {
ast::Statement::Return(r) => {
assert!(r.expressions.len() <= 1);
if let Some(e) = r.expressions.first() {
let ret = self.expr(e);
let ret_res = self.circ.return_(Some(ret));
self.unwrap(ret_res, &r.span);
} else {
let ret_res = self.circ.return_(None);
self.unwrap(ret_res, &r.span);
}
}
ast::Statement::Assertion(e) => {
let b = bool(self.expr(&e.expression));
let e = self.unwrap(b, &e.span);
self.circ.assert(e);
}
ast::Statement::Iteration(i) => {
let ty = self.type_(&i.ty);
let s = self.const_int(&i.from);
let e = self.const_int(&i.to);
let v_name = i.index.value.clone();
self.circ.enter_scope();
let decl_res = self.circ.declare(v_name.clone(), &ty, false, PROVER_VIS);
self.unwrap(decl_res, &i.index.span);
for j in s..e {
self.circ.enter_scope();
let ass_res = self.circ.assign(
Loc::local(v_name.clone()),
Val::Term(match ty {
Ty::Uint(8) => uint_lit(j, 8),
Ty::Uint(16) => uint_lit(j, 16),
Ty::Uint(32) => uint_lit(j, 32),
Ty::Field => field_lit(j),
_ => panic!("Unexpected type for iteration: {:?}", ty),
}),
);
self.unwrap(ass_res, &i.index.span);
for s in &i.statements {
self.stmt(s);
}
self.circ.exit_scope();
}
self.circ.exit_scope();
}
ast::Statement::Definition(d) => {
assert!(d.lhs.len() <= 1);
let e = self.expr(&d.expression);
if let Some(l) = d.lhs.first() {
let ty = e.type_();
if let Some(t) = l.ty.as_ref() {
let decl_ty = self.type_(t);
if &decl_ty != ty {
self.err(
format!(
"Assignment type mismatch: {} annotated vs {} actual",
decl_ty, ty,
),
&d.span,
);
}
assert!(l.a.accesses.len() == 0);
let d_res =
self.circ
.declare_init(l.a.id.value.clone(), decl_ty, Val::Term(e));
self.unwrap(d_res, &d.span);
} else {
// Assignee case
let lval = self.lval(&l.a);
let mod_res = self.mod_lval(lval, e);
self.unwrap(mod_res, &d.span);
}
}
}
}
}
fn mod_lval(&mut self, loc: ZLoc, val: T) -> Result<(), String> {
let old = self
.circ
.get_value(loc.var.clone())
.map_err(|e| format!("{}", e))?
.unwrap_term();
let new = loc_store(old, &loc.accesses, val)?;
debug!("Assign: {:?} = {}", loc.var, Letified(new.term.clone()));
self.circ
.assign(loc.var, Val::Term(new))
.map_err(|e| format!("{}", e))
.map(|_| ())
}
fn lval(&mut self, l: &ast::Assignee<'ast>) -> ZLoc {
let mut loc = ZLoc {
var: Loc::local(l.id.value.clone()),
accesses: vec![],
};
for acc in &l.accesses {
loc.accesses.push(match acc {
ast::AssigneeAccess::Member(m) => ZAccess::Member(m.id.value.clone()),
ast::AssigneeAccess::Select(m) => ZAccess::Idx(
if let ast::RangeOrExpression::Expression(e) = &m.expression {
self.expr(&e)
} else {
panic!("Cannot assign to slice")
},
),
})
}
loc
}
fn const_(&mut self, e: &ast::ConstantExpression<'ast>) -> T {
match e {
ast::ConstantExpression::U8(u) => {
uint_lit(u8::from_str_radix(&u.value[2..], 16).unwrap(), 8)
}
ast::ConstantExpression::U16(u) => {
uint_lit(u16::from_str_radix(&u.value[2..], 16).unwrap(), 16)
}
ast::ConstantExpression::U32(u) => {
uint_lit(u32::from_str_radix(&u.value[2..], 16).unwrap(), 32)
}
ast::ConstantExpression::DecimalNumber(u) => {
field_lit(Integer::from_str_radix(&u.value, 10).unwrap())
}
ast::ConstantExpression::BooleanLiteral(u) => {
z_bool_lit(bool::from_str(&u.value).unwrap())
}
}
}
fn bin_op(&self, o: &ast::BinaryOperator) -> fn(T, T) -> Result<T, String> {
match o {
ast::BinaryOperator::BitXor => bitxor,
ast::BinaryOperator::BitAnd => bitand,
ast::BinaryOperator::BitOr => bitor,
ast::BinaryOperator::RightShift => shr,
ast::BinaryOperator::LeftShift => shl,
ast::BinaryOperator::Or => or,
ast::BinaryOperator::And => and,
ast::BinaryOperator::Add => add,
ast::BinaryOperator::Sub => sub,
ast::BinaryOperator::Mul => mul,
ast::BinaryOperator::Div => div,
ast::BinaryOperator::Rem => rem,
ast::BinaryOperator::Eq => eq,
ast::BinaryOperator::NotEq => neq,
ast::BinaryOperator::Lt => ult,
ast::BinaryOperator::Gt => ugt,
ast::BinaryOperator::Lte => ule,
ast::BinaryOperator::Gte => uge,
ast::BinaryOperator::Pow => unimplemented!(),
}
}
fn expr(&mut self, e: &ast::Expression<'ast>) -> T {
debug!("Expr: {}", e.span().as_str());
let res = match e {
ast::Expression::Constant(c) => Ok(self.const_(c)),
ast::Expression::Unary(u) => not(self.expr(&u.expression)),
ast::Expression::Binary(u) => {
let f = self.bin_op(&u.op);
let a = self.expr(&u.left);
let b = self.expr(&u.right);
f(a, b)
}
ast::Expression::Ternary(u) => {
let c = self.expr(&u.first);
let a = self.expr(&u.second);
let b = self.expr(&u.third);
cond(c, a, b)
}
ast::Expression::Identifier(u) => Ok(self
.unwrap(self.circ.get_value(Loc::local(u.value.clone())), &u.span)
.unwrap_term()),
ast::Expression::InlineArray(u) => T::new_array(
u.expressions
.iter()
.flat_map(|x| self.array_lit_elem(x))
.collect(),
),
ast::Expression::InlineStruct(u) => Ok(T::new_struct(
u.ty.value.clone(),
u.members
.iter()
.map(|m| (m.id.value.clone(), self.expr(&m.expression)))
.collect(),
)),
ast::Expression::ArrayInitializer(a) => {
let v = self.expr(&a.value);
let n = const_int(self.const_(&a.count))
.unwrap()
.to_usize()
.unwrap();
array(vec![v; n])
}
ast::Expression::Postfix(p) => {
// Assume no functions in arrays, etc.
let (base, accs) = if let Some(ast::Access::Call(c)) = p.accesses.first() {
debug!("Call: {}", p.id.value);
let (f_path, f_name) = self.deref_import(p.id.value.clone());
let args = c
.expressions
.iter()
.map(|e| self.expr(e))
.collect::<Vec<_>>();
let res = if f_path.to_string_lossy().starts_with("EMBED") {
Self::builtin_call(f_path.to_str().unwrap(), args).unwrap()
} else {
let p = (f_path, f_name);
let f = self
.functions
.get(&p)
.unwrap_or_else(|| panic!("No function '{}'", p.1))
.clone();
self.file_stack.push(p.0);
assert!(f.returns.len() <= 1);
let ret_ty = f.returns.first().map(|r| self.type_(r));
self.circ.enter_fn(p.1, ret_ty);
assert_eq!(f.parameters.len(), args.len());
for (p, a) in f.parameters.iter().zip(args) {
let ty = self.type_(&p.ty);
let d_res =
self.circ.declare_init(p.id.value.clone(), ty, Val::Term(a));
self.unwrap(d_res, &c.span);
}
for s in &f.statements {
self.stmt(s);
}
let ret = self
.circ
.exit_fn()
.map(|a| a.unwrap_term())
.unwrap_or_else(|| z_bool_lit(false));
self.file_stack.pop();
ret
};
(res, &p.accesses[1..])
} else {
// Assume no calls
(
self.unwrap(
self.circ.get_value(Loc::local(p.id.value.clone())),
&p.id.span,
)
.unwrap_term(),
&p.accesses[..],
)
};
accs.iter().fold(Ok(base), |b, acc| match acc {
ast::Access::Member(m) => field_select(&b?, &m.id.value),
ast::Access::Select(a) => match &a.expression {
ast::RangeOrExpression::Expression(e) => array_select(b?, self.expr(e)),
ast::RangeOrExpression::Range(r) => {
let s = r.from.as_ref().map(|s| self.const_int(&s.0) as usize);
let e = r.to.as_ref().map(|s| self.const_int(&s.0) as usize);
slice(b?, s, e)
}
},
ast::Access::Call(_) => unreachable!("stray call"),
})
}
};
self.unwrap(res, e.span())
}
fn array_lit_elem(&mut self, e: &ast::SpreadOrExpression<'ast>) -> Vec<T> {
match e {
ast::SpreadOrExpression::Expression(e) => vec![self.expr(e)],
ast::SpreadOrExpression::Spread(s) => self.expr(&s.expression).unwrap_array().unwrap(),
}
}
fn entry_fn(&mut self, n: &str) {
debug!("Entry: {}", n);
// find the entry function
let (f_path, f_name) = self.deref_import(n.to_owned());
let p = (f_path, f_name);
let f = self
.functions
.get(&p)
.unwrap_or_else(|| panic!("No function '{}'", p.1))
.clone();
assert!(f.returns.len() <= 1);
// get return type
let ret_ty = f.returns.first().map(|r| self.type_(r));
// setup stack frame for entry function
self.circ.enter_fn(n.to_owned(), ret_ty.clone());
for p in f.parameters.iter() {
let ty = self.type_(&p.ty);
debug!("Entry param: {}: {}", p.id.value, ty);
let vis = self.interpret_visibility(&p.visibility);
let r = self.circ.declare(p.id.value.clone(), &ty, true, vis);
self.unwrap(r, &p.span);
}
for s in &f.statements {
self.stmt(s);
}
if let Some(r) = self.circ.exit_fn() {
match self.mode {
Mode::Mpc(_) => {
let ret_term = r.unwrap_term();
let ret_terms = ret_term.terms();
self.circ
.cir_ctx()
.cs
.borrow_mut()
.outputs
.extend(ret_terms);
}
Mode::Proof => {
let ty = ret_ty.as_ref().unwrap();
let name = "return".to_owned();
let term = r.unwrap_term();
let _r = self.circ.declare(name.clone(), &ty, false, PROVER_VIS);
self.circ
.assign_with_assertions(name, term, &ty, PUBLIC_VIS);
}
Mode::Opt => {
let ret_term = r.unwrap_term();
let ret_terms = ret_term.terms();
assert!(
ret_terms.len() == 1,
"When compiling to optimize, there can only be one output"
);
let t = ret_terms.into_iter().next().unwrap();
match check(&t) {
Sort::BitVector(_) => {}
s => panic!("Cannot maximize output of type {}", s),
}
self.circ.cir_ctx().cs.borrow_mut().outputs.push(t);
}
Mode::ProofOfHighValue(v) => {
let ret_term = r.unwrap_term();
let ret_terms = ret_term.terms();
assert!(
ret_terms.len() == 1,
"When compiling to optimize, there can only be one output"
);
let t = ret_terms.into_iter().next().unwrap();
let cmp = match check(&t) {
Sort::BitVector(w) => term![BV_UGE; t, bv_lit(v, w)],
s => panic!("Cannot maximize output of type {}", s),
};
self.circ.cir_ctx().cs.borrow_mut().outputs.push(cmp);
}
}
}
}
fn interpret_visibility(&self, visibility: &Option<ast::Visibility<'ast>>) -> Option<PartyId> {
match visibility {
None | Some(ast::Visibility::Public(_)) => PUBLIC_VIS.clone(),
Some(ast::Visibility::Private(private)) => match self.mode {
Mode::Proof | Mode::Opt | Mode::ProofOfHighValue(_) => {
if private.number.is_some() {
self.err(
format!(
"Party number found, but we're generating a {} circuit",
self.mode
),
&private.span,
);
}
PROVER_VIS.clone()
}
Mode::Mpc(n_parties) => {
let num_str = private
.number
.as_ref()
.unwrap_or_else(|| self.err("No party number", &private.span));
let num_val =
u8::from_str_radix(&num_str.value[1..num_str.value.len() - 1], 10)
.unwrap_or_else(|e| {
self.err(format!("Bad party number: {}", e), &private.span)
});
if num_val <= n_parties {
Some(num_val - 1)
} else {
self.err(
format!(
"Party number {} greater than the number of parties ({})",
num_val, n_parties
),
&private.span,
)
}
}
},
}
}
fn cur_path(&self) -> &Path {
self.file_stack.last().unwrap()
}
fn cur_dir(&self) -> PathBuf {
let mut p = self.file_stack.last().unwrap().to_path_buf();
p.pop();
p
}
fn deref_import(&self, s: String) -> (PathBuf, String) {
let r = (self.cur_path().to_path_buf(), s);
self.import_map.get(&r).cloned().unwrap_or(r)
}
fn const_int(&mut self, e: &ast::Expression<'ast>) -> isize {
let i = const_int(self.expr(e));
self.unwrap(i, e.span()).to_isize().unwrap()
}
fn type_(&mut self, t: &ast::Type<'ast>) -> Ty {
fn lift<'ast>(t: &ast::BasicOrStructType<'ast>) -> ast::Type<'ast> {
match t {
ast::BasicOrStructType::Basic(b) => ast::Type::Basic(b.clone()),
ast::BasicOrStructType::Struct(b) => ast::Type::Struct(b.clone()),
}
}
match t {
ast::Type::Basic(ast::BasicType::U8(_)) => Ty::Uint(8),
ast::Type::Basic(ast::BasicType::U16(_)) => Ty::Uint(16),
ast::Type::Basic(ast::BasicType::U32(_)) => Ty::Uint(32),
ast::Type::Basic(ast::BasicType::Boolean(_)) => Ty::Bool,
ast::Type::Basic(ast::BasicType::Field(_)) => Ty::Field,
ast::Type::Array(a) => {
let b = self.type_(&lift(&a.ty));
a.dimensions
.iter()
.map(|d| self.const_int(d))
.fold(b, |b, d| Ty::Array(d as usize, Box::new(b)))
}
ast::Type::Struct(s) => self.circ.get_type(&s.id.value).clone(),
}
}
fn visit_files(&mut self) {
let t = std::mem::take(&mut self.asts);
for (p, f) in &t {
self.file_stack.push(p.to_owned());
for func in &f.functions {
debug!("fn {} in {}", func.id.value, self.cur_path().display());
self.functions.insert(
(self.cur_path().to_owned(), func.id.value.clone()),
func.clone(),
);
}
for i in &f.imports {
let (src_path, src_name, dst_name) = match i {
ast::ImportDirective::Main(m) => (
m.source.value.clone(),
"main".to_owned(),
m.alias
.as_ref()
.map(|a| a.value.clone())
.unwrap_or_else(|| {
PathBuf::from(m.source.value.clone())
.file_stem()
.unwrap_or_else(|| panic!("Bad import: {}", m.source.value))
.to_string_lossy()
.to_string()
}),
),
ast::ImportDirective::From(m) => (
m.source.value.clone(),
m.symbol.value.clone(),
m.alias
.as_ref()
.map(|a| a.value.clone())
.unwrap_or_else(|| m.symbol.value.clone()),
),
};
let abs_src_path = self.stdlib.canonicalize(&self.cur_dir(), src_path.as_str());
debug!(
"Import of {} from {} as {}",
src_name,
abs_src_path.display(),
dst_name
);
self.import_map.insert(
(self.cur_path().to_path_buf(), dst_name),
(abs_src_path, src_name),
);
}
for s in &f.structs {
let ty = Ty::new_struct(
s.id.value.clone(),
s.fields
.clone()
.iter()
.map(|f| (f.id.value.clone(), self.type_(&f.ty))),
);
debug!("struct {}", s.id.value);
self.circ.def_type(&s.id.value, ty);
}
self.file_stack.pop();
}
self.asts = t;
}
}

70
src/front/zsharp/TODO Normal file
View File

@@ -0,0 +1,70 @@
- casts
- widening casts are free!
- check narrowing cast correctness!
- look at unpack functions again
- look at pack: advantage to builtin?
u8: | u16 u32 u64 field
u16: u8 | u32 u64 field
u32: u8 u16 | u64 field
u64: u8 u16 u32 | field
- error messages: (String, &Span) instead of String to avoid recursively
expanding Spans on error?
- talk to AO about bit-split
- generalized bit-split, i.e., into vector-of-bitvectors?
goes nicely with lookup table--based range checks...
- add explicit-generic-expr to parser, e,g., foo::<(N+1)>(5)
- maybe not necessary: can just say `u32 Np1 = N + 1 ; foo::<Np1>(5)`
- POW width - allow full-width rhs? (const only? in that case, easy)
- correctness: bit width for field comparisons, etc
- when lowering to r1cs, enforce MSB=0?
- correctness: when lowering IR arrays, enforce bounds checks!
- optimization: for a < b, only expand a to b's bit width;
a < b is (a_expansion == a) && (a_expansion < b_expansion)
== done ==
[x] generic inf: monomorphize at call time
[x] make sure we got a UNIQUE solution! find_unique_model() fn
- stdlib rename to avoid confusion with parallel ZoK checkouts
[x] and/or: check ZSHARP_STDLIB_PATH envvar
[WONTFIX] remove ret requirement for fns
- typecheck with bool if no type? (and test function_call() for compatibility)
- add () or nil type?
[x] unify_inline_array revisit
[x] field `%`
[x] unsigned
- divrem? (is this necessary for efficiency? can just do r=a%b, c=(a-r)/b
- signed?
[x] const / non-const cleanup
[x] tuples: Box<[_]> rather than Vec<_>
[x] tuple typecheck on update
[x] multi-returns?
- no. if we need tuples, we'll add them to the type system properly
[x] oob array read fix
[x] solver-related optimizations
[x] duh, don't call the solver if not needed
[x] cache generic inf results
[x] array construction optimization
[x] lints
[x] pretty-printing T

1732
src/front/zsharp/mod.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//! Parsing and recursively loading ZoKrates.
//! Parsing and recursively loading Z#.
//!
//! Based on the original ZoKrates parser, with extra machinery for recursive loading and locating
//! the standard library.
@@ -7,6 +7,7 @@ use zokrates_pest_ast as ast;
use log::debug;
use std::collections::HashMap;
use std::env::var_os;
use crate::circify::includer::Loader;
use rug::Integer;
@@ -28,7 +29,7 @@ pub fn parse_inputs(p: PathBuf) -> HashMap<String, Integer> {
for l in BufReader::new(File::open(p).unwrap()).lines() {
let l = l.unwrap();
let l = l.trim();
if l.len() > 0 {
if !l.is_empty() {
let mut s = l.split_whitespace();
let key = s.next().unwrap().to_owned();
let value = Integer::from(Integer::parse_radix(&s.next().unwrap(), 10).unwrap());
@@ -39,6 +40,7 @@ pub fn parse_inputs(p: PathBuf) -> HashMap<String, Integer> {
}
/// A representation of the standard libary's location.
#[derive(Default)]
pub struct ZStdLib {
path: PathBuf,
}
@@ -47,6 +49,18 @@ impl ZStdLib {
/// Looks for a "ZoKrates/zokrates_stdlib/stdlib" path in some ancestor of the current
/// directory.
pub fn new() -> Self {
if let Some(p) = var_os("ZSHARP_STDLIB_PATH") {
let p = PathBuf::from(p);
if p.exists() {
return Self { path: p };
} else {
panic!(
"ZStdLib: ZSHARP_STDLIB_PATH {:?} does not appear to exist",
p
);
}
}
let p = std::env::current_dir().unwrap().canonicalize().unwrap();
assert!(p.is_absolute());
let stdlib_subdirs = vec![
@@ -62,15 +76,12 @@ impl ZStdLib {
}
}
}
panic!("Could not find ZoKrates stdlib from {}", p.display())
panic!("Could not find ZoKrates/Z# stdlib from {}", p.display())
}
/// Turn `child`, relative to `parent` (or to the standard libary!), into an absolute path.
pub fn canonicalize(&self, parent: &Path, child: &str) -> PathBuf {
debug!("Looking for {} from {}", child, parent.display());
if child.contains("EMBED") {
return PathBuf::from(child);
}
let paths = vec![parent.to_path_buf(), self.path.clone()];
let paths = [parent.to_path_buf(), self.path.clone()];
for mut p in paths {
p.push(child);
if p.extension().is_none() {
@@ -83,16 +94,22 @@ impl ZStdLib {
}
panic!("Could not find {} from {}", child, parent.display())
}
/// check if this path is the EMBED prototypes path
pub fn is_embed<P: AsRef<Path>>(&self, p: P) -> bool {
p.as_ref().starts_with(&self.path)
&& p.as_ref().file_stem().map(|s| s.to_str()).flatten() == Some("EMBED")
}
}
/// A recrusive zokrates loader
/// A recrusive Z# loader
#[derive(Default)]
pub struct ZLoad {
sources: Arena<String>,
stdlib: ZStdLib,
}
impl ZLoad {
/// Make a new ZoKrates loader, looking for the standard library somewhere above the current
/// Make a new Z# loader, looking for the standard library somewhere above the current
/// dirdirectory. See [ZStdLib::new].
pub fn new() -> Self {
Self {
@@ -101,7 +118,7 @@ impl ZLoad {
}
}
/// Recursively load a ZoKrates file.
/// Recursively load a Z# file.
///
/// ## Returns
///
@@ -109,6 +126,11 @@ impl ZLoad {
pub fn load<P: AsRef<Path>>(&self, p: &P) -> HashMap<PathBuf, ast::File> {
self.recursive_load(p).unwrap()
}
/// Get ref to contained ZStdLib
pub fn stdlib(&self) -> &ZStdLib {
&self.stdlib
}
}
impl<'a> Loader for &'a ZLoad {
@@ -129,16 +151,19 @@ impl<'a> Loader for &'a ZLoad {
fn includes<P: AsRef<Path>>(&self, ast: &Self::AST, p: &P) -> Vec<PathBuf> {
let mut c = p.as_ref().to_path_buf();
c.pop();
ast.imports
ast.declarations
.iter()
.map(|i| {
let ext = match i {
ast::ImportDirective::Main(m) => &m.source.value,
ast::ImportDirective::From(m) => &m.source.value,
};
self.stdlib.canonicalize(&c, ext)
.filter_map(|d| {
if let ast::SymbolDeclaration::Import(i) = d {
let ext = match i {
ast::ImportDirective::Main(m) => &m.source.value,
ast::ImportDirective::From(m) => &m.source.value,
};
Some(self.stdlib.canonicalize(&c, ext))
} else {
None
}
})
.filter(|p| p.to_str().map(|s| !s.contains("EMBED")).unwrap_or(true))
.collect()
}
}

View File

@@ -1,26 +1,43 @@
//! Symbolic ZoKrates terms
//! Symbolic Z# terms
use std::collections::{BTreeMap, HashMap};
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use lazy_static::lazy_static;
use log::warn;
use rug::Integer;
use crate::circify::{CirCtx, Embeddable};
use crate::ir::opt::cfold::fold as constant_fold;
use crate::ir::term::*;
// The modulus for Z#.
// TODO: handle this better!
#[cfg(feature = "bls12381")]
lazy_static! {
// TODO: handle this better
/// The modulus for ZoKrates.
pub static ref ZOKRATES_MODULUS: Integer = Integer::from_str_radix(
"52435875175126190479447740508185965837690552500527637822603658699938581184513",
/// The modulus for Z#
pub static ref ZSHARP_MODULUS: Integer = Integer::from_str_radix(
"52435875175126190479447740508185965837690552500527637822603658699938581184513", // BLS12-381 group order
10
)
.unwrap();
/// The modulus for ZoKrates, as an ARC
pub static ref ZOKRATES_MODULUS_ARC: Arc<Integer> = Arc::new(ZOKRATES_MODULUS.clone());
/// The modulus for ZoKrates, as an IR sort
pub static ref ZOK_FIELD_SORT: Sort = Sort::Field(ZOKRATES_MODULUS_ARC.clone());
}
#[cfg(not(feature = "bls12381"))]
lazy_static! {
/// The modulus for Z#
pub static ref ZSHARP_MODULUS: Integer = Integer::from_str_radix(
"21888242871839275222246405745257275088548364400416034343698204186575808495617", // BN-254 group order
10
)
.unwrap();
}
lazy_static! {
/// The modulus for Z#, as an ARC
pub static ref ZSHARP_MODULUS_ARC: Arc<Integer> = Arc::new(ZSHARP_MODULUS.clone());
/// The modulus for Z#, as an IR sort
pub static ref ZSHARP_FIELD_SORT: Sort = Sort::Field(ZSHARP_MODULUS_ARC.clone());
}
#[derive(Clone, PartialEq, Eq)]
@@ -38,6 +55,7 @@ pub use field_list::FieldList;
///
/// It gets its own module so that its member can be private.
mod field_list {
use std::collections::BTreeMap;
#[derive(Clone, PartialEq, Eq)]
pub struct FieldList<T> {
@@ -63,6 +81,9 @@ mod field_list {
pub fn fields(&self) -> impl Iterator<Item = &(String, T)> {
self.list.iter()
}
pub fn into_map(self) -> BTreeMap<String, T> {
self.list.into_iter().collect()
}
}
}
@@ -79,7 +100,16 @@ impl Display for Ty {
}
o.finish()
}
Ty::Array(n, b) => write!(f, "{}[{}]", b, n),
Ty::Array(n, b) => {
let mut dims = vec![n];
let mut bb = b.as_ref();
while let Ty::Array(n, b) = bb {
bb = b.as_ref();
dims.push(n);
}
write!(f, "{}", bb)?;
dims.iter().try_for_each(|d| write!(f, "[{}]", d))
}
}
}
}
@@ -95,9 +125,9 @@ impl Ty {
match self {
Self::Bool => Sort::Bool,
Self::Uint(w) => Sort::BitVector(*w),
Self::Field => ZOK_FIELD_SORT.clone(),
Self::Field => ZSHARP_FIELD_SORT.clone(),
Self::Array(n, b) => {
Sort::Array(Box::new(ZOK_FIELD_SORT.clone()), Box::new(b.sort()), *n)
Sort::Array(Box::new(ZSHARP_FIELD_SORT.clone()), Box::new(b.sort()), *n)
}
Self::Struct(_name, fs) => {
Sort::Tuple(fs.fields().map(|(_f_name, f_ty)| f_ty.sort()).collect())
@@ -107,10 +137,10 @@ impl Ty {
fn default_ir_term(&self) -> Term {
self.sort().default_term()
}
fn default(&self) -> T {
pub fn default(&self) -> T {
T {
term: self.default_ir_term(),
ty: self.clone(),
term: self.default_ir_term(),
}
}
/// Creates a new structure type, sorting the keys.
@@ -178,6 +208,7 @@ impl T {
pub fn new_array(v: Vec<T>) -> Result<T, String> {
array(v)
}
pub fn new_struct(name: String, fields: Vec<(String, T)>) -> T {
let (field_tys, ir_terms): (Vec<_>, Vec<_>) = fields
.into_iter()
@@ -193,6 +224,102 @@ impl T {
});
T::new(Ty::Struct(name, field_ty_list), ir_term)
}
// XXX(rsw) hrm is there a nicer way to do this?
pub fn new_field<I>(v: I) -> Self
where
Integer: From<I>,
{
T::new(Ty::Field, pf_lit_ir(v))
}
pub fn new_u8<I>(v: I) -> Self
where
Integer: From<I>,
{
T::new(Ty::Uint(8), bv_lit(v, 8))
}
pub fn new_u16<I>(v: I) -> Self
where
Integer: From<I>,
{
T::new(Ty::Uint(16), bv_lit(v, 16))
}
pub fn new_u32<I>(v: I) -> Self
where
Integer: From<I>,
{
T::new(Ty::Uint(32), bv_lit(v, 32))
}
pub fn new_u64<I>(v: I) -> Self
where
Integer: From<I>,
{
T::new(Ty::Uint(64), bv_lit(v, 64))
}
pub fn pretty<W: std::io::Write>(&self, f: &mut W) -> Result<(), std::io::Error> {
use std::io::{Error, ErrorKind};
let val = match &self.term.op {
Op::Const(v) => Ok(v),
_ => Err(Error::new(ErrorKind::Other, "not a const val")),
}?;
match val {
Value::Bool(b) => write!(f, "{}", b),
Value::Field(fe) => write!(f, "{}f", fe.i()),
Value::BitVector(bv) => match bv.width() {
8 => write!(f, "0x{:02x}", bv.uint()),
16 => write!(f, "0x{:04x}", bv.uint()),
32 => write!(f, "0x{:08x}", bv.uint()),
64 => write!(f, "0x{:016x}", bv.uint()),
_ => unreachable!(),
},
Value::Tuple(vs) => {
let (n, fl) = if let Ty::Struct(n, fl) = &self.ty {
Ok((n, fl))
} else {
Err(Error::new(
ErrorKind::Other,
"expected struct, got something else",
))
}?;
write!(f, "{} {{ ", n)?;
fl.fields().zip(vs.iter()).try_for_each(|((n, ty), v)| {
write!(f, "{}: ", n)?;
T::new(ty.clone(), leaf_term(Op::Const(v.clone()))).pretty(f)?;
write!(f, ", ")
})?;
write!(f, "}}")
}
Value::Array(arr) => {
let inner_ty = if let Ty::Array(_, ty) = &self.ty {
Ok(ty)
} else {
Err(Error::new(
ErrorKind::Other,
"expected array, got something else",
))
}?;
write!(f, "[")?;
arr.key_sort
.elems_iter()
.take(arr.size)
.try_for_each(|idx| {
T::new(
*inner_ty.clone(),
leaf_term(Op::Const(arr.select(idx.as_value_opt().unwrap()))),
)
.pretty(f)?;
write!(f, ", ")
})?;
write!(f, "]")
}
_ => unreachable!(),
}
}
}
impl Display for T {
@@ -293,12 +420,19 @@ pub fn div(a: T, b: T) -> Result<T, String> {
wrap_bin_op("/", Some(div_uint), Some(div_field), None, a, b)
}
fn rem_field(a: Term, b: Term) -> Term {
let len = ZSHARP_MODULUS.significant_bits() as usize;
let a_bv = term![Op::PfToBv(len); a];
let b_bv = term![Op::PfToBv(len); b];
term![Op::UbvToPf(ZSHARP_MODULUS_ARC.clone()); term![Op::BvBinOp(BvBinOp::Urem); a_bv, b_bv]]
}
fn rem_uint(a: Term, b: Term) -> Term {
term![Op::BvBinOp(BvBinOp::Urem); a, b]
}
pub fn rem(a: T, b: T) -> Result<T, String> {
wrap_bin_op("%", Some(rem_uint), None, None, a, b)
wrap_bin_op("%", Some(rem_uint), Some(rem_field), None, a, b)
}
fn bitand_uint(a: Term, b: Term) -> Term {
@@ -341,52 +475,108 @@ pub fn and(a: T, b: T) -> Result<T, String> {
wrap_bin_op("&&", None, None, Some(and_bool), a, b)
}
fn eq_base(a: Term, b: Term) -> Term {
term![Op::Eq; a, b]
fn eq_base(a: T, b: T) -> Result<Term, String> {
if a.ty != b.ty {
Err(format!(
"Cannot '==' dissimilar types {} and {}",
a.type_(),
b.type_()
))
} else {
Ok(term![Op::Eq; a.term, b.term])
}
}
pub fn eq(a: T, b: T) -> Result<T, String> {
wrap_bin_pred("==", Some(eq_base), Some(eq_base), Some(eq_base), a, b)
}
fn neq_base(a: Term, b: Term) -> Term {
term![Op::Not; term![Op::Eq; a, b]]
Ok(T::new(Ty::Bool, eq_base(a, b)?))
}
pub fn neq(a: T, b: T) -> Result<T, String> {
wrap_bin_pred("!=", Some(neq_base), Some(neq_base), Some(neq_base), a, b)
Ok(T::new(Ty::Bool, not_bool(eq_base(a, b)?)))
}
fn ult_uint(a: Term, b: Term) -> Term {
term![Op::BvBinPred(BvBinPred::Ult); a, b]
}
// XXX(constr_opt) see TODO file - only need to expand to MIN of two bit-lengths if done right
// XXX(constr_opt) do this using subtraction instead?
fn field_comp(a: Term, b: Term, op: BvBinPred) -> Term {
let len = ZSHARP_MODULUS.significant_bits() as usize;
let a_bv = term![Op::PfToBv(len); a];
let b_bv = term![Op::PfToBv(len); b];
term![Op::BvBinPred(op); a_bv, b_bv]
}
fn ult_field(a: Term, b: Term) -> Term {
field_comp(a, b, BvBinPred::Ult)
}
pub fn ult(a: T, b: T) -> Result<T, String> {
wrap_bin_pred("<", Some(ult_uint), None, None, a, b)
wrap_bin_pred("<", Some(ult_uint), Some(ult_field), None, a, b)
}
fn ule_uint(a: Term, b: Term) -> Term {
term![Op::BvBinPred(BvBinPred::Ule); a, b]
}
fn ule_field(a: Term, b: Term) -> Term {
field_comp(a, b, BvBinPred::Ule)
}
pub fn ule(a: T, b: T) -> Result<T, String> {
wrap_bin_pred("<=", Some(ule_uint), None, None, a, b)
wrap_bin_pred("<=", Some(ule_uint), Some(ule_field), None, a, b)
}
fn ugt_uint(a: Term, b: Term) -> Term {
term![Op::BvBinPred(BvBinPred::Ugt); a, b]
}
fn ugt_field(a: Term, b: Term) -> Term {
field_comp(a, b, BvBinPred::Ugt)
}
pub fn ugt(a: T, b: T) -> Result<T, String> {
wrap_bin_pred(">", Some(ugt_uint), None, None, a, b)
wrap_bin_pred(">", Some(ugt_uint), Some(ugt_field), None, a, b)
}
fn uge_uint(a: Term, b: Term) -> Term {
term![Op::BvBinPred(BvBinPred::Uge); a, b]
}
fn uge_field(a: Term, b: Term) -> Term {
field_comp(a, b, BvBinPred::Uge)
}
pub fn uge(a: T, b: T) -> Result<T, String> {
wrap_bin_pred(">=", Some(uge_uint), None, None, a, b)
wrap_bin_pred(">=", Some(uge_uint), Some(uge_field), None, a, b)
}
pub fn pow(a: T, b: T) -> Result<T, String> {
if a.ty != Ty::Field || b.ty != Ty::Uint(32) {
return Err(format!(
"Cannot compute {} ** {} : must be Field ** U32",
a, b
));
}
let a = a.term;
let b = const_int(b)?;
if b == 0 {
return Ok(field_lit(1));
}
let res = (0..b.significant_bits() - 1)
.rev()
.fold(a.clone(), |acc, ix| {
let acc = mul_field(acc.clone(), acc);
if b.get_bit(ix) {
mul_field(acc, a.clone())
} else {
acc
}
});
Ok(T::new(Ty::Field, res))
}
fn wrap_un_op(
@@ -412,7 +602,6 @@ fn neg_uint(a: Term) -> Term {
term![Op::BvUnOp(BvUnOp::Neg); a]
}
#[allow(dead_code)]
// Missing from ZoKrates.
pub fn neg(a: T) -> Result<T, String> {
wrap_un_op("unary-", Some(neg_uint), Some(neg_field), None, a)
@@ -431,12 +620,33 @@ pub fn not(a: T) -> Result<T, String> {
}
pub fn const_int(a: T) -> Result<Integer, String> {
match &a.term.op {
Op::Const(Value::Field(f)) => Some(f.i().clone()),
Op::Const(Value::BitVector(f)) => Some(f.uint().clone()),
match const_value(&a.term) {
Some(Value::Field(f)) => Ok(f.i().clone()),
Some(Value::BitVector(f)) => Ok(f.uint().clone()),
_ => Err(format!("{} is not a constant integer", a)),
}
}
pub fn const_bool(a: T) -> Option<bool> {
match const_value(&a.term) {
Some(Value::Bool(b)) => Some(b),
_ => None,
}
}
pub fn const_val(a: T) -> Result<T, String> {
match const_value(&a.term) {
Some(v) => Ok(T::new(a.ty, leaf_term(Op::Const(v)))),
_ => Err(format!("{} is not a constant basic type", &a)),
}
}
fn const_value(t: &Term) -> Option<Value> {
let folded = constant_fold(t);
match &folded.op {
Op::Const(v) => Some(v.clone()),
_ => None,
}
.ok_or_else(|| format!("{} is not a constant integer", a))
}
pub fn bool(a: T) -> Result<Term, String> {
@@ -463,7 +673,7 @@ pub fn shr(a: T, b: T) -> Result<T, String> {
}
fn ite(c: Term, a: T, b: T) -> Result<T, String> {
if &a.ty != &b.ty {
if a.ty != b.ty {
Err(format!("Cannot perform ITE on {} and {}", a, b))
} else {
Ok(T::new(a.ty.clone(), term![Op::Ite; c, a.term, b.term]))
@@ -478,10 +688,14 @@ pub fn pf_lit_ir<I>(i: I) -> Term
where
Integer: From<I>,
{
leaf_term(Op::Const(Value::Field(FieldElem::new(
Integer::from(i),
ZOKRATES_MODULUS_ARC.clone(),
))))
leaf_term(Op::Const(pf_val(i)))
}
fn pf_val<I>(i: I) -> Value
where
Integer: From<I>,
{
Value::Field(FieldElem::new(Integer::from(i), ZSHARP_MODULUS_ARC.clone()))
}
pub fn field_lit<I>(i: I) -> T
@@ -555,26 +769,64 @@ pub fn field_store(struct_: T, field: &str, val: T) -> Result<T, String> {
}
pub fn array_select(array: T, idx: T) -> Result<T, String> {
match (array.ty, idx.ty) {
(Ty::Array(_size, elem_ty), Ty::Field) => {
Ok(T::new(*elem_ty, term![Op::Select; array.term, idx.term]))
match array.ty {
Ty::Array(_, elem_ty) if matches!(idx.ty, Ty::Uint(_) | Ty::Field) => {
let iterm = if matches!(idx.ty, Ty::Uint(_)) {
warn!("warning: indexing array with Uint type");
term![Op::UbvToPf(ZSHARP_MODULUS_ARC.clone()); idx.term]
} else {
idx.term
};
Ok(T::new(*elem_ty, term![Op::Select; array.term, iterm]))
}
(a, b) => Err(format!("Cannot index {} by {}", b, a)),
_ => Err(format!("Cannot index {} using {}", &array.ty, &idx.ty)),
}
}
pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> {
match (&array.ty, idx.ty) {
(Ty::Array(_, _), Ty::Field) => Ok(T::new(
if matches!(&array.ty, Ty::Array(_, _)) && matches!(&idx.ty, Ty::Uint(_) | Ty::Field) {
// XXX(q) typecheck here?
let iterm = if matches!(idx.ty, Ty::Uint(_)) {
warn!("warning: indexing array with Uint type");
term![Op::UbvToPf(ZSHARP_MODULUS_ARC.clone()); idx.term]
} else {
idx.term
};
Ok(T::new(
array.ty,
term![Op::Store; array.term, idx.term, val.term],
)),
(a, b) => Err(format!("Cannot index {} by {}", b, a)),
term![Op::Store; array.term, iterm, val.term],
))
} else {
Err(format!("Cannot index {} using {}", &array.ty, &idx.ty))
}
}
fn ir_array<I: IntoIterator<Item = Term>>(sort: Sort, elems: I) -> Term {
make_array(ZOK_FIELD_SORT.clone(), sort, elems.into_iter().collect())
let mut values = BTreeMap::new();
let to_insert = elems
.into_iter()
.enumerate()
.filter_map(|(i, t)| {
let i_val = pf_val(i);
match const_value(&t) {
Some(v) => {
values.insert(i_val, v);
None
}
None => Some((leaf_term(Op::Const(i_val)), t)),
}
})
.collect::<Vec<(Term, Term)>>();
let len = values.len() + to_insert.len();
let arr = leaf_term(Op::Const(Value::Array(Array::new(
ZSHARP_FIELD_SORT.clone(),
Box::new(sort.default_value()),
values,
len,
))));
to_insert
.into_iter()
.fold(arr, |arr, (idx, val)| term![Op::Store; arr, idx, val])
}
pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
@@ -582,7 +834,7 @@ pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
if let Some(e) = v.first() {
let ty = e.type_();
if v.iter().skip(1).any(|a| a.type_() != ty) {
Err(format!("Inconsistent types in array"))
Err("Inconsistent types in array".to_string())
} else {
let sort = check(&e.term);
Ok(T::new(
@@ -591,7 +843,7 @@ pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
))
}
} else {
Err(format!("Empty array"))
Err("Empty array".to_string())
}
}
@@ -601,17 +853,18 @@ pub fn uint_to_bits(u: T) -> Result<T, String> {
Ty::Array(*n, Box::new(Ty::Bool)),
ir_array(
Sort::Bool,
(0..*n).map(|i| term![Op::BvBit(i); u.term.clone()]),
(0..*n).rev().map(|i| term![Op::BvBit(i); u.term.clone()]),
),
)),
u => Err(format!("Cannot do uint-to-bits on {}", u)),
}
}
// XXX(rsw) is it correct to enforce length here, vs. in (say) builtin_call in mod.rs?
pub fn uint_from_bits(u: T) -> Result<T, String> {
match &u.ty {
Ty::Array(bits, elem_ty) if &**elem_ty == &Ty::Bool => match bits {
8 | 16 | 32 => Ok(T::new(
Ty::Array(bits, elem_ty) if **elem_ty == Ty::Bool => match bits {
8 | 16 | 32 | 64 => Ok(T::new(
Ty::Uint(*bits),
term(
Op::BvConcat,
@@ -627,14 +880,53 @@ pub fn uint_from_bits(u: T) -> Result<T, String> {
}
}
pub fn field_to_bits(f: T) -> Result<T, String> {
pub fn field_to_bits(f: T, n: usize) -> Result<T, String> {
match &f.ty {
Ty::Field => uint_to_bits(T::new(Ty::Uint(254), term![Op::PfToBv(254); f.term])),
Ty::Field => uint_to_bits(T::new(Ty::Uint(n), term![Op::PfToBv(n); f.term])),
u => Err(format!("Cannot do uint-to-bits on {}", u)),
}
}
pub struct ZoKrates {
fn bv_from_bits(barr: Term, size: usize) -> Term {
term(
Op::BvConcat,
(0..size)
.map(|i| term![Op::BoolToBv; term![Op::Select; barr.clone(), pf_lit_ir(i)]])
.collect(),
)
}
pub fn bit_array_le(a: T, b: T, n: usize) -> Result<T, String> {
match (&a.ty, &b.ty) {
(Ty::Array(la, ta), Ty::Array(lb, tb)) => {
if **ta != Ty::Bool || **tb != Ty::Bool {
Err("bit-array-le must be called on arrays of Bools".to_string())
} else if la != lb {
Err(format!(
"bit-array-le called on arrays with lengths {} != {}",
la, lb
))
} else if *la != n {
Err(format!(
"bit-array-le::<{}> called on arrays with length {}",
n, la
))
} else {
Ok(())
}
}
_ => Err(format!("Cannot do bit-array-le on ({}, {})", &a.ty, &b.ty)),
}?;
let at = bv_from_bits(a.term, n);
let bt = bv_from_bits(b.term, n);
Ok(T::new(
Ty::Bool,
term![Op::BvBinPred(BvBinPred::Ule); at, bt],
))
}
pub struct ZSharp {
values: Option<HashMap<String, Integer>>,
modulus: Arc<Integer>,
}
@@ -647,16 +939,16 @@ fn idx_name(struct_name: &str, idx: usize) -> String {
format!("{}.{}", struct_name, idx)
}
impl ZoKrates {
impl ZSharp {
pub fn new(values: Option<HashMap<String, Integer>>) -> Self {
Self {
values,
modulus: ZOKRATES_MODULUS_ARC.clone(),
modulus: ZSHARP_MODULUS_ARC.clone(),
}
}
}
impl Embeddable for ZoKrates {
impl Embeddable for ZSharp {
type T = T;
type Ty = Ty;
fn declare(
@@ -713,7 +1005,7 @@ impl Embeddable for ZoKrates {
&*ty,
idx_name(&raw_name, i),
user_name.as_ref().map(|u| idx_name(u, i)),
visibility.clone(),
visibility,
)
}))
.unwrap(),
@@ -728,7 +1020,7 @@ impl Embeddable for ZoKrates {
f_ty,
field_name(&raw_name, f_name),
user_name.as_ref().map(|u| field_name(u, f_name)),
visibility.clone(),
visibility,
),
)
})

View File

@@ -0,0 +1,41 @@
overloading:
Functions are imported by name. If many functions have the same name
but different signatures, all of them get imported, and which one to
use in a particular call is inferred. (ZoK manual section 3.8)
==> we disallow explicitly
inferred types for decimal literals (ZoK manual section 3.2)
inside expressions
in assignments if LHS
add "untypedInteger" and unify as we go?
==> handled
multi-assignment
==> not implemented (WONTFIX?)
generics
==> handled, but with edge cases
add array-membership operator
add arithmetic-progression literal (to use with array-membership op)
should we make range checks explicit in IR?
we have power-of-2 right now
add non-power-of-2 range check?
===
// Following is totally broken right now (may work in ref compiler because
// they seem to monomorphize on-the-fly --- consider doing this?)
def last<N>(u32[N] a) -> u32:
return a[N-1]
def foo<N>(u32[N] a) -> u32:
// can't compute 2*N and pass to last because it has to be a const value!
// (and cannot declare const values inside functions)
return last([...a, ...a])
def main() -> u32:
return foo([1,2,3])
// XXX do we want to add const decls *inside* functions?
// not possible right now, but could help for cases like this

View File

@@ -0,0 +1,74 @@
//! AST Walker for zokrates_pest_ast
use super::{ZVisitorError, ZVisitorResult};
use zokrates_pest_ast as ast;
pub fn eq_type<'ast>(ty: &ast::Type<'ast>, ty2: &ast::Type<'ast>) -> ZVisitorResult {
use ast::Type::*;
match (ty, ty2) {
(Basic(bty), Basic(bty2)) => eq_basic_type(bty, bty2),
(Array(aty), Array(aty2)) => eq_array_type(aty, aty2),
(Struct(sty), Struct(sty2)) => eq_struct_type(sty, sty2),
_ => Err(ZVisitorError(format!(
"type mismatch: expected {:?}, found {:?}",
ty, ty2,
))),
}
}
pub fn eq_basic_type<'ast>(
ty: &ast::BasicType<'ast>,
ty2: &ast::BasicType<'ast>,
) -> ZVisitorResult {
use ast::BasicType::*;
match (ty, ty2) {
(Field(_), Field(_)) => Ok(()),
(Boolean(_), Boolean(_)) => Ok(()),
(U8(_), U8(_)) => Ok(()),
(U16(_), U16(_)) => Ok(()),
(U32(_), U32(_)) => Ok(()),
(U64(_), U64(_)) => Ok(()),
_ => Err(ZVisitorError(format!(
"basic type mismatch: expected {:?}, found {:?}",
ty, ty2,
))),
}
}
pub fn eq_array_type<'ast>(
ty: &ast::ArrayType<'ast>,
ty2: &ast::ArrayType<'ast>,
) -> ZVisitorResult {
use ast::BasicOrStructType::*;
if ty.dimensions.len() != ty2.dimensions.len() {
return Err(ZVisitorError(format!(
"array type mismatch: expected {}-dimensional array, found {}-dimensional array",
ty.dimensions.len(),
ty2.dimensions.len(),
)));
}
match (&ty.ty, &ty2.ty) {
(Basic(bty), Basic(bty2)) => eq_basic_type(bty, bty2),
(Struct(sty), Struct(sty2)) => eq_struct_type(sty, sty2),
_ => Err(ZVisitorError(format!(
"array type mismatch: expected elms of type {:?}, found {:?}",
&ty.ty, &ty2.ty,
))),
}
}
pub fn eq_struct_type<'ast>(
ty: &ast::StructType<'ast>,
ty2: &ast::StructType<'ast>,
) -> ZVisitorResult {
if ty.id.value != ty2.id.value {
Err(ZVisitorError(format!(
"struct type mismatch: expected {:?}, found {:?}",
&ty.id.value, &ty2.id.value,
)))
} else {
// don't check generics here; they'll get checked after monomorphization
Ok(())
}
}

View File

@@ -0,0 +1,34 @@
//! AST Walker for zokrates_pest_ast
#![allow(missing_docs)]
mod eqtype;
mod walkfns;
mod zconstlitrw;
mod zgenericinf;
mod zstmtwalker;
mod zvmut;
pub(super) use zconstlitrw::ZConstLiteralRewriter;
pub(super) use zgenericinf::ZGenericInf;
pub(super) use zstmtwalker::ZStatementWalker;
pub use zvmut::ZVisitorMut;
use zokrates_pest_ast as ast;
pub struct ZVisitorError(pub String);
pub type ZResult<T> = Result<T, ZVisitorError>;
pub type ZVisitorResult = ZResult<()>;
impl From<String> for ZVisitorError {
fn from(f: String) -> Self {
Self(f)
}
}
fn bos_to_type(bos: ast::BasicOrStructType) -> ast::Type {
use ast::{BasicOrStructType::*, Type};
match bos {
Struct(st) => Type::Struct(st),
Basic(bt) => Type::Basic(bt),
}
}

View File

@@ -0,0 +1,783 @@
//! AST Walker for zokrates_pest_ast
use super::{ZVisitorMut, ZVisitorResult};
use zokrates_pest_ast as ast;
pub fn walk_file<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
file: &mut ast::File<'ast>,
) -> ZVisitorResult {
if let Some(p) = &mut file.pragma {
visitor.visit_pragma(p)?;
}
file.declarations
.iter_mut()
.try_for_each(|d| visitor.visit_symbol_declaration(d))?;
visitor.visit_eoi(&mut file.eoi)?;
visitor.visit_span(&mut file.span)
}
pub fn walk_pragma<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
pragma: &mut ast::Pragma<'ast>,
) -> ZVisitorResult {
visitor.visit_curve(&mut pragma.curve)?;
visitor.visit_span(&mut pragma.span)
}
pub fn walk_curve<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
curve: &mut ast::Curve<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut curve.span)
}
pub fn walk_symbol_declaration<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
sd: &mut ast::SymbolDeclaration<'ast>,
) -> ZVisitorResult {
use ast::SymbolDeclaration::*;
match sd {
Import(i) => visitor.visit_import_directive(i),
Constant(c) => visitor.visit_constant_definition(c),
Struct(s) => visitor.visit_struct_definition(s),
Function(f) => visitor.visit_function_definition(f),
}
}
pub fn walk_import_directive<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
import: &mut ast::ImportDirective<'ast>,
) -> ZVisitorResult {
use ast::ImportDirective::*;
match import {
Main(m) => visitor.visit_main_import_directive(m),
From(f) => visitor.visit_from_import_directive(f),
}
}
pub fn walk_main_import_directive<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
mimport: &mut ast::MainImportDirective<'ast>,
) -> ZVisitorResult {
visitor.visit_import_source(&mut mimport.source)?;
if let Some(ie) = &mut mimport.alias {
visitor.visit_identifier_expression(ie)?;
}
visitor.visit_span(&mut mimport.span)
}
pub fn walk_from_import_directive<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
fimport: &mut ast::FromImportDirective<'ast>,
) -> ZVisitorResult {
visitor.visit_import_source(&mut fimport.source)?;
fimport
.symbols
.iter_mut()
.try_for_each(|s| visitor.visit_import_symbol(s))?;
visitor.visit_span(&mut fimport.span)
}
pub fn walk_import_source<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
is: &mut ast::ImportSource<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut is.span)
}
pub fn walk_identifier_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ie: &mut ast::IdentifierExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut ie.span)
}
pub fn walk_import_symbol<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
is: &mut ast::ImportSymbol<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut is.id)?;
if let Some(ie) = &mut is.alias {
visitor.visit_identifier_expression(ie)?;
}
visitor.visit_span(&mut is.span)
}
pub fn walk_constant_definition<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
cnstdef: &mut ast::ConstantDefinition<'ast>,
) -> ZVisitorResult {
visitor.visit_type(&mut cnstdef.ty)?;
visitor.visit_identifier_expression(&mut cnstdef.id)?;
visitor.visit_expression(&mut cnstdef.expression)?;
visitor.visit_span(&mut cnstdef.span)
}
pub fn walk_struct_definition<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
structdef: &mut ast::StructDefinition<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut structdef.id)?;
structdef
.generics
.iter_mut()
.try_for_each(|g| visitor.visit_identifier_expression(g))?;
structdef
.fields
.iter_mut()
.try_for_each(|f| visitor.visit_struct_field(f))?;
visitor.visit_span(&mut structdef.span)
}
pub fn walk_struct_field<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
structfield: &mut ast::StructField<'ast>,
) -> ZVisitorResult {
visitor.visit_type(&mut structfield.ty)?;
visitor.visit_identifier_expression(&mut structfield.id)?;
visitor.visit_span(&mut structfield.span)
}
pub fn walk_function_definition<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
fundef: &mut ast::FunctionDefinition<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut fundef.id)?;
fundef
.generics
.iter_mut()
.try_for_each(|g| visitor.visit_identifier_expression(g))?;
fundef
.parameters
.iter_mut()
.try_for_each(|p| visitor.visit_parameter(p))?;
fundef
.returns
.iter_mut()
.try_for_each(|r| visitor.visit_type(r))?;
fundef
.statements
.iter_mut()
.try_for_each(|s| visitor.visit_statement(s))?;
visitor.visit_span(&mut fundef.span)
}
pub fn walk_parameter<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
param: &mut ast::Parameter<'ast>,
) -> ZVisitorResult {
if let Some(v) = &mut param.visibility {
visitor.visit_visibility(v)?;
}
visitor.visit_type(&mut param.ty)?;
visitor.visit_identifier_expression(&mut param.id)?;
visitor.visit_span(&mut param.span)
}
pub fn walk_visibility<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
vis: &mut ast::Visibility<'ast>,
) -> ZVisitorResult {
use ast::Visibility::*;
match vis {
Public(pu) => visitor.visit_public_visibility(pu),
Private(pr) => visitor.visit_private_visibility(pr),
}
}
pub fn walk_private_visibility<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
prv: &mut ast::PrivateVisibility<'ast>,
) -> ZVisitorResult {
if let Some(pn) = &mut prv.number {
visitor.visit_private_number(pn)?;
}
visitor.visit_span(&mut prv.span)
}
pub fn walk_private_number<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
pn: &mut ast::PrivateNumber<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut pn.span)
}
pub fn walk_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ty: &mut ast::Type<'ast>,
) -> ZVisitorResult {
use ast::Type::*;
match ty {
Basic(b) => visitor.visit_basic_type(b),
Array(a) => visitor.visit_array_type(a),
Struct(s) => visitor.visit_struct_type(s),
}
}
pub fn walk_basic_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
bty: &mut ast::BasicType<'ast>,
) -> ZVisitorResult {
use ast::BasicType::*;
match bty {
Field(f) => visitor.visit_field_type(f),
Boolean(b) => visitor.visit_boolean_type(b),
U8(u) => visitor.visit_u8_type(u),
U16(u) => visitor.visit_u16_type(u),
U32(u) => visitor.visit_u32_type(u),
U64(u) => visitor.visit_u64_type(u),
}
}
pub fn walk_field_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
fty: &mut ast::FieldType<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut fty.span)
}
pub fn walk_boolean_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
bty: &mut ast::BooleanType<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut bty.span)
}
pub fn walk_u8_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u8ty: &mut ast::U8Type<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u8ty.span)
}
pub fn walk_u16_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u16ty: &mut ast::U16Type<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u16ty.span)
}
pub fn walk_u32_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u32ty: &mut ast::U32Type<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u32ty.span)
}
pub fn walk_u64_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u64ty: &mut ast::U64Type<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u64ty.span)
}
pub fn walk_array_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
aty: &mut ast::ArrayType<'ast>,
) -> ZVisitorResult {
visitor.visit_basic_or_struct_type(&mut aty.ty)?;
aty.dimensions
.iter_mut()
.try_for_each(|d| visitor.visit_expression(d))?;
visitor.visit_span(&mut aty.span)
}
pub fn walk_basic_or_struct_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
bsty: &mut ast::BasicOrStructType<'ast>,
) -> ZVisitorResult {
use ast::BasicOrStructType::*;
match bsty {
Struct(s) => visitor.visit_struct_type(s),
Basic(b) => visitor.visit_basic_type(b),
}
}
pub fn walk_struct_type<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
sty: &mut ast::StructType<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut sty.id)?;
if let Some(eg) = &mut sty.explicit_generics {
visitor.visit_explicit_generics(eg)?;
}
visitor.visit_span(&mut sty.span)
}
pub fn walk_explicit_generics<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
eg: &mut ast::ExplicitGenerics<'ast>,
) -> ZVisitorResult {
eg.values
.iter_mut()
.try_for_each(|v| visitor.visit_constant_generic_value(v))?;
visitor.visit_span(&mut eg.span)
}
pub fn walk_constant_generic_value<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
cgv: &mut ast::ConstantGenericValue<'ast>,
) -> ZVisitorResult {
use ast::ConstantGenericValue::*;
match cgv {
Value(l) => visitor.visit_literal_expression(l),
Identifier(i) => visitor.visit_identifier_expression(i),
Underscore(u) => visitor.visit_underscore(u),
}
}
pub fn walk_literal_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
lexpr: &mut ast::LiteralExpression<'ast>,
) -> ZVisitorResult {
use ast::LiteralExpression::*;
match lexpr {
DecimalLiteral(d) => visitor.visit_decimal_literal_expression(d),
BooleanLiteral(b) => visitor.visit_boolean_literal_expression(b),
HexLiteral(h) => visitor.visit_hex_literal_expression(h),
}
}
pub fn walk_decimal_literal_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
dle: &mut ast::DecimalLiteralExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_decimal_number(&mut dle.value)?;
if let Some(s) = &mut dle.suffix {
visitor.visit_decimal_suffix(s)?;
}
visitor.visit_span(&mut dle.span)
}
pub fn walk_decimal_number<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
dn: &mut ast::DecimalNumber<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut dn.span)
}
pub fn walk_decimal_suffix<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ds: &mut ast::DecimalSuffix<'ast>,
) -> ZVisitorResult {
use ast::DecimalSuffix::*;
match ds {
U8(u8s) => visitor.visit_u8_suffix(u8s),
U16(u16s) => visitor.visit_u16_suffix(u16s),
U32(u32s) => visitor.visit_u32_suffix(u32s),
U64(u64s) => visitor.visit_u64_suffix(u64s),
Field(fs) => visitor.visit_field_suffix(fs),
}
}
pub fn walk_u8_suffix<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u8s: &mut ast::U8Suffix<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u8s.span)
}
pub fn walk_u16_suffix<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u16s: &mut ast::U16Suffix<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u16s.span)
}
pub fn walk_u32_suffix<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u32s: &mut ast::U32Suffix<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u32s.span)
}
pub fn walk_u64_suffix<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u64s: &mut ast::U64Suffix<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u64s.span)
}
pub fn walk_field_suffix<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
fs: &mut ast::FieldSuffix<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut fs.span)
}
pub fn walk_boolean_literal_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ble: &mut ast::BooleanLiteralExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut ble.span)
}
pub fn walk_hex_literal_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
hle: &mut ast::HexLiteralExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_hex_number_expression(&mut hle.value)?;
visitor.visit_span(&mut hle.span)
}
pub fn walk_hex_number_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
hne: &mut ast::HexNumberExpression<'ast>,
) -> ZVisitorResult {
use ast::HexNumberExpression::*;
match hne {
U8(u8e) => visitor.visit_u8_number_expression(u8e),
U16(u16e) => visitor.visit_u16_number_expression(u16e),
U32(u32e) => visitor.visit_u32_number_expression(u32e),
U64(u64e) => visitor.visit_u64_number_expression(u64e),
}
}
pub fn walk_u8_number_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u8e: &mut ast::U8NumberExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u8e.span)
}
pub fn walk_u16_number_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u16e: &mut ast::U16NumberExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u16e.span)
}
pub fn walk_u32_number_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u32e: &mut ast::U32NumberExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u32e.span)
}
pub fn walk_u64_number_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u64e: &mut ast::U64NumberExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u64e.span)
}
pub fn walk_underscore<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
u: &mut ast::Underscore<'ast>,
) -> ZVisitorResult {
visitor.visit_span(&mut u.span)
}
pub fn walk_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
expr: &mut ast::Expression<'ast>,
) -> ZVisitorResult {
use ast::Expression::*;
match expr {
Ternary(te) => visitor.visit_ternary_expression(te),
Binary(be) => visitor.visit_binary_expression(be),
Unary(ue) => visitor.visit_unary_expression(ue),
Postfix(pe) => visitor.visit_postfix_expression(pe),
Identifier(ie) => visitor.visit_identifier_expression(ie),
Literal(le) => visitor.visit_literal_expression(le),
InlineArray(iae) => visitor.visit_inline_array_expression(iae),
InlineStruct(ise) => visitor.visit_inline_struct_expression(ise),
ArrayInitializer(aie) => visitor.visit_array_initializer_expression(aie),
}
}
pub fn walk_ternary_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
te: &mut ast::TernaryExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_expression(&mut te.first)?;
visitor.visit_expression(&mut te.second)?;
visitor.visit_expression(&mut te.third)?;
visitor.visit_span(&mut te.span)
}
pub fn walk_binary_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
be: &mut ast::BinaryExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_binary_operator(&mut be.op)?;
visitor.visit_expression(&mut be.left)?;
visitor.visit_expression(&mut be.right)?;
visitor.visit_span(&mut be.span)
}
pub fn walk_unary_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ue: &mut ast::UnaryExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_unary_operator(&mut ue.op)?;
visitor.visit_expression(&mut ue.expression)?;
visitor.visit_span(&mut ue.span)
}
pub fn walk_unary_operator<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
uo: &mut ast::UnaryOperator,
) -> ZVisitorResult {
use ast::UnaryOperator::*;
match uo {
Pos(po) => visitor.visit_pos_operator(po),
Neg(ne) => visitor.visit_neg_operator(ne),
Not(no) => visitor.visit_not_operator(no),
}
}
pub fn walk_postfix_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
pe: &mut ast::PostfixExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut pe.id)?;
pe.accesses
.iter_mut()
.try_for_each(|a| visitor.visit_access(a))?;
visitor.visit_span(&mut pe.span)
}
pub fn walk_access<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
acc: &mut ast::Access<'ast>,
) -> ZVisitorResult {
use ast::Access::*;
match acc {
Call(ca) => visitor.visit_call_access(ca),
Select(aa) => visitor.visit_array_access(aa),
Member(ma) => visitor.visit_member_access(ma),
}
}
pub fn walk_call_access<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ca: &mut ast::CallAccess<'ast>,
) -> ZVisitorResult {
if let Some(eg) = &mut ca.explicit_generics {
visitor.visit_explicit_generics(eg)?;
}
visitor.visit_arguments(&mut ca.arguments)?;
visitor.visit_span(&mut ca.span)
}
pub fn walk_arguments<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
args: &mut ast::Arguments<'ast>,
) -> ZVisitorResult {
args.expressions
.iter_mut()
.try_for_each(|e| visitor.visit_expression(e))?;
visitor.visit_span(&mut args.span)
}
pub fn walk_array_access<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
aa: &mut ast::ArrayAccess<'ast>,
) -> ZVisitorResult {
visitor.visit_range_or_expression(&mut aa.expression)?;
visitor.visit_span(&mut aa.span)
}
pub fn walk_range_or_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
roe: &mut ast::RangeOrExpression<'ast>,
) -> ZVisitorResult {
use ast::RangeOrExpression::*;
match roe {
Range(r) => visitor.visit_range(r),
Expression(e) => visitor.visit_expression(e),
}
}
pub fn walk_range<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
rng: &mut ast::Range<'ast>,
) -> ZVisitorResult {
if let Some(f) = &mut rng.from {
visitor.visit_from_expression(f)?;
}
if let Some(t) = &mut rng.to {
visitor.visit_to_expression(t)?;
}
visitor.visit_span(&mut rng.span)
}
pub fn walk_from_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
from: &mut ast::FromExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_expression(&mut from.0)
}
pub fn walk_to_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
to: &mut ast::ToExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_expression(&mut to.0)
}
pub fn walk_member_access<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ma: &mut ast::MemberAccess<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut ma.id)?;
visitor.visit_span(&mut ma.span)
}
pub fn walk_inline_array_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
iae: &mut ast::InlineArrayExpression<'ast>,
) -> ZVisitorResult {
iae.expressions
.iter_mut()
.try_for_each(|e| visitor.visit_spread_or_expression(e))?;
visitor.visit_span(&mut iae.span)
}
pub fn walk_spread_or_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
soe: &mut ast::SpreadOrExpression<'ast>,
) -> ZVisitorResult {
use ast::SpreadOrExpression::*;
match soe {
Spread(s) => visitor.visit_spread(s),
Expression(e) => visitor.visit_expression(e),
}
}
pub fn walk_spread<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
spread: &mut ast::Spread<'ast>,
) -> ZVisitorResult {
visitor.visit_expression(&mut spread.expression)?;
visitor.visit_span(&mut spread.span)
}
pub fn walk_inline_struct_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ise: &mut ast::InlineStructExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut ise.ty)?;
ise.members
.iter_mut()
.try_for_each(|m| visitor.visit_inline_struct_member(m))?;
visitor.visit_span(&mut ise.span)
}
pub fn walk_inline_struct_member<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ism: &mut ast::InlineStructMember<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut ism.id)?;
visitor.visit_expression(&mut ism.expression)?;
visitor.visit_span(&mut ism.span)
}
pub fn walk_array_initializer_expression<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
aie: &mut ast::ArrayInitializerExpression<'ast>,
) -> ZVisitorResult {
visitor.visit_expression(&mut aie.value)?;
visitor.visit_expression(&mut aie.count)?;
visitor.visit_span(&mut aie.span)
}
pub fn walk_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
stmt: &mut ast::Statement<'ast>,
) -> ZVisitorResult {
use ast::Statement::*;
match stmt {
Return(r) => visitor.visit_return_statement(r),
Definition(d) => visitor.visit_definition_statement(d),
Assertion(a) => visitor.visit_assertion_statement(a),
Iteration(i) => visitor.visit_iteration_statement(i),
}
}
pub fn walk_return_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
ret: &mut ast::ReturnStatement<'ast>,
) -> ZVisitorResult {
ret.expressions
.iter_mut()
.try_for_each(|e| visitor.visit_expression(e))?;
visitor.visit_span(&mut ret.span)
}
pub fn walk_definition_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
def: &mut ast::DefinitionStatement<'ast>,
) -> ZVisitorResult {
def.lhs
.iter_mut()
.try_for_each(|l| visitor.visit_typed_identifier_or_assignee(l))?;
visitor.visit_expression(&mut def.expression)?;
visitor.visit_span(&mut def.span)
}
pub fn walk_typed_identifier_or_assignee<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
tioa: &mut ast::TypedIdentifierOrAssignee<'ast>,
) -> ZVisitorResult {
use ast::TypedIdentifierOrAssignee::*;
match tioa {
Assignee(a) => visitor.visit_assignee(a),
TypedIdentifier(ti) => visitor.visit_typed_identifier(ti),
}
}
pub fn walk_typed_identifier<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
tid: &mut ast::TypedIdentifier<'ast>,
) -> ZVisitorResult {
visitor.visit_type(&mut tid.ty)?;
visitor.visit_identifier_expression(&mut tid.identifier)?;
visitor.visit_span(&mut tid.span)
}
pub fn walk_assignee<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
asgn: &mut ast::Assignee<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut asgn.id)?;
asgn.accesses
.iter_mut()
.try_for_each(|a| visitor.visit_assignee_access(a))?;
visitor.visit_span(&mut asgn.span)
}
pub fn walk_assignee_access<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
acc: &mut ast::AssigneeAccess<'ast>,
) -> ZVisitorResult {
use ast::AssigneeAccess::*;
match acc {
Select(aa) => visitor.visit_array_access(aa),
Member(ma) => visitor.visit_member_access(ma),
}
}
pub fn walk_assertion_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
asrt: &mut ast::AssertionStatement<'ast>,
) -> ZVisitorResult {
visitor.visit_expression(&mut asrt.expression)?;
visitor.visit_span(&mut asrt.span)
}
pub fn walk_iteration_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
iter: &mut ast::IterationStatement<'ast>,
) -> ZVisitorResult {
visitor.visit_type(&mut iter.ty)?;
visitor.visit_identifier_expression(&mut iter.index)?;
visitor.visit_expression(&mut iter.from)?;
visitor.visit_expression(&mut iter.to)?;
iter.statements
.iter_mut()
.try_for_each(|s| visitor.visit_statement(s))?;
visitor.visit_span(&mut iter.span)
}

View File

@@ -0,0 +1,354 @@
//! AST Walker for zokrates_pest_ast
use super::super::term::Ty;
use super::walkfns::*;
use super::{ZVisitorError, ZVisitorMut, ZVisitorResult};
use zokrates_pest_ast as ast;
pub(in super::super) struct ZConstLiteralRewriter {
to_ty: Option<Ty>,
found: bool,
}
impl ZConstLiteralRewriter {
pub fn new(to_ty: Option<Ty>) -> Self {
Self {
to_ty,
found: false,
}
}
#[allow(dead_code)]
pub fn found(&self) -> bool {
self.found
}
pub fn replace(&mut self, to_ty: Option<Ty>) -> Option<Ty> {
std::mem::replace(&mut self.to_ty, to_ty)
}
}
impl<'ast> ZVisitorMut<'ast> for ZConstLiteralRewriter {
/*
Expressions can be any of:
Binary(BinaryExpression<'ast>),
-> depends on operator. e.g., == outputs Bool but takes in arbitrary l and r
Ternary(TernaryExpression<'ast>)
-> first expr is Bool, other two are expected type
Unary(UnaryExpression<'ast>),
-> no change to expected type: each sub-expr should have the expected type
Postfix(PostfixExpression<'ast>),
-> cannot type Access results, but descend into sub-exprs to type array indices
Identifier(IdentifierExpression<'ast>),
-> nothing to do (terminal)
Literal(LiteralExpression<'ast>),
-> literal should have same type as expression
InlineArray(InlineArrayExpression<'ast>),
-> descend into SpreadOrExpression, looking for either array or element type
InlineStruct(InlineStructExpression<'ast>),
-> check that struct types are equal
ArrayInitializer(ArrayInitializerExpression<'ast>),
-> value should have type of value inside Array
-> count should have type Field
*/
fn visit_ternary_expression(
&mut self,
te: &mut ast::TernaryExpression<'ast>,
) -> ZVisitorResult {
// first expression in a ternary should have type bool
let to_ty = self.replace(Some(Ty::Bool));
self.visit_expression(&mut te.first)?;
self.replace(to_ty);
self.visit_expression(&mut te.second)?;
self.visit_expression(&mut te.third)?;
self.visit_span(&mut te.span)
}
fn visit_binary_expression(&mut self, be: &mut ast::BinaryExpression<'ast>) -> ZVisitorResult {
let (ty_l, ty_r) = {
use ast::BinaryOperator::*;
match be.op {
Pow | RightShift | LeftShift => (self.to_ty.clone(), Some(Ty::Uint(32))),
Eq | NotEq | Lt | Gt | Lte | Gte => (None, None),
_ => (self.to_ty.clone(), self.to_ty.clone()),
}
};
self.visit_binary_operator(&mut be.op)?;
let to_ty = self.replace(ty_l);
self.visit_expression(&mut be.left)?;
self.replace(ty_r);
self.visit_expression(&mut be.right)?;
self.replace(to_ty);
self.visit_span(&mut be.span)
}
fn visit_decimal_literal_expression(
&mut self,
dle: &mut ast::DecimalLiteralExpression<'ast>,
) -> ZVisitorResult {
if dle.suffix.is_none() && self.to_ty.is_some() {
self.found = true;
dle.suffix.replace(match self.to_ty.as_ref().unwrap() {
Ty::Uint(8) => Ok(ast::DecimalSuffix::U8(ast::U8Suffix {
span: dle.span.clone(),
})),
Ty::Uint(16) => Ok(ast::DecimalSuffix::U16(ast::U16Suffix {
span: dle.span.clone(),
})),
Ty::Uint(32) => Ok(ast::DecimalSuffix::U32(ast::U32Suffix {
span: dle.span.clone(),
})),
Ty::Uint(64) => Ok(ast::DecimalSuffix::U64(ast::U64Suffix {
span: dle.span.clone(),
})),
Ty::Uint(_) => Err(
"ZConstLiteralRewriter: Uint size must be divisible by 8".to_string(),
),
Ty::Field => Ok(ast::DecimalSuffix::Field(ast::FieldSuffix {
span: dle.span.clone(),
})),
_ => Err(
"ZConstLiteralRewriter: rewriting DecimalLiteralExpression to incompatible type"
.to_string(),
),
}?);
}
walk_decimal_literal_expression(self, dle)
}
fn visit_array_initializer_expression(
&mut self,
aie: &mut ast::ArrayInitializerExpression<'ast>,
) -> ZVisitorResult {
if self.to_ty.is_some() {
if let Ty::Array(_, arr_ty) = self.to_ty.clone().unwrap() {
// ArrayInitializerExpression::value should match arr_ty
let to_ty = self.replace(Some(*arr_ty));
self.visit_expression(&mut aie.value)?;
self.to_ty = to_ty;
} else {
return Err(
"ZConstLiteralRewriter: rewriting ArrayInitializerExpression to non-Array type"
.to_string()
.into(),
);
}
}
// always rewrite ArrayInitializerExpression::count literals to type U32
let to_ty = self.replace(Some(Ty::Uint(32)));
self.visit_expression(&mut aie.count)?;
self.to_ty = to_ty;
self.visit_span(&mut aie.span)
}
fn visit_inline_struct_expression(
&mut self,
ise: &mut ast::InlineStructExpression<'ast>,
) -> ZVisitorResult {
self.visit_identifier_expression(&mut ise.ty)?;
let to_ty = self.replace(None);
let ty_map = if let Some(t) = to_ty.as_ref() {
if let Ty::Struct(name, ty_map) = t {
if name != &ise.ty.value {
Err(format!("ZConstLiteralRewriter: got struct {}, expected {} visiting inline struct expression", &ise.ty.value, name))
} else {
Ok(Some(ty_map.clone()))
}
} else {
Err(
"ZConstLiteralRewriter: rewriting InlineStructExpression to non-Struct type"
.to_string(),
)
}
} else {
Ok(None)
}?;
if let Some(ty_map) = ty_map {
let mut ty_map = ty_map.into_map();
let (mem, str_name) = (&mut ise.members, &ise.ty.value);
mem.iter_mut()
.try_for_each(|m| ty_map
.remove(&m.id.value)
.ok_or_else(|| ZVisitorError(format!(
"ZConstLiteralRewriter: no member {} in struct {}, or duplicate member in inline expression",
&m.id.value,
str_name,
)))
.and_then(|ty| {
self.to_ty = Some(ty);
self.visit_inline_struct_member(m)
})
)?;
if !ty_map.is_empty() {
return Err(format!(
"ZConstLiteralRewriter: inline expression for struct {} has extra fields: {:?}",
&ise.ty.value,
ty_map.keys().collect::<Vec<_>>(),
)
.into());
}
} else {
ise.members
.iter_mut()
.try_for_each(|m| self.visit_inline_struct_member(m))?;
}
self.to_ty = to_ty;
self.visit_span(&mut ise.span)
}
fn visit_inline_array_expression(
&mut self,
iae: &mut ast::InlineArrayExpression<'ast>,
) -> ZVisitorResult {
let mut inner_ty = if let Some(t) = self.to_ty.as_ref() {
if let Ty::Array(_, arr_ty) = t.clone() {
Ok(Some(*arr_ty))
} else {
Err(
"ZConstLiteralRewriter: rewriting InlineArrayExpression to non-Array type"
.to_string(),
)
}
} else {
Ok(None)
}?;
for e in iae.expressions.iter_mut() {
use ast::SpreadOrExpression::*;
match e {
Spread(s) => {
// a spread expression is an array; array type should match (we ignore number)
self.visit_spread(s)?;
}
Expression(e) => {
// an expression here is an individual array element, inner type should match
inner_ty = self.replace(inner_ty);
self.visit_expression(e)?;
inner_ty = self.replace(inner_ty);
}
}
}
self.visit_span(&mut iae.span)
}
fn visit_postfix_expression(
&mut self,
pe: &mut ast::PostfixExpression<'ast>,
) -> ZVisitorResult {
self.visit_identifier_expression(&mut pe.id)?;
// descend into accesses. we do not know expected type for these expressions
// (but we may end up descending into an ArrayAccess, which would get typed)
let to_ty = self.replace(None);
pe.accesses
.iter_mut()
.try_for_each(|a| self.visit_access(a))?;
self.to_ty = to_ty;
self.visit_span(&mut pe.span)
}
fn visit_array_type(&mut self, aty: &mut ast::ArrayType<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() {
if let Ty::Array(_, arr_ty) = self.to_ty.clone().unwrap() {
// ArrayType::value should match arr_ty
let to_ty = self.replace(Some(*arr_ty));
self.visit_basic_or_struct_type(&mut aty.ty)?;
self.to_ty = to_ty;
} else {
return Err(
"ZConstLiteralRewriter: rewriting ArrayType to non-Array type"
.to_string()
.into(),
);
}
}
// always rewrite ArrayType::dimensions literals to type U32
let to_ty = self.replace(Some(Ty::Uint(32)));
aty.dimensions
.iter_mut()
.try_for_each(|d| self.visit_expression(d))?;
self.to_ty = to_ty;
self.visit_span(&mut aty.span)
}
fn visit_explicit_generics(&mut self, eg: &mut ast::ExplicitGenerics<'ast>) -> ZVisitorResult {
// always rewrite ConstantGenericValue literals to type U32
let to_ty = self.replace(Some(Ty::Uint(32)));
walk_explicit_generics(self, eg)?;
self.to_ty = to_ty;
Ok(())
}
fn visit_field_type(&mut self, fty: &mut ast::FieldType<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() && !matches!(self.to_ty, Some(Ty::Field)) {
return Err("ZConstLiteralRewriter: Field type mismatch"
.to_string()
.into());
}
walk_field_type(self, fty)
}
fn visit_boolean_type(&mut self, bty: &mut ast::BooleanType<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() && !matches!(self.to_ty, Some(Ty::Bool)) {
return Err("ZConstLiteralRewriter: Bool type mismatch"
.to_string()
.into());
}
walk_boolean_type(self, bty)
}
fn visit_u8_type(&mut self, u8ty: &mut ast::U8Type<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() && !matches!(self.to_ty, Some(Ty::Uint(8))) {
return Err("ZConstLiteralRewriter: u8 type mismatch".to_string().into());
}
walk_u8_type(self, u8ty)
}
fn visit_u16_type(&mut self, u16ty: &mut ast::U16Type<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() && !matches!(self.to_ty, Some(Ty::Uint(16))) {
return Err("ZConstLiteralRewriter: u16 type mismatch"
.to_string()
.into());
}
walk_u16_type(self, u16ty)
}
fn visit_u32_type(&mut self, u32ty: &mut ast::U32Type<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() && !matches!(self.to_ty, Some(Ty::Uint(32))) {
return Err("ZConstLiteralRewriter: u32 type mismatch"
.to_string()
.into());
}
walk_u32_type(self, u32ty)
}
fn visit_u64_type(&mut self, u64ty: &mut ast::U64Type<'ast>) -> ZVisitorResult {
if self.to_ty.is_some() && !matches!(self.to_ty, Some(Ty::Uint(64))) {
return Err("ZConstLiteralRewriter: u64 type mismatch"
.to_string()
.into());
}
walk_u64_type(self, u64ty)
}
}

View File

@@ -0,0 +1,455 @@
//! Generic parameter inference
use super::super::term::{cond, const_val, Ty, T};
use super::super::{span_to_string, ZGen};
use crate::ir::term::{bv_lit, leaf_term, term, BoolNaryOp, Op, Sort, Term, Value};
use crate::target::smt::find_unique_model;
use lazy_static::lazy_static;
use log::debug;
use std::collections::HashMap;
use std::path::Path;
use std::sync::RwLock;
use zokrates_pest_ast as ast;
lazy_static! {
static ref CACHE: RwLock<HashMap<Term, HashMap<String, T>>> = RwLock::new(HashMap::new());
}
pub(in super::super) struct ZGenericInf<'ast, 'gen, const IS_CNST: bool> {
zgen: &'gen ZGen<'ast>,
fdef: &'gen ast::FunctionDefinition<'ast>,
gens: &'gen [ast::IdentifierExpression<'ast>],
sfx: String,
constr: Option<Term>,
}
impl<'ast, 'gen, const IS_CNST: bool> ZGenericInf<'ast, 'gen, IS_CNST> {
pub fn new(
zgen: &'gen ZGen<'ast>,
fdef: &'gen ast::FunctionDefinition<'ast>,
path: &Path,
name: &str,
) -> Self {
let gens = fdef.generics.as_ref();
let mut path_str = "___".to_string();
path_str.push_str(&path.to_string_lossy());
path_str.push_str("___");
path_str.push_str(name);
path_str.push_str("___");
path_str.push_str(&fdef.id.value);
let sfx = make_sfx(path_str, &fdef.id.value);
Self {
zgen,
fdef,
gens,
sfx,
constr: None,
}
}
fn is_generic_var(&self, var: &str) -> bool {
self.gens.iter().any(|id| id.value == var)
}
fn add_constraint(&mut self, lhs: Term, rhs: Term) {
let new_term = term![Op::Eq; lhs, rhs];
let new_term = if let Some(old_term) = self.constr.take() {
term![Op::BoolNaryOp(BoolNaryOp::And); old_term, new_term]
} else {
new_term
};
self.constr = Some(new_term);
}
fn const_id_(&self, id: &ast::IdentifierExpression<'ast>) -> Result<T, String> {
self.zgen
.identifier_impl_::<IS_CNST>(id)
.and_then(const_val)
}
pub fn unify_generic<ATIter: Iterator<Item = Ty>>(
&mut self,
egv: &[ast::ConstantGenericValue<'ast>],
rty: Option<Ty>,
arg_tys: ATIter,
) -> Result<HashMap<String, T>, String> {
debug!("ZGenericInf::unify_generic");
use ast::ConstantGenericValue as CGV;
self.constr = None;
self.gens = &self.fdef.generics[..];
// early returns: monomorphized or not generic
if self.gens.is_empty() {
debug!("done (no generics)");
return Ok(HashMap::new());
}
if egv.len() == self.gens.len() && !egv.iter().any(|cgv| matches!(cgv, CGV::Underscore(_)))
{
match self
.zgen
.egvs_impl_::<IS_CNST>(egv, self.fdef.generics.clone())
{
Ok(gens) if gens.len() == self.gens.len() => {
debug!("done (explicit generics)");
return Ok(gens);
}
_ => (),
};
}
// 1. build up the already-known generics
for (cgv, id) in egv.iter().zip(self.fdef.generics.iter()) {
if let Some(v) = match cgv {
CGV::Underscore(_) => None,
CGV::Value(v) => Some(self.zgen.literal_(v)),
CGV::Identifier(i) => Some(self.const_id_(i)),
} {
let v = v?;
let var = make_varname(&id.value, &self.sfx);
let val = match v.ty {
Ty::Uint(32) => Ok(v.term),
ty => Err(format!(
"ZGenericInf: ConstantGenericValue for {} had type {}, expected u32",
&id.value, ty
)),
}?;
self.add_constraint(var, val);
}
}
// 2. for each argument, update the const generic values
for (pty, arg_ty) in self.fdef.parameters.iter().map(|p| &p.ty).zip(arg_tys) {
self.fdef_gen_ty(arg_ty, pty)?;
// bracketing invariant
assert!(self.gens == &self.fdef.generics[..]);
assert!(self.sfx.ends_with(&self.fdef.id.value));
}
// 3. unify the return type
match (rty, self.fdef.returns.first()) {
(Some(rty), Some(ret)) => self.fdef_gen_ty(rty, ret),
(Some(rty), None) if rty != Ty::Bool => Err(format!(
"Function {} expected implicit Bool ret, but got {}",
&self.fdef.id.value, rty
)),
(Some(_), None) => Ok(()),
(None, _) => Ok(()),
}?;
// bracketing invariant
assert!(self.gens == &self.fdef.generics[..]);
assert!(self.sfx.ends_with(&self.fdef.id.value));
// 4. run the solver on the term stack, if it's not already cached
if let Some(res) = self
.constr
.as_ref()
.and_then(|t| CACHE.read().unwrap().get(t).cloned())
{
assert!(self.gens.len() == res.len());
assert!(self.gens.iter().all(|g| res.contains_key(&g.value)));
debug!("done (cached result for {})", &self.sfx);
return Ok(res);
}
let g_names = self
.gens
.iter()
.map(|gid| make_varname_str(&gid.value, &self.sfx))
.collect::<Vec<_>>();
let mut solved = self
.constr
.as_ref()
.and_then(|t| find_unique_model(t, g_names.clone()))
.unwrap_or_else(HashMap::new);
// 5. extract the assignments from the solver result
let mut res = HashMap::with_capacity(g_names.len());
assert_eq!(g_names.len(), self.gens.len());
g_names
.into_iter()
.enumerate()
.for_each(|(idx, mut g_name)| {
if let Some(g_val) = solved.remove(&g_name) {
match &g_val {
Value::BitVector(bv) => assert!(bv.width() == 32),
_ => unreachable!(),
}
g_name.truncate(self.gens[idx].value.len());
g_name.shrink_to_fit();
assert!(res
.insert(g_name, T::new(Ty::Uint(32), term![Op::Const(g_val)]))
.is_none());
}
});
if self.constr.is_some() {
CACHE
.write()
.unwrap()
.insert(self.constr.take().unwrap(), res.clone());
}
debug!("done (finished inference)");
Ok(res)
}
fn fdef_gen_ty(&mut self, arg_ty: Ty, def_ty: &ast::Type<'ast>) -> Result<(), String> {
use ast::Type as TT;
match def_ty {
TT::Basic(dty_b) => self.fdef_gen_ty_basic(arg_ty, dty_b),
TT::Array(dty_a) => self.fdef_gen_ty_array(arg_ty, dty_a),
TT::Struct(dty_s) => self.fdef_gen_ty_struct(arg_ty, dty_s),
}
}
fn fdef_gen_ty_basic(&self, arg_ty: Ty, bas_ty: &ast::BasicType<'ast>) -> Result<(), String> {
// XXX(q) dispatch to const_ or not? does not seem necessary because arg is Type::Basic
if arg_ty
!= self
.zgen
.type_impl_::<IS_CNST>(&ast::Type::Basic(bas_ty.clone()))?
{
Err(format!(
"Type mismatch unifying generics: got {}, decl was {:?}",
arg_ty, bas_ty
))
} else {
Ok(())
}
}
fn fdef_gen_ty_array(
&mut self,
mut arg_ty: Ty,
def_ty: &ast::ArrayType<'ast>,
) -> Result<(), String> {
if !matches!(arg_ty, Ty::Array(_, _)) {
return Err(format!(
"Type mismatch unifying generics: got {}, decl was Array",
arg_ty
));
}
// iterate through array dimensions, unifying each with fn decl
let mut dim_off = 0;
loop {
match arg_ty {
Ty::Array(arg_dim, nty) => {
// make sure that we expect at least one more array dim
if dim_off >= def_ty.dimensions.len() {
return Err(format!(
"Type mismatch: got >={}-dim array, decl was {} dims",
dim_off,
def_ty.dimensions.len(),
));
}
// unify actual dimension with dim expression
self.fdef_gen_ty_expr(arg_dim, &def_ty.dimensions[dim_off])?;
// iterate
dim_off += 1;
arg_ty = *nty;
}
nty => {
// make sure we didn't expect any more array dims!
if dim_off != def_ty.dimensions.len() {
return Err(format!(
"Type mismatch: got {}-dim array, decl had {} dims",
dim_off,
def_ty.dimensions.len(),
));
}
arg_ty = nty;
break;
}
};
}
use ast::BasicOrStructType as BoST;
match &def_ty.ty {
BoST::Struct(dty_s) => self.fdef_gen_ty_struct(arg_ty, dty_s),
BoST::Basic(dty_b) => self.fdef_gen_ty_basic(arg_ty, dty_b),
}
}
fn fdef_gen_ty_struct(
&mut self,
arg_ty: Ty,
def_ty: &ast::StructType<'ast>,
) -> Result<(), String> {
// check type and struct name
let mut aty_map = match arg_ty {
Ty::Struct(aty_n, aty_map) if aty_n == def_ty.id.value => Ok(aty_map.into_map()),
Ty::Struct(aty_n, _) => Err(format!(
"Type mismatch: got struct {}, decl was struct {}",
&aty_n, &def_ty.id.value
)),
arg_ty => Err(format!(
"Type mismatch unifying generics: got {}, decl was Struct",
arg_ty
)),
}?;
let strdef = self
.zgen
.get_struct(&def_ty.id.value)
.ok_or_else(|| format!("ZGenericInf: no such struct {}", &def_ty.id.value))?;
// short-circuit if there are no generics in this struct
if strdef.generics.is_empty() {
return if def_ty.explicit_generics.is_some() {
Err(format!(
"Unifying generics: got explicit generics for non-generic struct type {}:\n{}",
&def_ty.id.value,
span_to_string(&def_ty.span),
))
} else {
Ok(())
};
}
// struct type in fn defn must provide explicit generics
use ast::ConstantGenericValue as CGV;
if def_ty
.explicit_generics
.as_ref()
.map(|eg| eg.values.iter().any(|eg| matches!(eg, CGV::Underscore(_))))
.unwrap_or(true)
{
return Err(format!(
"Cannot infer generic values for struct {} arg to function {}\nGeneric structs in fn defns must have explicit generics (in terms of fn generic vars)",
&def_ty.id.value,
&self.fdef.id.value,
));
}
// 1. set up mapping from outer explicit generics to inner explicit generics
let new_sfx = make_sfx(self.sfx.clone(), &def_ty.id.value);
def_ty
.explicit_generics
.as_ref()
.unwrap()
.values
.iter()
.zip(strdef.generics.iter())
.try_for_each::<_, Result<(), String>>(|(cgv, id)| {
let sgid = make_varname(&id.value, &new_sfx);
let val = match cgv {
CGV::Underscore(_) => unreachable!(),
CGV::Value(le) => u32_term(self.zgen.literal_(le)?)?,
CGV::Identifier(id) => {
if self.is_generic_var(&id.value) {
make_varname(&id.value, &self.sfx)
} else {
u32_term(self.const_id_(id)?)?
}
}
};
self.add_constraint(sgid, val);
Ok(())
})?;
// 2. walk through struct def to generate constraints on inner explicit generics
let old_sfx = std::mem::replace(&mut self.sfx, new_sfx);
let old_gens = std::mem::replace(&mut self.gens, &strdef.generics[..]);
for ast::StructField { ty, id, .. } in strdef.fields.iter() {
if let Some(t) = aty_map.remove(&id.value) {
self.fdef_gen_ty(t, ty)?;
} else {
return Err(format!(
"ZGenericInf: missing member {} in struct {} value",
&id.value, &def_ty.id.value,
));
}
}
if !aty_map.is_empty() {
return Err(format!(
"ZGenericInf: struct {} value had extra members: {:?}",
&def_ty.id.value,
aty_map.keys().collect::<Vec<_>>(),
));
}
// 3. pop stack and continue
self.gens = old_gens;
self.sfx = old_sfx;
Ok(())
}
// turn an expr into a set of terms and assert equality
fn fdef_gen_ty_expr(
&mut self,
arg_dim: usize,
def_exp: &ast::Expression<'ast>,
) -> Result<(), String> {
let t = u32_term(self.expr(def_exp)?)?;
self.add_constraint(bv_lit(arg_dim, 32), t);
Ok(())
}
fn expr(&self, expr: &ast::Expression<'ast>) -> Result<T, String> {
use ast::Expression::*;
match expr {
Ternary(te) => {
let cnd = self.expr(&te.first)?;
let csq = self.expr(&te.second)?;
let alt = self.expr(&te.third)?;
cond(cnd, csq, alt)
}
Binary(be) => {
let lhs = self.expr(&be.left)?;
let rhs = self.expr(&be.right)?;
let op = self.zgen.bin_op(&be.op);
op(lhs, rhs)
}
Unary(ue) => {
let exp = self.expr(&ue.expression)?;
let op = self.zgen.unary_op(&ue.op);
op(exp)
}
Identifier(id) => {
if self.is_generic_var(&id.value) {
Ok(T::new(Ty::Uint(32), make_varname(&id.value, &self.sfx)))
} else {
self.const_id_(id)
}
}
Literal(le) => self.zgen.literal_(le),
Postfix(_) => Err("ZGenericInf: got Postfix in array dim expr (unimpl)".into()),
InlineArray(_) => Err("ZGenericInf: got InlineArray in array dim expr (unimpl)".into()),
InlineStruct(_) => {
Err("ZGenericInf: got InlineStruct in array dim expr (unimpl)".into())
}
ArrayInitializer(_) => {
Err("ZGenericInf: got ArrayInitializer in array dim expr (unimpl)".into())
}
}
}
}
fn u32_term(t: T) -> Result<Term, String> {
match t.ty {
Ty::Uint(32) => Ok(t.term),
ty => Err(format!(
"ZGenericInf: got {} for expr, expected T::Uint(32)",
ty
)),
}
}
fn make_sfx(mut base: String, sfx: &str) -> String {
base.push('_');
base.push_str(sfx);
base
}
fn make_varname_str(id: &str, sfx: &str) -> String {
let mut tmp = String::from(id);
tmp.push('_');
tmp.push_str(sfx);
tmp
}
fn make_varname(id: &str, sfx: &str) -> Term {
let tmp = make_varname_str(id, sfx);
term![Op::Var(tmp, Sort::BitVector(32))]
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
//! AST Walker for zokrates_pest_ast
use super::{ZVisitorMut, ZVisitorResult};
use std::collections::HashMap;
use zokrates_pest_ast as ast;
pub(super) struct ZExpressionRewriter<'ast> {
gvmap: HashMap<String, ast::Expression<'ast>>,
}
impl<'ast> ZExpressionRewriter<'ast> {
pub fn new(gvmap: HashMap<String, ast::Expression<'ast>>) -> Self {
Self { gvmap }
}
}
impl<'ast> ZVisitorMut<'ast> for ZExpressionRewriter<'ast> {
fn visit_expression(&mut self, expr: &mut ast::Expression<'ast>) -> ZVisitorResult {
use ast::Expression::*;
match expr {
Ternary(te) => self.visit_ternary_expression(te),
Binary(be) => self.visit_binary_expression(be),
Unary(ue) => self.visit_unary_expression(ue),
Postfix(pe) => self.visit_postfix_expression(pe),
Literal(le) => self.visit_literal_expression(le),
InlineArray(iae) => self.visit_inline_array_expression(iae),
InlineStruct(ise) => self.visit_inline_struct_expression(ise),
ArrayInitializer(aie) => self.visit_array_initializer_expression(aie),
Identifier(ie) => {
if let Some(e) = self.gvmap.get(&ie.value) {
*expr = e.clone();
Ok(())
} else {
self.visit_identifier_expression(ie)
}
}
}
}
}

View File

@@ -0,0 +1,337 @@
//! AST Walker for zokrates_pest_ast
use super::super::eqtype::*;
use super::super::{bos_to_type, ZVisitorError, ZVisitorMut, ZVisitorResult};
use super::ZStatementWalker;
use zokrates_pest_ast as ast;
pub(super) struct ZExpressionTyper<'ast, 'ret, 'wlk> {
walker: &'wlk ZStatementWalker<'ast, 'ret>,
ty: Option<ast::Type<'ast>>,
}
impl<'ast, 'ret, 'wlk> ZExpressionTyper<'ast, 'ret, 'wlk> {
pub fn new(walker: &'wlk ZStatementWalker<'ast, 'ret>) -> Self {
Self { walker, ty: None }
}
pub fn take(&mut self) -> Option<ast::Type<'ast>> {
self.ty.take()
}
fn visit_identifier_expression_t(
&mut self,
ie: &mut ast::IdentifierExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
self.walker.lookup_type(ie).map(|t| {
self.ty.replace(t);
})
}
fn arrayize(
&self,
ty: ast::Type<'ast>,
cnt: ast::Expression<'ast>,
spn: &ast::Span<'ast>,
) -> ast::ArrayType<'ast> {
use ast::Type::*;
match ty {
Array(mut aty) => {
aty.dimensions.insert(0, cnt);
aty
}
Basic(bty) => ast::ArrayType {
ty: ast::BasicOrStructType::Basic(bty),
dimensions: vec![cnt],
span: spn.clone(),
},
Struct(sty) => ast::ArrayType {
ty: ast::BasicOrStructType::Struct(sty),
dimensions: vec![cnt],
span: spn.clone(),
},
}
}
}
impl<'ast, 'ret, 'wlk> ZVisitorMut<'ast> for ZExpressionTyper<'ast, 'ret, 'wlk> {
fn visit_expression(&mut self, expr: &mut ast::Expression<'ast>) -> ZVisitorResult {
use ast::Expression::*;
if self.ty.is_some() {
return Err(ZVisitorError(
"ZExpressionTyper: type found at expression entry?".to_string(),
));
}
match expr {
Ternary(te) => self.visit_ternary_expression(te),
Binary(be) => self.visit_binary_expression(be),
Unary(ue) => self.visit_unary_expression(ue),
Postfix(pe) => self.visit_postfix_expression(pe),
Identifier(ie) => self.visit_identifier_expression_t(ie),
Literal(le) => self.visit_literal_expression(le),
InlineArray(iae) => self.visit_inline_array_expression(iae),
InlineStruct(ise) => self.visit_inline_struct_expression(ise),
ArrayInitializer(aie) => self.visit_array_initializer_expression(aie),
}
}
fn visit_ternary_expression(
&mut self,
te: &mut ast::TernaryExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
self.visit_expression(&mut te.second)?;
let ty2 = self.take();
self.visit_expression(&mut te.third)?;
let ty3 = self.take();
match (ty2, ty3) {
(Some(t), None) => self.ty.replace(t),
(None, Some(t)) => self.ty.replace(t),
(Some(t1), Some(t2)) => {
eq_type(&t1, &t2)?;
self.ty.replace(t2)
}
(None, None) => None,
};
Ok(())
}
fn visit_binary_expression(&mut self, be: &mut ast::BinaryExpression<'ast>) -> ZVisitorResult {
use ast::{BasicType::*, BinaryOperator::*, Type::*};
assert!(self.ty.is_none());
match &be.op {
Or | And | Eq | NotEq | Lt | Gt | Lte | Gte => {
self.ty.replace(Basic(Boolean(ast::BooleanType {
span: be.span.clone(),
})));
}
Pow => {
self.ty.replace(Basic(Field(ast::FieldType {
span: be.span.clone(),
})));
}
BitXor | BitAnd | BitOr | RightShift | LeftShift | Add | Sub | Mul | Div | Rem => {
self.visit_expression(&mut be.left)?;
let ty_l = self.take();
self.visit_expression(&mut be.right)?;
let ty_r = self.take();
if let Some(ty) = match (ty_l, ty_r) {
(Some(t), None) => Some(t),
(None, Some(t)) => Some(t),
(Some(t1), Some(t2)) => {
eq_type(&t1, &t2)?;
Some(t2)
}
(None, None) => None,
} {
if !matches!(&ty, Basic(_)) {
return Err(ZVisitorError(
"ZExpressionTyper: got non-Basic type for a binop".to_string(),
));
}
if matches!(&ty, Basic(Boolean(_))) {
return Err(ZVisitorError(
"ZExpressionTyper: got Bool for a binop that cannot support it"
.to_string(),
));
}
if matches!(&be.op, BitXor | BitAnd | BitOr | RightShift | LeftShift)
&& matches!(&ty, Basic(Field(_)))
{
return Err(ZVisitorError(
"ZExpressionTyper: got Field for a binop that cannot support it"
.to_string(),
));
}
self.ty.replace(ty);
}
}
};
Ok(())
}
fn visit_unary_expression(&mut self, ue: &mut ast::UnaryExpression<'ast>) -> ZVisitorResult {
use ast::{BasicType::*, Type::*, UnaryOperator::*};
assert!(self.ty.is_none());
match &ue.op {
Pos(_) | Neg(_) => {
self.visit_expression(&mut ue.expression)?;
if let Some(ty) = &self.ty {
if !matches!(ty, Basic(_)) || matches!(ty, Basic(Boolean(_))) {
return Err(ZVisitorError(
"ZExpressionTyper: got Bool or non-Basic for unary op".to_string(),
));
}
}
}
Not(_) => {
self.ty.replace(Basic(Boolean(ast::BooleanType {
span: ue.span.clone(),
})));
}
}
Ok(())
}
fn visit_boolean_literal_expression(
&mut self,
ble: &mut ast::BooleanLiteralExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
self.ty.replace(ast::Type::Basic(ast::BasicType::Boolean(
ast::BooleanType {
span: ble.span.clone(),
},
)));
Ok(())
}
fn visit_decimal_suffix(&mut self, ds: &mut ast::DecimalSuffix<'ast>) -> ZVisitorResult {
assert!(self.ty.is_none());
use ast::{BasicType::*, DecimalSuffix as DS, Type::*};
match ds {
DS::U8(s) => self.ty.replace(Basic(U8(ast::U8Type {
span: s.span.clone(),
}))),
DS::U16(s) => self.ty.replace(Basic(U16(ast::U16Type {
span: s.span.clone(),
}))),
DS::U32(s) => self.ty.replace(Basic(U32(ast::U32Type {
span: s.span.clone(),
}))),
DS::U64(s) => self.ty.replace(Basic(U64(ast::U64Type {
span: s.span.clone(),
}))),
DS::Field(s) => self.ty.replace(Basic(Field(ast::FieldType {
span: s.span.clone(),
}))),
};
Ok(())
}
fn visit_hex_number_expression(
&mut self,
hne: &mut ast::HexNumberExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
use ast::{BasicType::*, HexNumberExpression as HNE, Type::*};
match hne {
HNE::U8(s) => self.ty.replace(Basic(U8(ast::U8Type {
span: s.span.clone(),
}))),
HNE::U16(s) => self.ty.replace(Basic(U16(ast::U16Type {
span: s.span.clone(),
}))),
HNE::U32(s) => self.ty.replace(Basic(U32(ast::U32Type {
span: s.span.clone(),
}))),
HNE::U64(s) => self.ty.replace(Basic(U64(ast::U64Type {
span: s.span.clone(),
}))),
};
Ok(())
}
fn visit_array_initializer_expression(
&mut self,
aie: &mut ast::ArrayInitializerExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
use ast::Type::*;
self.visit_expression(&mut *aie.value)?;
if let Some(ty) = self.take() {
let ty = self.arrayize(ty, aie.count.as_ref().clone(), &aie.span);
self.ty.replace(Array(ty));
}
Ok(())
}
fn visit_inline_struct_expression(
&mut self,
ise: &mut ast::InlineStructExpression<'ast>,
) -> ZVisitorResult {
// XXX(unimpl) we don't monomorphize struct type here... OK?
self.visit_identifier_expression_t(&mut ise.ty)
}
fn visit_inline_array_expression(
&mut self,
iae: &mut ast::InlineArrayExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
assert!(!iae.expressions.is_empty());
let mut acc_ty = None;
let mut acc_len = 0;
iae.expressions
.iter_mut()
.try_for_each::<_, ZVisitorResult>(|soe| {
self.visit_spread_or_expression(soe)?;
if let Some(ty) = self.take() {
let (nty, nln) = if matches!(soe, ast::SpreadOrExpression::Expression(_)) {
Ok((ty, 1))
} else if let ast::Type::Array(mut at) = ty {
assert!(!at.dimensions.is_empty());
let len = self.walker.zgen.const_usize_(&at.dimensions[0])?;
if at.dimensions.len() == 1 {
Ok((bos_to_type(at.ty), len))
} else {
at.dimensions.remove(0);
Ok((ast::Type::Array(at), len))
}
} else {
Err(format!(
"ZExpressionTyper: Spread expression: expected array, got {:?}",
ty
))
}?;
if let Some(acc) = &acc_ty {
eq_type(acc, &nty)?;
} else {
acc_ty.replace(nty);
}
acc_len += nln;
Ok(())
} else if matches!(soe, ast::SpreadOrExpression::Expression(_)) {
// assume expression type is OK, just increment count
acc_len += 1;
Ok(())
} else {
Err(ZVisitorError(format!(
"ZExpressionTyper: Could not type SpreadOrExpression::Spread {:#?}",
soe
)))
}
})?;
self.ty = acc_ty.map(|at| {
ast::Type::Array(self.arrayize(
at,
ast::Expression::Literal(ast::LiteralExpression::HexLiteral(
ast::HexLiteralExpression {
value: ast::HexNumberExpression::U32(ast::U32NumberExpression {
value: format!("{:04x}", acc_len),
span: iae.span.clone(),
}),
span: iae.span.clone(),
},
)),
&iae.span,
))
});
Ok(())
}
fn visit_postfix_expression(
&mut self,
pfe: &mut ast::PostfixExpression<'ast>,
) -> ZVisitorResult {
assert!(self.ty.is_none());
self.ty.replace(self.walker.get_postfix_ty(pfe, None)?);
Ok(())
}
}

View File

@@ -0,0 +1,444 @@
//! AST Walker for zokrates_pest_ast
use super::walkfns::*;
use super::ZVisitorResult;
use zokrates_pest_ast as ast;
pub trait ZVisitorMut<'ast>: Sized {
fn visit_file(&mut self, file: &mut ast::File<'ast>) -> ZVisitorResult {
walk_file(self, file)
}
fn visit_pragma(&mut self, pragma: &mut ast::Pragma<'ast>) -> ZVisitorResult {
walk_pragma(self, pragma)
}
fn visit_curve(&mut self, curve: &mut ast::Curve<'ast>) -> ZVisitorResult {
walk_curve(self, curve)
}
fn visit_span(&mut self, _span: &mut ast::Span<'ast>) -> ZVisitorResult {
Ok(())
}
fn visit_symbol_declaration(
&mut self,
sd: &mut ast::SymbolDeclaration<'ast>,
) -> ZVisitorResult {
walk_symbol_declaration(self, sd)
}
fn visit_eoi(&mut self, _eoi: &mut ast::EOI) -> ZVisitorResult {
Ok(())
}
fn visit_import_directive(
&mut self,
import: &mut ast::ImportDirective<'ast>,
) -> ZVisitorResult {
walk_import_directive(self, import)
}
fn visit_main_import_directive(
&mut self,
mimport: &mut ast::MainImportDirective<'ast>,
) -> ZVisitorResult {
walk_main_import_directive(self, mimport)
}
fn visit_from_import_directive(
&mut self,
fimport: &mut ast::FromImportDirective<'ast>,
) -> ZVisitorResult {
walk_from_import_directive(self, fimport)
}
fn visit_import_source(&mut self, is: &mut ast::ImportSource<'ast>) -> ZVisitorResult {
walk_import_source(self, is)
}
fn visit_import_symbol(&mut self, is: &mut ast::ImportSymbol<'ast>) -> ZVisitorResult {
walk_import_symbol(self, is)
}
fn visit_identifier_expression(
&mut self,
ie: &mut ast::IdentifierExpression<'ast>,
) -> ZVisitorResult {
walk_identifier_expression(self, ie)
}
fn visit_constant_definition(
&mut self,
cnstdef: &mut ast::ConstantDefinition<'ast>,
) -> ZVisitorResult {
walk_constant_definition(self, cnstdef)
}
fn visit_struct_definition(
&mut self,
structdef: &mut ast::StructDefinition<'ast>,
) -> ZVisitorResult {
walk_struct_definition(self, structdef)
}
fn visit_struct_field(&mut self, structfield: &mut ast::StructField<'ast>) -> ZVisitorResult {
walk_struct_field(self, structfield)
}
fn visit_function_definition(
&mut self,
fundef: &mut ast::FunctionDefinition<'ast>,
) -> ZVisitorResult {
walk_function_definition(self, fundef)
}
fn visit_parameter(&mut self, param: &mut ast::Parameter<'ast>) -> ZVisitorResult {
walk_parameter(self, param)
}
fn visit_visibility(&mut self, vis: &mut ast::Visibility<'ast>) -> ZVisitorResult {
walk_visibility(self, vis)
}
fn visit_public_visibility(&mut self, _pu: &mut ast::PublicVisibility) -> ZVisitorResult {
Ok(())
}
fn visit_private_visibility(
&mut self,
pr: &mut ast::PrivateVisibility<'ast>,
) -> ZVisitorResult {
walk_private_visibility(self, pr)
}
fn visit_private_number(&mut self, pn: &mut ast::PrivateNumber<'ast>) -> ZVisitorResult {
walk_private_number(self, pn)
}
fn visit_type(&mut self, ty: &mut ast::Type<'ast>) -> ZVisitorResult {
walk_type(self, ty)
}
fn visit_basic_type(&mut self, bty: &mut ast::BasicType<'ast>) -> ZVisitorResult {
walk_basic_type(self, bty)
}
fn visit_field_type(&mut self, fty: &mut ast::FieldType<'ast>) -> ZVisitorResult {
walk_field_type(self, fty)
}
fn visit_boolean_type(&mut self, bty: &mut ast::BooleanType<'ast>) -> ZVisitorResult {
walk_boolean_type(self, bty)
}
fn visit_u8_type(&mut self, u8ty: &mut ast::U8Type<'ast>) -> ZVisitorResult {
walk_u8_type(self, u8ty)
}
fn visit_u16_type(&mut self, u16ty: &mut ast::U16Type<'ast>) -> ZVisitorResult {
walk_u16_type(self, u16ty)
}
fn visit_u32_type(&mut self, u32ty: &mut ast::U32Type<'ast>) -> ZVisitorResult {
walk_u32_type(self, u32ty)
}
fn visit_u64_type(&mut self, u64ty: &mut ast::U64Type<'ast>) -> ZVisitorResult {
walk_u64_type(self, u64ty)
}
fn visit_array_type(&mut self, aty: &mut ast::ArrayType<'ast>) -> ZVisitorResult {
walk_array_type(self, aty)
}
fn visit_basic_or_struct_type(
&mut self,
bsty: &mut ast::BasicOrStructType<'ast>,
) -> ZVisitorResult {
walk_basic_or_struct_type(self, bsty)
}
fn visit_struct_type(&mut self, sty: &mut ast::StructType<'ast>) -> ZVisitorResult {
walk_struct_type(self, sty)
}
fn visit_explicit_generics(&mut self, eg: &mut ast::ExplicitGenerics<'ast>) -> ZVisitorResult {
walk_explicit_generics(self, eg)
}
fn visit_constant_generic_value(
&mut self,
cgv: &mut ast::ConstantGenericValue<'ast>,
) -> ZVisitorResult {
walk_constant_generic_value(self, cgv)
}
fn visit_literal_expression(
&mut self,
lexpr: &mut ast::LiteralExpression<'ast>,
) -> ZVisitorResult {
walk_literal_expression(self, lexpr)
}
fn visit_decimal_literal_expression(
&mut self,
dle: &mut ast::DecimalLiteralExpression<'ast>,
) -> ZVisitorResult {
walk_decimal_literal_expression(self, dle)
}
fn visit_decimal_number(&mut self, dn: &mut ast::DecimalNumber<'ast>) -> ZVisitorResult {
walk_decimal_number(self, dn)
}
fn visit_decimal_suffix(&mut self, ds: &mut ast::DecimalSuffix<'ast>) -> ZVisitorResult {
walk_decimal_suffix(self, ds)
}
fn visit_u8_suffix(&mut self, u8s: &mut ast::U8Suffix<'ast>) -> ZVisitorResult {
walk_u8_suffix(self, u8s)
}
fn visit_u16_suffix(&mut self, u16s: &mut ast::U16Suffix<'ast>) -> ZVisitorResult {
walk_u16_suffix(self, u16s)
}
fn visit_u32_suffix(&mut self, u32s: &mut ast::U32Suffix<'ast>) -> ZVisitorResult {
walk_u32_suffix(self, u32s)
}
fn visit_u64_suffix(&mut self, u64s: &mut ast::U64Suffix<'ast>) -> ZVisitorResult {
walk_u64_suffix(self, u64s)
}
fn visit_field_suffix(&mut self, fs: &mut ast::FieldSuffix<'ast>) -> ZVisitorResult {
walk_field_suffix(self, fs)
}
fn visit_boolean_literal_expression(
&mut self,
ble: &mut ast::BooleanLiteralExpression<'ast>,
) -> ZVisitorResult {
walk_boolean_literal_expression(self, ble)
}
fn visit_hex_literal_expression(
&mut self,
hle: &mut ast::HexLiteralExpression<'ast>,
) -> ZVisitorResult {
walk_hex_literal_expression(self, hle)
}
fn visit_hex_number_expression(
&mut self,
hne: &mut ast::HexNumberExpression<'ast>,
) -> ZVisitorResult {
walk_hex_number_expression(self, hne)
}
fn visit_u8_number_expression(
&mut self,
u8e: &mut ast::U8NumberExpression<'ast>,
) -> ZVisitorResult {
walk_u8_number_expression(self, u8e)
}
fn visit_u16_number_expression(
&mut self,
u16e: &mut ast::U16NumberExpression<'ast>,
) -> ZVisitorResult {
walk_u16_number_expression(self, u16e)
}
fn visit_u32_number_expression(
&mut self,
u32e: &mut ast::U32NumberExpression<'ast>,
) -> ZVisitorResult {
walk_u32_number_expression(self, u32e)
}
fn visit_u64_number_expression(
&mut self,
u64e: &mut ast::U64NumberExpression<'ast>,
) -> ZVisitorResult {
walk_u64_number_expression(self, u64e)
}
fn visit_underscore(&mut self, u: &mut ast::Underscore<'ast>) -> ZVisitorResult {
walk_underscore(self, u)
}
fn visit_expression(&mut self, expr: &mut ast::Expression<'ast>) -> ZVisitorResult {
walk_expression(self, expr)
}
fn visit_ternary_expression(
&mut self,
te: &mut ast::TernaryExpression<'ast>,
) -> ZVisitorResult {
walk_ternary_expression(self, te)
}
fn visit_binary_expression(&mut self, be: &mut ast::BinaryExpression<'ast>) -> ZVisitorResult {
walk_binary_expression(self, be)
}
fn visit_binary_operator(&mut self, _bo: &mut ast::BinaryOperator) -> ZVisitorResult {
Ok(())
}
fn visit_unary_expression(&mut self, ue: &mut ast::UnaryExpression<'ast>) -> ZVisitorResult {
walk_unary_expression(self, ue)
}
fn visit_unary_operator(&mut self, uo: &mut ast::UnaryOperator) -> ZVisitorResult {
walk_unary_operator(self, uo)
}
fn visit_pos_operator(&mut self, _po: &mut ast::PosOperator) -> ZVisitorResult {
Ok(())
}
fn visit_neg_operator(&mut self, _po: &mut ast::NegOperator) -> ZVisitorResult {
Ok(())
}
fn visit_not_operator(&mut self, _po: &mut ast::NotOperator) -> ZVisitorResult {
Ok(())
}
fn visit_postfix_expression(
&mut self,
pe: &mut ast::PostfixExpression<'ast>,
) -> ZVisitorResult {
walk_postfix_expression(self, pe)
}
fn visit_access(&mut self, acc: &mut ast::Access<'ast>) -> ZVisitorResult {
walk_access(self, acc)
}
fn visit_call_access(&mut self, ca: &mut ast::CallAccess<'ast>) -> ZVisitorResult {
walk_call_access(self, ca)
}
fn visit_arguments(&mut self, args: &mut ast::Arguments<'ast>) -> ZVisitorResult {
walk_arguments(self, args)
}
fn visit_array_access(&mut self, aa: &mut ast::ArrayAccess<'ast>) -> ZVisitorResult {
walk_array_access(self, aa)
}
fn visit_range_or_expression(
&mut self,
roe: &mut ast::RangeOrExpression<'ast>,
) -> ZVisitorResult {
walk_range_or_expression(self, roe)
}
fn visit_range(&mut self, rng: &mut ast::Range<'ast>) -> ZVisitorResult {
walk_range(self, rng)
}
fn visit_from_expression(&mut self, from: &mut ast::FromExpression<'ast>) -> ZVisitorResult {
walk_from_expression(self, from)
}
fn visit_to_expression(&mut self, to: &mut ast::ToExpression<'ast>) -> ZVisitorResult {
walk_to_expression(self, to)
}
fn visit_member_access(&mut self, ma: &mut ast::MemberAccess<'ast>) -> ZVisitorResult {
walk_member_access(self, ma)
}
fn visit_inline_array_expression(
&mut self,
iae: &mut ast::InlineArrayExpression<'ast>,
) -> ZVisitorResult {
walk_inline_array_expression(self, iae)
}
fn visit_spread_or_expression(
&mut self,
soe: &mut ast::SpreadOrExpression<'ast>,
) -> ZVisitorResult {
walk_spread_or_expression(self, soe)
}
fn visit_spread(&mut self, spread: &mut ast::Spread<'ast>) -> ZVisitorResult {
walk_spread(self, spread)
}
fn visit_inline_struct_expression(
&mut self,
ise: &mut ast::InlineStructExpression<'ast>,
) -> ZVisitorResult {
walk_inline_struct_expression(self, ise)
}
fn visit_inline_struct_member(
&mut self,
ism: &mut ast::InlineStructMember<'ast>,
) -> ZVisitorResult {
walk_inline_struct_member(self, ism)
}
fn visit_array_initializer_expression(
&mut self,
aie: &mut ast::ArrayInitializerExpression<'ast>,
) -> ZVisitorResult {
walk_array_initializer_expression(self, aie)
}
fn visit_statement(&mut self, stmt: &mut ast::Statement<'ast>) -> ZVisitorResult {
walk_statement(self, stmt)
}
fn visit_return_statement(&mut self, ret: &mut ast::ReturnStatement<'ast>) -> ZVisitorResult {
walk_return_statement(self, ret)
}
fn visit_definition_statement(
&mut self,
def: &mut ast::DefinitionStatement<'ast>,
) -> ZVisitorResult {
walk_definition_statement(self, def)
}
fn visit_typed_identifier_or_assignee(
&mut self,
tioa: &mut ast::TypedIdentifierOrAssignee<'ast>,
) -> ZVisitorResult {
walk_typed_identifier_or_assignee(self, tioa)
}
fn visit_typed_identifier(&mut self, ti: &mut ast::TypedIdentifier<'ast>) -> ZVisitorResult {
walk_typed_identifier(self, ti)
}
fn visit_assignee(&mut self, asgn: &mut ast::Assignee<'ast>) -> ZVisitorResult {
walk_assignee(self, asgn)
}
fn visit_assignee_access(&mut self, acc: &mut ast::AssigneeAccess<'ast>) -> ZVisitorResult {
walk_assignee_access(self, acc)
}
fn visit_assertion_statement(
&mut self,
asrt: &mut ast::AssertionStatement<'ast>,
) -> ZVisitorResult {
walk_assertion_statement(self, asrt)
}
fn visit_iteration_statement(
&mut self,
iter: &mut ast::IterationStatement<'ast>,
) -> ZVisitorResult {
walk_iteration_statement(self, iter)
}
}

View File

@@ -8,7 +8,7 @@ use std::sync::RwLock;
lazy_static! {
// TODO: use weak pointers to allow GC
static ref FOLDS: RwLock<TermMap<Term>> = RwLock::new(TermMap::new());
static ref FOLDS: RwLock<TermCache<Term>> = RwLock::new(TermCache::new(TERM_CACHE_LIMIT));
}
/// Create a constant boolean
@@ -23,18 +23,27 @@ fn cbv(b: BitVector) -> Option<Term> {
/// Fold away operators over constants.
pub fn fold(node: &Term) -> Term {
let mut cache = FOLDS.write().unwrap();
fold_cache(node, cache.deref_mut())
let mut cache_handle = FOLDS.write().unwrap();
let cache = cache_handle.deref_mut();
// make the cache unbounded during the fold_cache call
let old_capacity = cache.cap();
cache.resize(std::usize::MAX);
let ret = fold_cache(node, cache);
// shrink cache to its max size
cache.resize(old_capacity);
ret
}
/// Do constant-folding backed by a cache.
pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
pub fn fold_cache(node: &Term, cache: &mut TermCache<Term>) -> Term {
// (node, children pushed)
let mut stack = vec![(node.clone(), false)];
// Maps terms to their rewritten versions.
while let Some((t, children_pushed)) = stack.pop() {
if cache.contains_key(&t) {
if cache.contains(&t) {
continue;
}
if !children_pushed {
@@ -42,11 +51,11 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
stack.extend(t.cs.iter().map(|c| (c.clone(), false)));
continue;
}
let c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
let get = |i: usize| c_get(&t.cs[i]);
let mut c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
let mut get = |i: usize| c_get(&t.cs[i]);
let new_t_opt = match &t.op {
&NOT => get(0).as_bool_opt().and_then(|c| cbool(!c)),
&IMPLIES => match get(0).as_bool_opt() {
Op::Not => get(0).as_bool_opt().and_then(|c| cbool(!c)),
Op::Implies => match get(0).as_bool_opt() {
Some(true) => Some(get(1).clone()),
Some(false) => cbool(true),
None => match get(1).as_bool_opt() {
@@ -63,12 +72,17 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
Op::Eq => {
let c0 = get(0);
let c1 = get(1);
match (&c0.op, &c1.op) {
(Op::Const(Value::Bool(b0)), Op::Const(Value::Bool(b1))) => cbool(*b0 == *b1),
(Op::Const(Value::BitVector(b0)), Op::Const(Value::BitVector(b1))) => {
cbool(*b0 == *b1)
match (c0.as_value_opt(), c1.as_value_opt()) {
(Some(Value::BitVector(b0)), Some(Value::BitVector(b1))) => cbool(*b0 == *b1),
(Some(Value::F32(b0)), Some(Value::F32(b1))) => cbool(*b0 == *b1),
(Some(Value::F64(b0)), Some(Value::F64(b1))) => cbool(*b0 == *b1),
(Some(Value::Int(b0)), Some(Value::Int(b1))) => cbool(*b0 == *b1),
(Some(Value::Field(b0)), Some(Value::Field(b1))) => cbool(*b0 == *b1),
(Some(Value::Bool(b0)), Some(Value::Bool(b1))) => cbool(*b0 == *b1),
(Some(Value::Tuple(t0)), Some(Value::Tuple(t1))) => cbool(*t0 == *t1),
(Some(Value::Array(a0)), Some(Value::Array(a1))) => {
cbool(a0.size == a1.size && a0.map == a1.map)
}
(Op::Const(Value::Field(b0)), Op::Const(Value::Field(b1))) => cbool(*b0 == *b1),
_ => None,
}
}
@@ -179,12 +193,69 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
PfUnOp::Neg => -pf.clone(),
})))
}),
Op::UbvToPf(m) => get(0).as_bv_opt().map(|bv| {
leaf_term(Op::Const(Value::Field(FieldElem::new(
bv.uint().clone(),
m.clone(),
))))
}),
Op::Store => {
match (
get(0).as_array_opt(),
get(1).as_value_opt(),
get(2).as_value_opt(),
) {
(Some(arr), Some(idx), Some(val)) => {
let new_arr = arr.clone().store(idx.clone(), val.clone());
Some(leaf_term(Op::Const(Value::Array(new_arr))))
}
_ => None,
}
}
Op::Select => match (get(0).as_array_opt(), get(1).as_value_opt()) {
(Some(arr), Some(idx)) => Some(leaf_term(Op::Const(arr.select(idx)))),
_ => None,
},
Op::Tuple => {
t.cs.iter()
.map(|c| c_get(c).as_value_opt().cloned())
.collect::<Option<_>>()
.map(|v| leaf_term(Op::Const(Value::Tuple(v))))
}
Op::Field(n) => get(0)
.as_tuple_opt()
.map(|t| leaf_term(Op::Const(t[*n].clone()))),
Op::Update(n) => match (get(0).as_tuple_opt(), get(1).as_value_opt()) {
(Some(t), Some(v)) => {
let mut new_vec = Vec::from(t).into_boxed_slice();
assert_eq!(new_vec[*n].sort(), v.sort());
new_vec[*n] = v.clone();
Some(leaf_term(Op::Const(Value::Tuple(new_vec))))
}
_ => None,
},
Op::BvConcat => {
t.cs.iter()
.map(|c| c_get(c).as_bv_opt().cloned())
.collect::<Option<Vec<_>>>()
.map(|v| v.into_iter().reduce(BitVector::concat))
.flatten()
.map(|bv| leaf_term(Op::Const(Value::BitVector(bv))))
}
Op::BoolToBv => get(0).as_bool_opt().map(|b| {
leaf_term(Op::Const(Value::BitVector(BitVector::new(
Integer::from(b),
1,
))))
}),
_ => None,
};
let c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
let new_t = new_t_opt
.unwrap_or_else(|| term(t.op.clone(), t.cs.iter().map(|c| c_get(c)).collect()));
cache.insert(t, new_t);
let new_t = {
let mut cc_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
new_t_opt
.unwrap_or_else(|| term(t.op.clone(), t.cs.iter().map(|c| cc_get(c)).collect()))
};
cache.put(t, new_t);
}
cache.get(node).expect("postorder cache").clone()
}

View File

@@ -192,7 +192,7 @@ mod test {
use crate::target::smt::{check_sat, find_model};
fn b_var(b: &str) -> Term {
leaf_term(Op::Var(format!("{}", b), Sort::Bool))
leaf_term(Op::Var(b.to_string(), Sort::Bool))
}
fn sub_test(xs: Vec<Term>, n: usize) {
@@ -236,8 +236,8 @@ mod test {
}
panic!("Invalid inline");
}
assert_eq!(check_sat(&not_imp), false);
assert_eq!(check_sat(&not_imp_not), false);
assert!(!check_sat(&not_imp));
assert!(!check_sat(&not_imp_not));
}
#[test]

View File

@@ -11,7 +11,7 @@ fn arr_val_to_tup(v: &Value) -> Value {
Value::Array(Array {
default, map, size, ..
}) => Value::Tuple({
let mut vec: Vec<Value> = vec![arr_val_to_tup(default); *size];
let mut vec = vec![arr_val_to_tup(default); *size].into_boxed_slice();
for (i, v) in map {
vec[i.as_usize().expect("non usize key")] = arr_val_to_tup(v);
}
@@ -23,7 +23,9 @@ fn arr_val_to_tup(v: &Value) -> Value {
fn arr_sort_to_tup(v: &Sort) -> Sort {
match v {
Sort::Array(_key, value, size) => Sort::Tuple(vec![arr_sort_to_tup(value); *size]),
Sort::Array(_key, value, size) => {
Sort::Tuple(vec![arr_sort_to_tup(value); *size].into_boxed_slice())
}
v => v.clone(),
}
}
@@ -110,7 +112,7 @@ mod test {
fn count_ites(t: &Term) -> usize {
PostOrderIter::new(t.clone())
.filter(|t| &t.op == &Op::Ite)
.filter(|t| t.op == Op::Ite)
.count()
}
@@ -133,7 +135,7 @@ mod test {
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
term![Op::Store; z.clone(), bv_lit(3, 4), bv_lit(1, 4)],
term![Op::Store; z.clone(), bv_lit(2, 4), bv_lit(1, 4)]
term![Op::Store; z, bv_lit(2, 4), bv_lit(1, 4)]
],
bv_lit(3, 4)
];
@@ -156,7 +158,7 @@ mod test {
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
term![Op::Store; z.clone(), field_lit(3), bv_lit(1, 4)],
term![Op::Store; z.clone(), field_lit(2), bv_lit(1, 4)]
term![Op::Store; z, field_lit(2), bv_lit(1, 4)]
],
field_lit(3)
];

View File

@@ -11,10 +11,11 @@
//!
//! ## Pass 1: Identifying oblivious arrays
//!
//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole SMT
//! constraint system, performing the following inferences:
//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole computation
//! system, performing the following inferences:
//!
//! * If `a[i]` for non-constant `i`, then `a` is not oblivious
//! * If `a[i]` for non-constant `i`, then `a` and `a[i]` are not oblivious;
//! * If `a[i]`, `a` and `a[i]` are equi-oblivious
//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious
//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious
//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious
@@ -22,12 +23,51 @@
//!
//! This procedure is iterated to fixpoint.
//!
//! Notice that we flag some *array* terms as non-oblivious, and we also flag their derived select
//! terms as non-oblivious. This makes it easy to see which selects should be replaced later.
//!
//! ### Sharing & Constant Arrays
//!
//! This pass is effective given the somewhat naive assumption that array terms in the term graph
//! can be separated into different "threads", which are not connected. Sometimes they are,
//! especially by constant arrays.
//!
//! For example, consider code like this:
//!
//! ```ignore
//! x = [0, 0, 0, 0]
//! y = [0, 0, 0, 0]
//! // oblivious modifications to x
//! // non-oblivious modifications to y
//! ```
//!
//! In this situation, we would hope that x and its derived arrays will be identified as
//! "oblivious" while y will not.
//!
//! However, because of term sharing, the constant array [0,0,0,0] happens to be the root of both
//! x's and y's store chains. If the constant array is `c`, then the first store to x might be
//! `c[0\v1]` while the first store to y might be `c[i2\v2]`. The "store" rule for non-oblivious
//! analysis would say that `c` is non-oblivious (b/c of the second store) and therefore the whole
//! x store chain would b too...
//!
//! The problem isn't just with constants. If any non-oblivious stores branch off an otherwise
//! oblivious store chain, the same thing happens.
//!
//! Since constants are a pervasive problem, we special-case them, omitting them from the analysis.
//!
//! We probably want a better idea of what this pass does (and how to handle arrays) at some
//! point...
//!
//! ## Pass 2: Replacing oblivious arrays with term lists.
//!
//! In this pass, the goal is to
//!
//! * map array terms to tuple terms
//! * map array selections to tuple field gets
//!
//! In both cases we look at the non-oblivious array/select set to see whether to do the
//! replacement.
//!
use super::super::visit::*;
use crate::ir::term::extras::as_uint_constant;
@@ -66,6 +106,7 @@ impl NonOblivComputer {
false
}
}
fn new() -> Self {
Self {
not_obliv: TermSet::new(),
@@ -98,13 +139,19 @@ impl ProgressAnalysisPass for NonOblivComputer {
progress
}
Op::Select => {
// Even though the selected value may not have array sort, we still flag it as
// non-oblivious so we know whether to replace it or not.
let a = &term.cs[0];
let i = &term.cs[1];
let mut progress = false;
if let Op::Const(_) = i.op {
false
// pass
} else {
self.mark(a)
progress = self.mark(a) || progress;
progress = self.mark(term) || progress;
}
progress = self.bi_implicate(term, a) || progress;
progress
}
Op::Ite => {
let t = &term.cs[1];
@@ -150,7 +197,7 @@ fn arr_val_to_tup(v: &Value) -> Value {
Value::Array(Array {
default, map, size, ..
}) => Value::Tuple({
let mut vec: Vec<Value> = vec![arr_val_to_tup(default); *size];
let mut vec = vec![arr_val_to_tup(default); *size].into_boxed_slice();
for (i, v) in map {
vec[i.as_usize().expect("non usize key")] = arr_val_to_tup(v);
}
@@ -169,7 +216,9 @@ fn term_arr_val_to_tup(a: Term) -> Term {
fn arr_sort_to_tup(v: &Sort) -> Sort {
match v {
Sort::Array(_key, value, size) => Sort::Tuple(vec![arr_sort_to_tup(value); *size]),
Sort::Array(_key, value, size) => {
Sort::Tuple(vec![arr_sort_to_tup(value); *size].into_boxed_slice())
}
v => v.clone(),
}
}
@@ -213,7 +262,8 @@ impl RewritePass for Replacer {
}
}
Op::Select => {
if self.should_replace(&orig.cs[0]) {
// we mark the selected term as non-obliv...
if self.should_replace(orig) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 2);
let k_const = get_const(&cs.pop().unwrap());
@@ -290,7 +340,7 @@ mod test {
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
term![Op::Store; z.clone(), bv_lit(3, 4), bv_lit(1, 4)],
term![Op::Store; z.clone(), bv_lit(2, 4), bv_lit(1, 4)]
term![Op::Store; z, bv_lit(2, 4), bv_lit(1, 4)]
],
bv_lit(3, 4)
];
@@ -312,7 +362,7 @@ mod test {
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
term![Op::Store; z.clone(), v_bv("a", 4), bv_lit(1, 4)],
term![Op::Store; z.clone(), bv_lit(2, 4), bv_lit(1, 4)]
term![Op::Store; z, bv_lit(2, 4), bv_lit(1, 4)]
],
bv_lit(3, 4)
];
@@ -321,4 +371,58 @@ mod test {
elim_obliv(&mut c);
assert!(!array_free(&c.outputs[0]));
}
#[test]
fn mix_diff_constant() {
let z0 = term![Op::Const(Value::Array(Array::new(
Sort::BitVector(4),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
6
)))];
let z1 = term![Op::Const(Value::Array(Array::new(
Sort::BitVector(4),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
5
)))];
let t0 = term![Op::Select;
term![Op::Store; z0.clone(), v_bv("a", 4), bv_lit(1, 4)],
bv_lit(3, 4)
];
let t1 = term![Op::Select;
term![Op::Store; z1.clone(), bv_lit(3, 4), bv_lit(1, 4)],
bv_lit(3, 4)
];
let mut c = Computation::default();
c.outputs.push(t0);
c.outputs.push(t1);
elim_obliv(&mut c);
assert!(!array_free(&c.outputs[0]));
assert!(array_free(&c.outputs[1]));
}
#[test]
fn mix_same_constant() {
let z = term![Op::Const(Value::Array(Array::new(
Sort::BitVector(4),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
6
)))];
let t0 = term![Op::Select;
term![Op::Store; z.clone(), v_bv("a", 4), bv_lit(1, 4)],
bv_lit(3, 4)
];
let t1 = term![Op::Select;
term![Op::Store; z.clone(), bv_lit(3, 4), bv_lit(1, 4)],
bv_lit(3, 4)
];
let mut c = Computation::default();
c.outputs.push(t0);
c.outputs.push(t1);
elim_obliv(&mut c);
assert!(!array_free(&c.outputs[0]));
assert!(array_free(&c.outputs[1]));
}
}

View File

@@ -47,9 +47,13 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
scalarize_vars::scalarize_inputs(&mut cs);
}
Opt::ConstantFold => {
let mut cache = TermMap::new();
let mut cache = TermCache::new(TERM_CACHE_LIMIT);
for a in &mut cs.outputs {
// allow unbounded size during a single fold_cache call
cache.resize(std::usize::MAX);
*a = cfold::fold_cache(a, &mut cache);
// then shrink back down to size between calls
cache.resize(TERM_CACHE_LIMIT);
}
}
Opt::Sha => {

View File

@@ -174,7 +174,8 @@ mod test {
let b = bool_lit(false);
let c = bool_lit(false);
let t = term![Op::BoolMaj; a.clone(), b.clone(),c.clone()];
let tt = term![OR; term![AND; a.clone(), b.clone()], term![AND; b.clone(), c.clone()], term![AND; c.clone(), a.clone()]];
let tt =
term![OR; term![AND; a.clone(), b.clone()], term![AND; b, c.clone()], term![AND; c, a]];
assert_eq!(tt, sha_maj_elim(&t));
}

View File

@@ -141,7 +141,7 @@ impl ValueTupleTree {
fn rec_unroll_into(t: &Value, out: &mut Vec<Value>) {
match t {
Value::Tuple(vs) => {
for c in vs {
for c in vs.iter() {
rec_unroll_into(c, out);
}
}
@@ -166,7 +166,10 @@ impl ValueTupleTree {
fn termify_val_tuples(v: Value) -> Term {
if let Value::Tuple(vs) = v {
term(Op::Tuple, vs.into_iter().map(termify_val_tuples).collect())
term(
Op::Tuple,
Vec::from(vs).into_iter().map(termify_val_tuples).collect(),
)
} else {
leaf_term(Op::Const(v))
}

View File

@@ -179,7 +179,7 @@ impl FixedSizeDist {
}],
Op::Tuple => {
if let Sort::Tuple(sorts) = sort {
sorts.clone()
sorts.to_vec()
} else {
unreachable!("Bad sort for tuple cons: {}", sort)
}
@@ -343,9 +343,12 @@ pub mod test {
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let ts = PostOrderIter::new(self.0.clone()).collect::<Vec<_>>();
let ts = PostOrderIter::new(self.0.clone())
.collect::<Vec<_>>()
.into_iter()
.rev();
Box::new(ts.into_iter().rev().skip(1).map(ArbitraryTerm))
Box::new(ts.skip(1).map(ArbitraryTerm))
}
}
@@ -368,15 +371,13 @@ pub mod test {
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let ts = PostOrderIter::new(self.0.clone()).collect::<Vec<_>>();
let ts = PostOrderIter::new(self.0.clone())
.collect::<Vec<_>>()
.into_iter()
.rev();
let vs = self.1.clone();
Box::new(
ts.into_iter()
.rev()
.skip(1)
.map(move |t| ArbitraryBoolEnv(t, vs.clone())),
)
Box::new(ts.skip(1).map(move |t| ArbitraryBoolEnv(t, vs.clone())))
}
}
@@ -411,15 +412,13 @@ pub mod test {
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let ts = PostOrderIter::new(self.0.clone()).collect::<Vec<_>>();
let ts = PostOrderIter::new(self.0.clone())
.collect::<Vec<_>>()
.into_iter()
.rev();
let vs = self.1.clone();
Box::new(
ts.into_iter()
.rev()
.skip(1)
.map(move |t| ArbitraryTermEnv(t, vs.clone())),
)
Box::new(ts.skip(1).map(move |t| ArbitraryTermEnv(t, vs.clone())))
}
}

View File

@@ -623,7 +623,7 @@ pub enum Value {
/// Array
Array(Array),
/// Tuple
Tuple(Vec<Value>),
Tuple(Box<[Value]>),
}
#[derive(Clone, PartialEq, Debug, PartialOrd, Hash)]
@@ -649,6 +649,12 @@ impl Array {
map: BTreeMap<Value, Value>,
size: usize,
) -> Self {
if key_sort.default_value().as_usize().is_none() {
panic!(
"IR Arrays cannot have {} index (Int, BitVector, Bool, or Field only)",
key_sort
);
}
Self {
key_sort,
default,
@@ -665,13 +671,48 @@ impl Array {
size,
)
}
// consistency check for index
fn check_idx(&self, idx: &Value) {
if idx.sort() != self.key_sort {
panic!(
"Tried to index array with key {}, but {} was expected",
idx.sort(),
self.key_sort
);
}
match idx.as_usize() {
Some(idx_u) if idx_u < self.size => (),
Some(idx_u) => panic!(
"IR Array out of range: accessed {}, size is {}",
idx_u, self.size
),
_ => panic!("IR Array index {} not convertible to usize", idx),
}
}
// consistency check for value
fn check_val(&self, vsrt: Sort) {
if vsrt != self.default.sort() {
panic!(
"Attempted to store {} to an array of {}",
vsrt,
self.default.sort()
);
}
}
/// Store
pub fn store(mut self, idx: Value, val: Value) -> Self {
self.check_idx(&idx);
self.check_val(val.sort());
self.map.insert(idx, val);
self
}
/// Select
pub fn select(&self, idx: &Value) -> Value {
self.check_idx(idx);
self.map.get(idx).unwrap_or(&*self.default).clone()
}
}
@@ -687,7 +728,7 @@ impl Display for Value {
Value::BitVector(b) => write!(f, "{}", b),
Value::Tuple(fields) => {
write!(f, "(tuple")?;
for field in fields {
for field in fields.iter() {
write!(f, " {}", field)?;
}
write!(f, ")")
@@ -701,7 +742,7 @@ impl Display for Array {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"(map default:{} size:{} {:?})",
"(array default:{} size:{} {:?})",
self.default, self.size, self.map
)
}
@@ -756,7 +797,7 @@ pub enum Sort {
/// size presumes an order, and a zero, for the key sort.
Array(Box<Sort>, Box<Sort>, usize),
/// A tuple
Tuple(Vec<Sort>),
Tuple(Box<[Sort]>),
}
impl Sort {
@@ -782,7 +823,7 @@ impl Sort {
#[track_caller]
/// Unwrap the constituent sorts of this tuple, panicking otherwise.
pub fn as_tuple(&self) -> &Vec<Sort> {
pub fn as_tuple(&self) -> &[Sort] {
if let Sort::Tuple(w) = self {
w
} else {
@@ -880,7 +921,7 @@ impl Display for Sort {
Sort::Array(k, v, n) => write!(f, "(array {} {} {})", k, v, n),
Sort::Tuple(fields) => {
write!(f, "(tuple")?;
for field in fields {
for field in fields.iter() {
write!(f, " {}", field)?;
}
write!(f, ")")
@@ -898,6 +939,7 @@ pub type TTerm = WHConsed<TermData>;
struct TermTable {
map: FxHashMap<TermData, TTerm>,
count: u64,
last_len: usize,
}
impl TermTable {
@@ -926,6 +968,14 @@ impl TermTable {
// ...and return consed version.
hconsed
}
fn should_collect(&mut self) -> bool {
let ret = LEN_THRESH_DEN * self.map.len() > LEN_THRESH_NUM * self.last_len;
if self.last_len > TERM_CACHE_LIMIT {
// when last_len is big, force a garbage collect every once in a while
self.last_len = (self.last_len * LEN_DECAY_NUM) / LEN_DECAY_DEN;
}
ret
}
fn collect(&mut self) {
let old_size = self.map.len();
let mut to_check: OnceQueue<Term> = OnceQueue::new();
@@ -952,6 +1002,39 @@ impl TermTable {
assert!(v.elm.upgrade().is_some(), "Can not upgrade: {:?}", k)
}
debug!(target: "ir::term::gc", "{} of {} terms collected", old_size - new_size, old_size);
self.last_len = new_size;
}
}
struct TypeTable {
map: FxHashMap<TTerm, Sort>,
last_len: usize,
}
impl std::ops::Deref for TypeTable {
type Target = FxHashMap<TTerm, Sort>;
fn deref(&self) -> &Self::Target {
&self.map
}
}
impl std::ops::DerefMut for TypeTable {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.map
}
}
impl TypeTable {
fn should_collect(&mut self) -> bool {
let ret = LEN_THRESH_DEN * self.map.len() > LEN_THRESH_NUM * self.last_len;
if self.last_len > TERM_CACHE_LIMIT {
// when last_len is big, force a garbage collect every once in a while
self.last_len = (self.last_len * LEN_DECAY_NUM) / LEN_DECAY_DEN;
}
ret
}
fn collect(&mut self) {
let old_size = self.map.len();
self.map.retain(|term, _| term.to_hconsed().is_some());
let new_size = self.map.len();
debug!(target: "ir::term::gc", "{} of {} types collected", old_size - new_size, old_size);
self.last_len = new_size;
}
}
@@ -959,6 +1042,7 @@ lazy_static! {
static ref TERMS: RwLock<TermTable> = RwLock::new(TermTable {
map: FxHashMap::default(),
count: 0,
last_len: 0,
});
}
@@ -973,16 +1057,37 @@ pub fn garbage_collect() {
collect_types();
}
const LEN_THRESH_NUM: usize = 8;
const LEN_THRESH_DEN: usize = 1;
const LEN_DECAY_NUM: usize = 15;
const LEN_DECAY_DEN: usize = 16;
/// Scan term and type databases only if they've grown in size since last scan
pub fn maybe_garbage_collect() -> bool {
let mut ran = {
let mut term_table = TERMS.write().unwrap();
if term_table.should_collect() {
term_table.collect();
true
} else {
false
}
};
{
let mut type_table = ty::TERM_TYPES.write().unwrap();
if type_table.should_collect() {
type_table.collect();
ran = true;
}
}
ran
}
fn collect_terms() {
TERMS.write().unwrap().collect();
}
fn collect_types() {
let mut ty_map = ty::TERM_TYPES.write().unwrap();
let old_size = ty_map.len();
ty_map.retain(|term, _| term.to_hconsed().is_some());
let new_size = ty_map.len();
debug!(target: "ir::term::gc", "{} of {} types collected", old_size - new_size, old_size);
ty::TERM_TYPES.write().unwrap().collect();
}
impl TermData {
@@ -1010,10 +1115,39 @@ impl TermData {
None
}
}
/// Get the underlying tuple constant, if possible.
pub fn as_tuple_opt(&self) -> Option<&[Value]> {
if let Op::Const(Value::Tuple(t)) = &self.op {
Some(t)
} else {
None
}
}
/// Get the underlying array constant, if possible.
pub fn as_array_opt(&self) -> Option<&Array> {
if let Op::Const(Value::Array(a)) = &self.op {
Some(a)
} else {
None
}
}
/// Get the underlying constant value, if possible.
pub fn as_value_opt(&self) -> Option<&Value> {
if let Op::Const(v) = &self.op {
Some(v)
} else {
None
}
}
/// Is this a variable?
pub fn is_var(&self) -> bool {
matches!(&self.op, Op::Var(..))
}
/// Is this a value
pub fn is_const(&self) -> bool {
matches!(&self.op, Op::Const(..))
@@ -1068,7 +1202,7 @@ impl Value {
}
#[track_caller]
/// Get the underlying tuple's constituent values, if possible.
pub fn as_tuple(&self) -> &Vec<Value> {
pub fn as_tuple(&self) -> &[Value] {
if let Value::Tuple(b) = self {
b
} else {
@@ -1102,7 +1236,8 @@ impl Value {
None
}
}
/// Compute the sort of this value
/// Convert this value into a usize if possible
pub fn as_usize(&self) -> Option<usize> {
match &self {
Value::Bool(b) => Some(*b as usize),
@@ -1260,9 +1395,10 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
t[*i].clone()
}
Op::Update(i) => {
let mut t = vs.get(&c.cs[0]).unwrap().as_tuple().clone();
let mut t = Vec::from(vs.get(&c.cs[0]).unwrap().as_tuple()).into_boxed_slice();
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs[0]);
let e = vs.get(&c.cs[1]).unwrap().clone();
assert_eq!(t[*i].sort(), e.sort());
t[*i] = e;
Value::Tuple(t)
}
@@ -1307,6 +1443,7 @@ pub fn leaf_term(op: Op) -> Term {
/// Make a term with arguments.
#[track_caller]
pub fn term(op: Op, cs: Vec<Term>) -> Term {
#[cfg_attr(not(debug_assertions), allow(clippy::let_and_return))]
let t = mk(TermData { op, cs });
#[cfg(debug_assertions)]
check_rec(&t);
@@ -1348,9 +1485,15 @@ macro_rules! term {
/// Map from terms
pub type TermMap<T> = hashconsing::coll::HConMap<Term, T>;
/// LRU cache of terms (like TermMap, but limited size)
pub type TermCache<T> = hashconsing::coll::HConLru<Term, T>;
/// Set of terms
pub type TermSet = hashconsing::coll::HConSet<Term>;
// default LRU cache size
// this size avoids quadratic behavior for Falcon verification
pub(super) const TERM_CACHE_LIMIT: usize = 65536;
/// Iterator over descendents in child-first order.
pub struct PostOrderIter {
// (cs stacked, term)
@@ -1595,7 +1738,7 @@ impl Computation {
/// Assert `s` in the system.
pub fn assert(&mut self, s: Term) {
assert!(check(&s) == Sort::Bool);
debug!("Assert: {}", extras::Letified(s.clone()));
debug!("Assert: {}", &s.op);
self.outputs.push(s);
}
/// If tracking values, evaluate `term`, and set the result to `name`.

View File

@@ -4,7 +4,10 @@ use super::*;
lazy_static! {
/// Cache of all types
pub static ref TERM_TYPES: RwLock<FxHashMap<TTerm, Sort>> = RwLock::new(FxHashMap::default());
pub(super) static ref TERM_TYPES: RwLock<TypeTable> = RwLock::new(TypeTable {
map: FxHashMap::default(),
last_len: 0,
});
}
#[track_caller]
@@ -27,6 +30,28 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
if let Some(s) = TERM_TYPES.read().unwrap().get(&t.to_weak()) {
return Ok(s.clone());
}
// RSW: the below loop is a band-aid to keep from blowing the stack
// XXX(q) is there a better way to write this function?
let mut t = t;
loop {
let t_new = match &t.op {
Op::Ite => &t.cs[1],
Op::BvBinOp(_) => &t.cs[0],
Op::BvNaryOp(_) => &t.cs[0],
Op::BvUnOp(_) => &t.cs[0],
Op::FpBinOp(_) => &t.cs[0],
Op::FpUnOp(_) => &t.cs[0],
Op::PfUnOp(_) => &t.cs[0],
Op::PfNaryOp(_) => &t.cs[0],
Op::Store => &t.cs[0],
Op::Update(_i) => &t.cs[0],
_ => break,
};
if std::ptr::eq(t, t_new) {
panic!("infinite loop detected in check_raw");
}
t = t_new;
}
let ty = match &t.op {
Op::Ite => Ok(check_raw(&t.cs[1])?),
Op::Eq => Ok(Sort::Bool),
@@ -88,7 +113,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
Op::Select => array_or(&check_raw(&t.cs[0])?, "select").map(|(_, v)| v.clone()),
Op::Store => Ok(check_raw(&t.cs[0])?),
Op::Tuple => Ok(Sort::Tuple(
t.cs.iter().map(check_raw).collect::<Result<Vec<_>, _>>()?,
t.cs.iter().map(check_raw).collect::<Result<_, _>>()?,
)),
Op::Field(i) => {
let sort = check_raw(&t.cs[0])?;
@@ -371,7 +396,7 @@ fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason
}
}
fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Vec<Sort>, TypeErrorReason> {
fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a [Sort], TypeErrorReason> {
match a {
Sort::Tuple(a) => Ok(a),
_ => Err(TypeErrorReason::ExpectedTuple(ctx)),

View File

@@ -129,7 +129,7 @@ mod test {
let solution = vars
.maximise(a + b + c)
.using(default_solver)
.with(a + b << 30.0)
.with((a + b) << 30.0)
.solve()
.unwrap();
assert_eq!(solution.value(a), 1.0);
@@ -145,7 +145,7 @@ mod test {
let solution = vars
.maximise(a + b + c)
.using(good_lp::solvers::lp_solvers::LpSolver(s))
.with(a + b << 30.0)
.with((a + b) << 30.0)
.solve()
.unwrap();
assert_eq!(solution.value(a), 1.0);

View File

@@ -727,13 +727,13 @@ mod test {
match val {
Value::Bool(true) => {
if let Some(var) = ilp.var_names.get(v) {
let e = Expression::from(var.clone());
let e = Expression::from(*var);
ilp.new_constraint(e.eq(1.0));
}
}
Value::Bool(false) => {
if let Some(var) = ilp.var_names.get(v) {
let e = Expression::from(var.clone());
let e = Expression::from(*var);
ilp.new_constraint(e.eq(0.0));
}
}

View File

@@ -1,6 +1,8 @@
//! Target circuit representations (and lowering passes)
#[cfg(feature = "lp")]
pub mod aby;
#[cfg(feature = "lp")]
pub mod ilp;
pub mod r1cs;
pub mod smt;

View File

@@ -188,7 +188,7 @@ mod test {
)
.unwrap(),
);
convert(modulus.clone() - 1);
convert(modulus - 1);
}
#[test]

View File

@@ -999,14 +999,12 @@ pub mod test {
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let vs = self.1.clone();
let ts = PostOrderIter::new(self.0.clone()).collect::<Vec<_>>();
let ts = PostOrderIter::new(self.0.clone())
.collect::<Vec<_>>()
.into_iter()
.rev();
Box::new(
ts.into_iter()
.rev()
.skip(1)
.map(move |t| PureBool(t, vs.clone())),
)
Box::new(ts.skip(1).map(move |t| PureBool(t, vs.clone())))
}
}

View File

@@ -94,7 +94,7 @@ impl Expr2Smt<()> for Value {
}
Value::Tuple(fs) => {
write!(w, "(mkTuple")?;
for t in fs {
for t in fs.iter() {
write!(w, " {}", SmtDisp(t))?;
}
write!(w, ")")?;
@@ -176,7 +176,7 @@ impl Sort2Smt for Sort {
Sort::Int => write!(w, "Int")?,
Sort::Tuple(fs) => {
write!(w, "(Tuple")?;
for t in fs {
for t in fs.iter() {
write!(w, " {}", SmtSortDisp(t))?;
}
write!(w, ")")?;
@@ -277,7 +277,7 @@ impl<'a, Br: ::std::io::BufRead> ModelParser<String, Sort, Value, &'a mut SmtPar
/// Create a solver, which can optionally parse models.
///
/// If [rsmt2::conf::CVC4_ENV_VAR] is set, uses that as the solver's invocation command.
fn make_solver<P>(parser: P, models: bool) -> rsmt2::Solver<P> {
fn make_solver<P>(parser: P, models: bool, inc: bool) -> rsmt2::Solver<P> {
let mut conf = rsmt2::conf::SmtConf::default_cvc4();
if let Ok(val) = std::env::var(rsmt2::conf::CVC4_ENV_VAR) {
conf.cmd(val);
@@ -285,12 +285,13 @@ fn make_solver<P>(parser: P, models: bool) -> rsmt2::Solver<P> {
if models {
conf.models();
}
conf.set_incremental(inc);
rsmt2::Solver::new(conf, parser).expect("Error creating SMT solver")
}
/// Check whether some term is satisfiable.
pub fn check_sat(t: &Term) -> bool {
let mut solver = make_solver((), false);
let mut solver = make_solver((), false, false);
for c in PostOrderIter::new(t.clone()) {
if let Op::Var(n, s) = &c.op {
solver.declare_const(&SmtSymDisp(n), s).unwrap();
@@ -301,9 +302,8 @@ pub fn check_sat(t: &Term) -> bool {
solver.check_sat().unwrap()
}
/// Get a satisfying assignment for `t`, assuming it is SAT.
pub fn find_model(t: &Term) -> Option<HashMap<String, Value>> {
let mut solver = make_solver(Parser, true);
fn get_model_solver(t: &Term, inc: bool) -> rsmt2::Solver<Parser> {
let mut solver = make_solver(Parser, true, inc);
//solver.path_tee("solver_com").unwrap();
for c in PostOrderIter::new(t.clone()) {
if let Op::Var(n, s) = &c.op {
@@ -311,6 +311,12 @@ pub fn find_model(t: &Term) -> Option<HashMap<String, Value>> {
}
}
assert!(check(t) == Sort::Bool);
solver
}
/// Get a satisfying assignment for `t`, assuming it is SAT.
pub fn find_model(t: &Term) -> Option<HashMap<String, Value>> {
let mut solver = get_model_solver(t, false);
solver.assert(&**t).unwrap();
if solver.check_sat().unwrap() {
Some(
@@ -326,6 +332,44 @@ pub fn find_model(t: &Term) -> Option<HashMap<String, Value>> {
}
}
/// Get a unique satisfying assignment for `t`, assuming it is SAT.
pub fn find_unique_model(t: &Term, uniqs: Vec<String>) -> Option<HashMap<String, Value>> {
let mut solver = get_model_solver(t, true);
solver.assert(&**t).unwrap();
// first, get the result
let model: HashMap<String, Value> = if solver.check_sat().unwrap() {
solver
.get_model()
.unwrap()
.into_iter()
.map(|(id, _, _, v)| (id, v))
.collect()
} else {
return None;
};
// now, assert that any value in uniq is not the value assigned and check unsat
match uniqs
.into_iter()
.flat_map(|n| {
model
.get(&n)
.map(|v| term![EQ; term![Op::Var(n, v.sort())], term![Op::Const(v.clone())]])
})
.reduce(|l, r| term![AND; l, r])
.map(|t| term![NOT; t])
{
None => Some(model),
Some(ast) => {
solver.push(1).unwrap();
solver.assert(&*ast).unwrap();
match solver.check_sat().unwrap() {
true => None,
false => Some(model),
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
@@ -370,7 +414,7 @@ mod test {
fn tuple_is_sat() {
let t = term![Op::Eq; term![Op::Field(0); term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)]], leaf_term(Op::Var("a".into(), Sort::BitVector(4)))];
assert!(check_sat(&t));
let t = term![Op::Eq; term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)], leaf_term(Op::Var("a".into(), Sort::Tuple(vec![Sort::BitVector(4), Sort::BitVector(6)])))];
let t = term![Op::Eq; term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)], leaf_term(Op::Var("a".into(), Sort::Tuple(vec![Sort::BitVector(4), Sort::BitVector(6)].into_boxed_slice())))];
assert!(check_sat(&t));
}
@@ -414,12 +458,12 @@ mod test {
#[quickcheck]
fn eval_random_bool(ArbitraryBoolEnv(t, vs): ArbitraryBoolEnv) {
assert!(smt_eval_test(t.clone(), &vs));
assert!(!smt_eval_alternate_solution(t.clone(), &vs));
assert!(!smt_eval_alternate_solution(t, &vs));
}
/// Check that `t` evaluates consistently within the SMT solver under `vs`.
pub fn smt_eval_test(t: Term, vs: &HashMap<String, Value>) -> bool {
let mut solver = make_solver((), false);
let mut solver = make_solver((), false, false);
for (var, val) in vs {
let s = val.sort();
solver.declare_const(&SmtSymDisp(&var), &s).unwrap();
@@ -434,7 +478,7 @@ mod test {
/// Check that `t` evaluates consistently within the SMT solver under `vs`.
pub fn smt_eval_alternate_solution(t: Term, vs: &HashMap<String, Value>) -> bool {
let mut solver = make_solver((), false);
let mut solver = make_solver((), false, false);
for (var, val) in vs {
let s = val.sort();
solver.declare_const(&SmtSymDisp(&var), &s).unwrap();

View File

@@ -1,189 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "block-buffer"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0940dc441f31689269e10ac70eb1002a3a1d3ad1390e030043662eb7fe4688b"
dependencies = [
"block-padding",
"byte-tools",
"byteorder",
"generic-array",
]
[[package]]
name = "block-padding"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa79dedbb091f449f1f39e53edf88d5dbe95f895dae6135a8d7b881fb5af73f5"
dependencies = [
"byte-tools",
]
[[package]]
name = "byte-tools"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b5ca7a04898ad4bcd41c90c5285445ff5b791899bb1b0abdd2a2aa791211d7"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "digest"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5"
dependencies = [
"generic-array",
]
[[package]]
name = "fake-simd"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed"
[[package]]
name = "generic-array"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffdf9f34f1447443d37393cc6c2b8313aebddcd96906caf34e54c68d8e57d7bd"
dependencies = [
"typenum",
]
[[package]]
name = "glob"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8be18de09a56b60ed0edf84bc9df007e30040691af7acd1c41874faac5895bfb"
[[package]]
name = "maplit"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d"
[[package]]
name = "opaque-debug"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c"
[[package]]
name = "pest"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53"
dependencies = [
"ucd-trie",
]
[[package]]
name = "pest_derive"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "833d1ae558dc601e9a60366421196a8d94bc0ac980476d0b67e1d0988d72b2d0"
dependencies = [
"pest",
"pest_generator",
]
[[package]]
name = "pest_generator"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99b8db626e31e5b81787b9783425769681b347011cc59471e33ea46d2ea0cf55"
dependencies = [
"pest",
"pest_meta",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pest_meta"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54be6e404f5317079812fc8f9f5279de376d8856929e21c184ecf6bbd692a11d"
dependencies = [
"maplit",
"pest",
"sha-1",
]
[[package]]
name = "proc-macro2"
version = "1.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba508cc11742c0dc5c1659771673afbab7a0efab23aa17e854cbab0837ed0b43"
dependencies = [
"unicode-xid",
]
[[package]]
name = "quote"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bc8cc6a5f2e3655e0899c1b848643b2562f853f114bfec7be120678e3ace05"
dependencies = [
"proc-macro2",
]
[[package]]
name = "sha-1"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d94d0bede923b3cea61f3f1ff57ff8cdfd77b400fb8f9998949e0cf04163df"
dependencies = [
"block-buffer",
"digest",
"fake-simd",
"opaque-debug",
]
[[package]]
name = "syn"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2afee18b8beb5a596ecb4a2dce128c719b4ba399d34126b9e4396e3f9860966"
dependencies = [
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "typenum"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b63708a265f51345575b27fe43f9500ad611579e764c79edbc2037b1121959ec"
[[package]]
name = "ucd-trie"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c"
[[package]]
name = "unicode-xid"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "zokrates_parser"
version = "0.1.6"
dependencies = [
"glob",
"pest",
"pest_derive",
]

View File

@@ -1,6 +1,6 @@
[package]
name = "zokrates_parser"
version = "0.1.6"
version = "0.2.4"
authors = ["JacobEberhardt <jacob.eberhardt@tu-berlin.de>"]
edition = "2018"
@@ -9,4 +9,4 @@ pest = "2.0"
pest_derive = "2.0"
[dev-dependencies]
glob = "0.2"
glob = "0.2"

View File

@@ -37,7 +37,7 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac
var ZoKratesHighlightRules = function () {
var keywords = (
"assert|as|bool|byte|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32"
"assert|as|bool|byte|const|def|do|else|endfor|export|false|field|for|if|then|fi|import|from|in|private|public|return|struct|true|u8|u16|u32|u64"
);
var keywordMapper = this.createKeywordMapper({
@@ -45,8 +45,9 @@ ace.define("ace/mode/zokrates_highlight_rules",["require","exports","module","ac
}, "identifier");
var decimalInteger = "(?:(?:[1-9]\\d*)|(?:0))";
var decimalSuffix = "(?:_?(?:f|u(?:8|16|32|64)))?";
var hexInteger = "(?:0[xX][\\dA-Fa-f]+)";
var integer = "(?:" + decimalInteger + "|" + hexInteger + ")\\b";
var integer = "(?:" + decimalInteger + decimalSuffix + "|" + hexInteger + ")\\b";
this.$rules = {
"start": [
@@ -117,4 +118,4 @@ ace.define("ace/mode/zokrates",["require","exports","module","ace/lib/oop","ace/
}).call(Mode.prototype);
exports.Mode = Mode;
});
});

View File

@@ -1,6 +1,6 @@
{
"name": "ace-mode-zokrates",
"version": "1.0.2",
"version": "1.0.4",
"description": "Ace Mode for ZoKrates DSL",
"main": "index.js",
"scripts": {

View File

@@ -1,3 +1,5 @@
#![allow(clippy::upper_case_acronyms)] // we allow uppercase acronyms because the pest derive generates WHITESPACE and COMMENT which have special meaning in pest
extern crate pest;
#[macro_use]
extern crate pest_derive;
@@ -53,224 +55,263 @@ mod tests {
mod rules {
use super::*;
#[test]
fn parse_valid_identifier() {
parses_to! {
parser: ZoKratesParser,
input: "valididentifier_01",
rule: Rule::identifier,
tokens: [
identifier(0, 18)
]
};
}
#[test]
fn parse_parameter_list() {
parses_to! {
parser: ZoKratesParser,
input: "def foo(field a) -> (field, field): return 1
",
rule: Rule::function_definition,
tokens: [
function_definition(0, 45, [
identifier(4, 7),
// parameter_list is not created (silent rule)
parameter(8, 15, [
ty(8, 13, [
ty_basic(8, 13, [
ty_field(8, 13)
])
]),
identifier(14, 15)
]),
// type_list is not created (silent rule)
ty(21, 26, [
ty_basic(21, 26, [
ty_field(21, 26)
])
]),
ty(28, 33, [
ty_basic(28, 33, [
ty_field(28, 33)
])
]),
statement(36, 45, [
return_statement(36, 44, [
expression(43, 44, [
term(43, 44, [
primary_expression(43, 44, [
constant(43, 44, [
decimal_number(43, 44)
])
])
])
])
])
])
])
]
};
}
// TODO: uncomment these tests once https://github.com/pest-parser/pest/pull/493 is resolved
#[test]
fn parse_single_def_to_multi() {
parses_to! {
parser: ZoKratesParser,
input: r#"a = foo()
"#,
rule: Rule::statement,
tokens: [
statement(0, 22, [
definition_statement(0, 9, [
optionally_typed_assignee(0, 2, [
assignee(0, 2, [
identifier(0, 1)
])
]),
expression(4, 9, [
term(4, 9, [
postfix_expression(4, 9, [
identifier(4, 7),
access(7, 9, [
call_access(7, 9)
])
])
])
]),
])
])
]
};
}
// #[test]
// fn parse_valid_identifier() {
// parses_to! {
// parser: ZoKratesParser,
// input: "valididentifier_01",
// rule: Rule::identifier,
// tokens: [
// identifier(0, 18)
// ]
// };
// }
#[test]
fn parse_field_def_to_multi() {
parses_to! {
parser: ZoKratesParser,
input: r#"field a = foo()
"#,
rule: Rule::statement,
tokens: [
statement(0, 28, [
definition_statement(0, 15, [
optionally_typed_assignee(0, 8, [
ty(0, 5, [
ty_basic(0, 5, [
ty_field(0, 5)
])
]),
assignee(6, 8, [
identifier(6, 7)
])
]),
expression(10, 15, [
term(10, 15, [
postfix_expression(10, 15, [
identifier(10, 13),
access(13, 15, [
call_access(13, 15)
])
])
])
]),
])
])
]
};
}
// #[test]
// fn parse_parameter_list() {
// parses_to! {
// parser: ZoKratesParser,
// input: "def foo<P, Q>(field[P] a) -> (field, field): return 1
// ",
// rule: Rule::function_definition,
// tokens: [
// function_definition(0, 54, [
// identifier(4, 7),
// identifier(8, 9),
// identifier(11, 12),
// // parameter_list is not created (silent rule)
// parameter(14, 24, [
// ty(14, 23, [
// ty_array(14, 23, [
// ty_basic_or_struct(14, 19, [
// ty_basic(14, 19, [
// ty_field(14, 19)
// ])
// ]),
// expression(20, 21, [
// term(20, 21, [
// primary_expression(20, 21, [
// identifier(20, 21)
// ])
// ])
// ])
// ])
// ]),
// identifier(23, 24)
// ]),
// // type_list is not created (silent rule)
// ty(30, 35, [
// ty_basic(30, 35, [
// ty_field(30, 35)
// ])
// ]),
// ty(37, 42, [
// ty_basic(37, 42, [
// ty_field(37, 42)
// ])
// ]),
// statement(45, 54, [
// return_statement(45, 53, [
// expression(52, 53, [
// term(52, 53, [
// primary_expression(52, 53, [
// literal(52, 53, [
// decimal_literal(52, 53, [
// decimal_number(52, 53)
// ])
// ])
// ])
// ])
// ])
// ])
// ])
// ])
// ]
// };
// }
#[test]
fn parse_u8_def_to_multi() {
parses_to! {
parser: ZoKratesParser,
input: r#"u32 a = foo()
"#,
rule: Rule::statement,
tokens: [
statement(0, 26, [
definition_statement(0, 13, [
optionally_typed_assignee(0, 6, [
ty(0, 3, [
ty_basic(0, 3, [
ty_u32(0, 3)
])
]),
assignee(4, 6, [
identifier(4, 5)
])
]),
expression(8, 13, [
term(8, 13, [
postfix_expression(8, 13, [
identifier(8, 11),
access(11, 13, [
call_access(11, 13)
])
])
])
]),
])
])
]
};
}
// #[test]
// fn parse_single_def_to_multi() {
// parses_to! {
// parser: ZoKratesParser,
// input: r#"a = foo::<_>(x)
// "#,
// rule: Rule::statement,
// tokens: [
// statement(0, 28, [
// definition_statement(0, 15, [
// optionally_typed_assignee(0, 2, [
// assignee(0, 2, [
// identifier(0, 1)
// ])
// ]),
// expression(4, 15, [
// term(4, 15, [
// postfix_expression(4, 15, [
// identifier(4, 7),
// access(7, 15, [
// call_access(7, 15, [
// explicit_generics(7, 12, [
// constant_generics_value(10, 11, [
// underscore(10, 11)
// ])
// ]),
// arguments(13, 14, [
// expression(13, 14, [
// term(13, 14, [
// primary_expression(13, 14, [
// identifier(13, 14)
// ])
// ])
// ])
// ])
// ])
// ])
// ])
// ])
// ]),
// ])
// ])
// ]
// };
// }
#[test]
fn parse_invalid_identifier() {
fails_with! {
parser: ZoKratesParser,
input: "0_invalididentifier",
rule: Rule::identifier,
positives: vec![Rule::identifier],
negatives: vec![],
pos: 0
};
}
// #[test]
// fn parse_field_def_to_multi() {
// parses_to! {
// parser: ZoKratesParser,
// input: r#"field a = foo()
// "#,
// rule: Rule::statement,
// tokens: [
// statement(0, 28, [
// definition_statement(0, 15, [
// optionally_typed_assignee(0, 8, [
// ty(0, 5, [
// ty_basic(0, 5, [
// ty_field(0, 5)
// ])
// ]),
// assignee(6, 8, [
// identifier(6, 7)
// ])
// ]),
// expression(10, 15, [
// term(10, 15, [
// postfix_expression(10, 15, [
// identifier(10, 13),
// access(13, 15, [
// call_access(13, 15, [
// arguments(14, 14)
// ])
// ])
// ])
// ])
// ]),
// ])
// ])
// ]
// };
// }
#[test]
fn parse_struct_def() {
parses_to! {
parser: ZoKratesParser,
input: "struct Foo { field foo\n field[2] bar }
",
rule: Rule::ty_struct_definition,
tokens: [
ty_struct_definition(0, 39, [
identifier(7, 10),
struct_field(13, 22, [
ty(13, 18, [
ty_basic(13, 18, [
ty_field(13, 18)
])
]),
identifier(19, 22)
]),
struct_field(24, 36, [
ty(24, 33, [
ty_array(24, 33, [
ty_basic_or_struct(24, 29, [
ty_basic(24, 29, [
ty_field(24, 29)
])
]),
expression(30, 31, [
term(30, 31, [
primary_expression(30, 31, [
constant(30, 31, [
decimal_number(30, 31)
])
])
])
])
])
]),
identifier(33, 36)
])
])
]
};
}
// #[test]
// fn parse_u8_def_to_multi() {
// parses_to! {
// parser: ZoKratesParser,
// input: r#"u32 a = foo()
// "#,
// rule: Rule::statement,
// tokens: [
// statement(0, 26, [
// definition_statement(0, 13, [
// optionally_typed_assignee(0, 6, [
// ty(0, 3, [
// ty_basic(0, 3, [
// ty_u32(0, 3)
// ])
// ]),
// assignee(4, 6, [
// identifier(4, 5)
// ])
// ]),
// expression(8, 13, [
// term(8, 13, [
// postfix_expression(8, 13, [
// identifier(8, 11),
// access(11, 13, [
// call_access(11, 13, [
// arguments(12, 12)
// ])
// ])
// ])
// ])
// ]),
// ])
// ])
// ]
// };
// }
// #[test]
// fn parse_invalid_identifier() {
// fails_with! {
// parser: ZoKratesParser,
// input: "0_invalididentifier",
// rule: Rule::identifier,
// positives: vec![Rule::identifier],
// negatives: vec![],
// pos: 0
// };
// }
// #[test]
// fn parse_struct_def() {
// parses_to! {
// parser: ZoKratesParser,
// input: "struct Foo { field foo\n field[2] bar }
// ",
// rule: Rule::ty_struct_definition,
// tokens: [
// ty_struct_definition(0, 39, [
// identifier(7, 10),
// struct_field(13, 22, [
// ty(13, 18, [
// ty_basic(13, 18, [
// ty_field(13, 18)
// ])
// ]),
// identifier(19, 22)
// ]),
// struct_field(24, 36, [
// ty(24, 33, [
// ty_array(24, 33, [
// ty_basic_or_struct(24, 29, [
// ty_basic(24, 29, [
// ty_field(24, 29)
// ])
// ]),
// expression(30, 31, [
// term(30, 31, [
// primary_expression(30, 31, [
// literal(30, 31, [
// decimal_literal(30, 31, [
// decimal_number(30, 31)
// ])
// ])
// ])
// ])
// ])
// ])
// ]),
// identifier(33, 36)
// ])
// ])
// ]
// };
// }
#[test]
fn parse_invalid_identifier_because_keyword() {

Some files were not shown because too many files have changed in this diff Show More