Committed witnesses & randomness in Z# (& tests) (#154)

A basic implementation of committed witnesses & volatile RAM extraction in the Z# front-end.

The passes in question are still a bit brittle, so I left them behind a flag.
This commit is contained in:
Alex Ozdemir
2023-03-15 16:28:19 -07:00
committed by GitHub
parent a49e03abfb
commit 706405fd4f
37 changed files with 723 additions and 129 deletions

View File

@@ -70,7 +70,7 @@ lp = ["good_lp", "lp-solvers"]
aby = ["lp"]
kahip = ["aby"]
kahypar = ["aby"]
r1cs = []
r1cs = ["bincode"]
poly = ["rug-polynomial"]
spartan = ["r1cs", "dep:spartan", "merlin", "curve25519-dalek", "bincode", "gmp-mpfr-sys"]
bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys", "byteorder"]
@@ -82,6 +82,10 @@ name = "circ"
name = "zk"
required-features = ["r1cs"]
[[example]]
name = "cp"
required-features = ["bellman", "poly"]
[[example]]
name = "zxi"
required-features = ["smt", "zok"]

View File

@@ -298,6 +298,19 @@ impl FieldV {
}
i
}
/// Raise this element to a power.
#[inline]
pub fn pow(&self, u: u64) -> Self {
match self {
FieldV::FBls12381(f) => FieldV::FBls12381(f.pow_vartime(&[u])),
FieldV::FBn254(f) => FieldV::FBn254(f.pow_vartime(&[u])),
FieldV::IntField(i) => FieldV::IntField(IntField::new(
i.i.clone().pow_mod(&Integer::from(u), i.modulus()).unwrap(),
i.modulus_arc(),
)),
}
}
}
impl Display for FieldV {

View File

@@ -88,7 +88,7 @@ pub trait Weak<Op>: Sized + Clone + PartialEq + Eq + PartialOrd + Ord + Hash {
/// A unique term ID.
#[repr(transparent)]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub struct Id(pub u64);
impl std::fmt::Display for Id {
@@ -97,6 +97,12 @@ impl std::fmt::Display for Id {
}
}
impl std::fmt::Debug for Id {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "id{}", self.0)
}
}
mod hash {
use super::Id;
use std::hash::{Hash, Hasher};

View File

@@ -86,6 +86,13 @@ Options:
- wrap: x % 2^b
- panic: a panic
--ram <ENABLED>
Whether to use advanced RAM techniques
[env: RAM=]
[default: false]
[possible values: true, false]
--ram-range <RANGE>
How to argue that values are in a range
@@ -178,6 +185,8 @@ Options:
Which modulus to use (overrides [FieldOpt::builtin]) [env: FIELD_CUSTOM_MODULUS=] [default: ]
--ir-field-to-bv <FIELD_TO_BV>
Which field to use [env: IR_FIELD_TO_BV=] [default: wrap] [possible values: wrap, panic]
--ram <ENABLED>
Whether to use advanced RAM techniques [env: RAM=] [default: false] [possible values: true, false]
--ram-range <RANGE>
How to argue that values are in a range [env: RAM_RANGE=] [default: sort] [possible values: bit-split, sort]
--ram-index <INDEX>
@@ -221,6 +230,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -264,6 +274,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -305,6 +316,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -346,6 +358,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -387,6 +400,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -428,6 +442,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -469,6 +484,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -510,6 +526,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -554,6 +571,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -596,6 +614,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -640,6 +659,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -682,6 +702,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -726,6 +747,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},
@@ -768,6 +790,7 @@ BinaryOpt {
field_to_bv: Wrap,
},
ram: RamOpt {
enabled: false,
range: Sort,
index: Uniqueness,
},

View File

@@ -199,6 +199,14 @@ impl Default for FieldToBv {
/// Options related to memory.
#[derive(Args, Debug, Default, Clone, PartialEq, Eq)]
pub struct RamOpt {
/// Whether to use advanced RAM techniques
#[arg(
long = "ram",
env = "RAM",
action = ArgAction::Set,
default_value = "false"
)]
pub enabled: bool,
/// How to argue that values are in a range
#[arg(
long = "ram-range",

View File

@@ -164,6 +164,8 @@ def test(features, extra_args):
log_run_check(["./scripts/spartan_zok_test.zsh"])
else: # bellman field
log_run_check(["./scripts/zokrates_test.zsh"])
if "poly" in features:
log_run_check(["./scripts/cp_test.zsh"])
if "lp" in features and "r1cs" in features:
log_run_check(["./scripts/test_zok_to_ilp_pf.zsh"])

View File

@@ -0,0 +1,5 @@
// persistent RAM
def main(committed field[4] array, private field x) -> field:
field y = array[x]
cond_store(array, x, 0, true)
return y

View File

@@ -0,0 +1,5 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(array (#l (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (#f0 #f6 #f7 #f8)))
) false ; ignored
))

View File

@@ -0,0 +1,5 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(array (#l (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (#f5 #f6 #f7 #f8)))
) false ; ignored
))

View File

@@ -0,0 +1,6 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(array (#l (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (#f5 #f6 #f7 #f8)))
(x #f0)
) false ; ignored
))

View File

@@ -0,0 +1,5 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(return #f5)
) false ; ignored
))

View File

@@ -0,0 +1,11 @@
// volatile RAM
const u32 LEN = 8196
const field ACC = 10
def main(private field x, private field y, private bool b) -> field:
field[LEN] array = [0; LEN]
for field i in 0..ACC do
cond_store(array, x+i, 1, b)
endfor
return array[y]

View File

@@ -0,0 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(x #f0)
(y #f9)
(b true)
) false ; ignored
))

View File

@@ -0,0 +1,5 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(return #f1)
) false ; ignored
))

View File

@@ -1,4 +1,5 @@
#![allow(unused_imports)]
#![allow(clippy::vec_init_then_push)]
#[cfg(feature = "bellman")]
use bellman::{
gadgets::test::TestConstraintSystem,
@@ -37,7 +38,11 @@ use circ::target::ilp::{assignment_to_values, trans::to_ilp};
#[cfg(feature = "spartan")]
use circ::target::r1cs::spartan::write_data;
#[cfg(feature = "bellman")]
use circ::target::r1cs::{bellman::Bellman, proof::ProofSystem};
use circ::target::r1cs::{
bellman::Bellman,
mirage::Mirage,
proof::{CommitProofSystem, ProofSystem},
};
#[cfg(feature = "r1cs")]
use circ::target::r1cs::{opt::reduce_linearities, trans::to_r1cs};
#[cfg(feature = "smt")]
@@ -46,6 +51,7 @@ use circ_fields::FieldT;
use fxhash::FxHashMap as HashMap;
#[cfg(feature = "lp")]
use good_lp::default_solver;
use log::trace;
use std::fs::File;
use std::io::Read;
use std::io::Write;
@@ -134,6 +140,7 @@ pub enum CostModelType {
enum ProofAction {
Count,
Setup,
CpSetup,
SpartanSetup,
}
@@ -251,26 +258,33 @@ fn main() {
// vec![Opt::Sha, Opt::ConstantFold, Opt::Mem, Opt::ConstantFold],
)
}
Mode::Proof | Mode::ProofOfHighValue(_) => opt(
cs,
vec![
Opt::ScalarizeVars,
Opt::Flatten,
Opt::Sha,
Opt::ConstantFold(Box::new([])),
// Tuples must be eliminated before oblivious array elim
Opt::Tuple,
Opt::ConstantFold(Box::new([])),
Opt::Obliv,
// The obliv elim pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::LinearScan,
// The linear scan pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::Flatten,
Opt::ConstantFold(Box::new([])),
],
),
Mode::Proof | Mode::ProofOfHighValue(_) => {
let mut opts = Vec::new();
opts.push(Opt::ScalarizeVars);
opts.push(Opt::Flatten);
opts.push(Opt::Sha);
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::ParseCondStores);
// Tuples must be eliminated before oblivious array elim
opts.push(Opt::Tuple);
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::Tuple);
opts.push(Opt::Obliv);
// The obliv elim pass produces more tuples, that must be eliminated
opts.push(Opt::Tuple);
if options.circ.ram.enabled {
opts.push(Opt::PersistentRam);
opts.push(Opt::VolatileRam);
opts.push(Opt::SkolemizeChallenges);
}
opts.push(Opt::LinearScan);
// The linear scan pass produces more tuples, that must be eliminated
opts.push(Opt::Tuple);
opts.push(Opt::Flatten);
opts.push(Opt::ConstantFold(Box::new([])));
opt(cs, opts)
}
};
println!("Done with IR optimization");
@@ -280,10 +294,12 @@ fn main() {
action,
prover_key,
verifier_key,
proof_impl,
..
} => {
println!("Converting to r1cs");
let cs = cs.get("main");
trace!("IR: {}", circ::ir::term::text::serialize_computation(cs));
let mut r1cs = to_r1cs(cs, cfg());
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
@@ -296,16 +312,41 @@ fn main() {
#[cfg(feature = "bellman")]
ProofAction::Setup => {
println!("Generating Parameters");
Bellman::<Bls12>::setup_fs(
prover_data,
verifier_data,
prover_key,
verifier_key,
)
.unwrap();
match proof_impl {
ProofImpl::Groth16 => Bellman::<Bls12>::setup_fs(
prover_data,
verifier_data,
prover_key,
verifier_key,
)
.unwrap(),
ProofImpl::Mirage => Mirage::<Bls12>::setup_fs(
prover_data,
verifier_data,
prover_key,
verifier_key,
)
.unwrap(),
};
}
#[cfg(not(feature = "bellman"))]
ProofAction::Setup => panic!("Missing feature: bellman"),
#[cfg(feature = "bellman")]
ProofAction::CpSetup => {
println!("Generating Parameters");
match proof_impl {
ProofImpl::Groth16 => panic!("Groth16 is not CP"),
ProofImpl::Mirage => Mirage::<Bls12>::cp_setup_fs(
prover_data,
verifier_data,
prover_key,
verifier_key,
)
.unwrap(),
};
}
#[cfg(not(feature = "bellman"))]
ProofAction::CpSetup => panic!("Missing feature: bellman"),
#[cfg(feature = "spartan")]
ProofAction::SpartanSetup => {
write_data::<_, _>(prover_key, verifier_key, &prover_data, &verifier_data)

86
examples/cp.rs Normal file
View File

@@ -0,0 +1,86 @@
use circ::cfg::{
clap::{self, Parser, ValueEnum},
CircOpt,
};
use std::path::PathBuf;
use bls12_381::Bls12;
use circ::target::r1cs::{mirage, proof::CommitProofSystem};
#[derive(Debug, Parser)]
#[command(name = "zk", about = "The CirC Commit-Prove runner")]
struct Options {
#[arg(long, default_value = "P")]
prover_key: PathBuf,
#[arg(long, default_value = "V")]
verifier_key: PathBuf,
#[arg(long, default_value = "pi")]
proof: PathBuf,
#[arg(long, default_value = "in")]
inputs: PathBuf,
#[arg(long)]
/// Commitment randomness (path)
rands: Vec<PathBuf>,
#[arg(long)]
/// Commitments (path)
commits: Vec<PathBuf>,
#[arg(long)]
action: ProofAction,
#[command(flatten)]
circ: CircOpt,
}
#[derive(PartialEq, Debug, Clone, ValueEnum)]
/// `Prove`/`Verify` execute proving/verifying in bellman separately
/// `Spartan` executes both proving/verifying in spartan
enum ProofAction {
Prove,
Verify,
Commit,
SampleRand,
}
type Mirage = mirage::Mirage<Bls12>;
fn main() {
env_logger::Builder::from_default_env()
.format_level(false)
.format_timestamp(None)
.init();
let opts = Options::parse();
circ::cfg::set(&opts.circ);
match opts.action {
ProofAction::SampleRand => {
for r in &opts.rands {
Mirage::sample_com_rand_fs(r).unwrap();
}
}
ProofAction::Prove => {
Mirage::cp_prove_fs(&opts.prover_key, &opts.inputs, &opts.proof, opts.rands).unwrap();
}
ProofAction::Verify => {
assert!(Mirage::cp_verify_fs(
&opts.verifier_key,
&opts.inputs,
&opts.proof,
opts.commits
)
.unwrap());
}
ProofAction::Commit => {
assert_eq!(
1,
opts.rands.len(),
"Must specify *one* commitment randomness path"
);
assert_eq!(1, opts.commits.len(), "Must specify *one* commitment path");
Mirage::cp_commit_fs(
&opts.verifier_key,
&opts.inputs,
&opts.rands[0],
&opts.commits[0],
)
.unwrap();
}
}
}

View File

@@ -7,7 +7,7 @@ use std::path::PathBuf;
#[cfg(feature = "bellman")]
use bls12_381::Bls12;
#[cfg(feature = "bellman")]
use circ::target::r1cs::{bellman::Bellman, proof::ProofSystem};
use circ::target::r1cs::{bellman::Bellman, mirage::Mirage, proof::ProofSystem};
#[cfg(feature = "spartan")]
use circ::ir::term::text::parse_value_map;
@@ -60,22 +60,37 @@ fn main() {
.init();
let opts = Options::parse();
circ::cfg::set(&opts.circ);
match opts.action {
match (opts.action, opts.proof_impl) {
#[cfg(feature = "bellman")]
ProofAction::Prove => {
(ProofAction::Prove, ProofImpl::Groth16) => {
println!("Proving");
Bellman::<Bls12>::prove_fs(opts.prover_key, opts.inputs, opts.proof).unwrap();
}
#[cfg(feature = "bellman")]
ProofAction::Verify => {
(ProofAction::Prove, ProofImpl::Mirage) => {
println!("Proving");
Mirage::<Bls12>::prove_fs(opts.prover_key, opts.inputs, opts.proof).unwrap();
}
#[cfg(feature = "bellman")]
(ProofAction::Verify, ProofImpl::Groth16) => {
println!("Verifying");
assert!(
Bellman::<Bls12>::verify_fs(opts.verifier_key, opts.inputs, opts.proof).unwrap(),
"invalid proof"
);
}
#[cfg(feature = "bellman")]
(ProofAction::Verify, ProofImpl::Mirage) => {
println!("Verifying");
assert!(
Mirage::<Bls12>::verify_fs(opts.verifier_key, opts.inputs, opts.proof).unwrap(),
"invalid proof"
);
}
#[cfg(not(feature = "bellman"))]
(ProofAction::Prove | ProofAction::Verify, _) => panic!("Missing feature: bellman"),
#[cfg(feature = "spartan")]
ProofAction::Spartan => {
(ProofAction::Spartan, _) => {
let prover_input_map = parse_value_map(&std::fs::read(opts.pin).unwrap());
println!("Spartan Proving");
let (gens, inst, proof) = spartan::prove(opts.prover_key, &prover_input_map).unwrap();
@@ -84,9 +99,7 @@ fn main() {
println!("Spartan Verifying");
spartan::verify(opts.verifier_key, &verifier_input_map, &gens, &inst, proof).unwrap();
}
#[cfg(not(feature = "bellman"))]
ProofAction::Prove | ProofAction::Verify => panic!("Missing feature: bellman"),
#[cfg(not(feature = "spartan"))]
ProofAction::Spartan => panic!("Missing feature: spartan"),
(ProofAction::Spartan, _) => panic!("Missing feature: spartan"),
}
}

44
scripts/cp_test.zsh Executable file
View File

@@ -0,0 +1,44 @@
#!/usr/bin/env zsh
set -ex
disable -r time
# cargo build --release --features r1cs,smt,zok --example circ
# cargo build --example circ
MODE=release # debug or release
BIN=./target/$MODE/examples/circ
CP_BIN=./target/$MODE/examples/cp
ZK_BIN=./target/$MODE/examples/zk
# Test prove workflow, given an example name
function com_wit_test {
ex_name=$1
init_rand=$(mktemp /tmp/tmp.circ.init_rand.XXXXXXXXXX)
fin_rand=$(mktemp /tmp/tmp.circ.fin_rand.XXXXXXXXXX)
init_commit=$(mktemp /tmp/tmp.circ.init_commit.XXXXXXXXXX)
fin_commit=$(mktemp /tmp/tmp.circ.fin_commit.XXXXXXXXXX)
$BIN --ram true $ex_name r1cs --action cp-setup --proof-impl mirage
$CP_BIN --action sample-rand --rands $init_rand
$CP_BIN --action sample-rand --rands $fin_rand
$CP_BIN --action commit --inputs $ex_name.array.init --rands $init_rand --commits $init_commit
$CP_BIN --action commit --inputs $ex_name.array.fin --rands $fin_rand --commits $fin_commit
$CP_BIN --action prove --inputs $ex_name.pin --rands $init_rand --rands $fin_rand
$CP_BIN --action verify --inputs $ex_name.vin --commits $init_commit --commits $fin_commit
rm -rf P V pi
rm -rf $init_rand $fin_rand $init_commit $fin_commit
}
# Test prove workflow, given an example name
function pf_test {
proof_impl=mirage
ex_name=$1
$BIN --ram true $ex_name r1cs --action setup --proof-impl $proof_impl
$ZK_BIN --inputs $ex_name.pin --action prove --proof-impl $proof_impl
$ZK_BIN --inputs $ex_name.vin --action verify --proof-impl $proof_impl
rm -rf P V pi
}
com_wit_test ./examples/ZoKrates/pf/mem/tiny.zok
pf_test ./examples/ZoKrates/pf/mem/volatile.zok

View File

@@ -11,18 +11,18 @@ $BIN --language datalog ./examples/datalog/inv.pl r1cs --action count || true
$BIN --language datalog ./examples/datalog/call.pl r1cs --action count || true
$BIN --language datalog ./examples/datalog/arr.pl r1cs --action count || true
# Small R1cs b/c too little recursion.
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 4 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 4 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
[ "$size" -lt 10 ]
# Big R1cs b/c enough recursion
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 5 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 5 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
[ "$size" -gt 250 ]
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 10 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 10 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
[ "$size" -gt 250 ]
size=$(($BIN --language datalog ./examples/datalog/dec.pl --datalog-rec-limit 2 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dec.pl --datalog-rec-limit 2 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
[ "$size" -gt 250 ]
# Test prim-rec test
$BIN --language datalog ./examples/datalog/dec.pl --datalog-lint-prim-rec true smt
($BIN --language datalog ./examples/datalog/not_dec.pl --datalog-lint-prim-rec true smt || true) | egrep 'Not prim'
($BIN --language datalog ./examples/datalog/not_dec.pl --datalog-lint-prim-rec true smt || true) | grep -E 'Not prim'

View File

@@ -129,6 +129,7 @@ impl MemManager {
alloc.size
}
}
#[cfg(all(feature = "smt", feature = "test", feature = "zok"))]
mod test {
use super::*;

View File

@@ -375,6 +375,11 @@ pub trait Embeddable {
// Because the type alias may change.
#[allow(clippy::ptr_arg)]
fn initialize_return(&self, ty: &Self::Ty, ssa_name: &SsaName) -> Self::T;
/// Wrap an IR field->field array as a language-level persistent array.
fn wrap_persistent_array(&self, _t: Term) -> Self::T {
unimplemented!("wrap_persistent_array")
}
}
/// Manager for circuit-embedded state.
@@ -836,6 +841,11 @@ impl<E: Embeddable> Circify<E> {
self.cir_ctx.mem.borrow_mut().load(id, offset)
}
/// Conditional store to an AllocId based on an explicit condition
pub fn cond_store(&mut self, id: AllocId, offset: Term, val: Term, cond: Term) {
self.cir_ctx.mem.borrow_mut().store(id, offset, val, cond);
}
/// Conditional store to an AllocId based on current path condition
pub fn store(&mut self, id: AllocId, offset: Term, val: Term) {
let cond = self.condition();
@@ -854,6 +864,36 @@ impl<E: Embeddable> Circify<E> {
.borrow_mut()
.zero_allocate(size, addr_width, val_width)
}
/// Create a new persistent array.
pub fn start_persistent_array(
&mut self,
var: &str,
size: usize,
field: circ_fields::FieldT,
party: PartyId,
) -> E::T {
let ir = self
.cir_ctx
.cs
.borrow_mut()
.start_persistent_array(var, size, field, party);
let t = self.e.wrap_persistent_array(ir);
let ssa_name = self
.declare_env_name(var.into(), &t.type_())
.unwrap()
.clone();
assert!(self.vals.insert(ssa_name, Val::Term(t.clone())).is_none());
t
}
/// Record the final state
pub fn end_persistent_array(&mut self, var: &str, final_state: Term) {
self.cir_ctx
.cs
.borrow_mut()
.end_persistent_array(var, final_state)
}
}
const RET_NAME: &str = "return";

View File

@@ -7,11 +7,11 @@ pub mod zvisit;
use super::{FrontEnd, Mode};
use crate::cfg::cfg;
use crate::circify::{CircError, Circify, Loc, Val};
use crate::front::{PROVER_VIS, PUBLIC_VIS};
use crate::front::proof::PROVER_ID;
use crate::ir::proof::ConstraintMetadata;
use crate::ir::term::*;
use log::{debug, warn};
use log::{debug, trace, warn};
use rug::Integer;
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
@@ -138,6 +138,12 @@ fn loc_store(struct_: T, loc: &[ZAccess], val: T) -> Result<T, String> {
}
}
enum ZVis {
Public,
Private(u8),
Committed,
}
impl<'ast> ZGen<'ast> {
fn new(
asts: HashMap<PathBuf, ast::File<'ast>>,
@@ -675,16 +681,29 @@ impl<'ast> ZGen<'ast> {
let ret_ty = f.returns.first().map(|r| self.type_(r));
// set up stack frame for entry function
self.circ_enter_fn(n.to_owned(), ret_ty.clone());
let mut persistent_arrays: Vec<String> = Vec::new();
for p in f.parameters.iter() {
let ty = self.type_(&p.ty);
debug!("Entry param: {}: {}", p.id.value, ty);
let vis = self.interpret_visibility(&p.visibility);
if let ZVis::Committed = &vis {
persistent_arrays.push(p.id.value.clone());
}
let r = self.circ_declare_input(p.id.value.clone(), &ty, vis, None, false);
self.unwrap(r, &p.span);
}
for s in &f.statements {
self.unwrap(self.stmt_impl_::<false>(s), s.span());
}
for a in persistent_arrays {
let term = self
.circ_get_value(Loc::local(a.clone()))
.unwrap()
.unwrap_term()
.term;
trace!("End persistent_array {a}, {}", term);
self.circ.borrow_mut().end_persistent_array(&a, term);
}
if let Some(r) = self.circ_exit_fn() {
match self.mode {
Mode::Mpc(_) => {
@@ -703,7 +722,7 @@ impl<'ast> ZGen<'ast> {
let name = "return".to_owned();
let ret_val = r.unwrap_term();
let ret_var_val = self
.circ_declare_input(name, ty, PUBLIC_VIS, Some(ret_val.clone()), false)
.circ_declare_input(name, ty, ZVis::Public, Some(ret_val.clone()), false)
.expect("circ_declare return");
let ret_eq = eq(ret_val, ret_var_val).unwrap().term;
let mut assertions = std::mem::take(&mut *self.assertions.borrow_mut());
@@ -752,9 +771,13 @@ impl<'ast> ZGen<'ast> {
}
}
}
fn interpret_visibility(&self, visibility: &Option<ast::Visibility<'ast>>) -> Option<PartyId> {
fn interpret_visibility(&self, visibility: &Option<ast::Visibility<'ast>>) -> ZVis {
match visibility {
None | Some(ast::Visibility::Public(_)) => PUBLIC_VIS,
None | Some(ast::Visibility::Public(_)) => ZVis::Public,
Some(ast::Visibility::Committed(_)) => match self.mode {
Mode::Proof => ZVis::Committed,
_ => unimplemented!(),
},
Some(ast::Visibility::Private(private)) => match self.mode {
Mode::Proof | Mode::Opt | Mode::ProofOfHighValue(_) => {
if private.number.is_some() {
@@ -766,7 +789,7 @@ impl<'ast> ZGen<'ast> {
&private.span,
);
}
PROVER_VIS
ZVis::Private(PROVER_ID)
}
Mode::Mpc(n_parties) => {
let num_str = private
@@ -779,7 +802,7 @@ impl<'ast> ZGen<'ast> {
self.err(format!("Bad party number: {e}"), &private.span)
});
if num_val <= n_parties {
Some(num_val - 1)
ZVis::Private(num_val - 1)
} else {
self.err(
format!(
@@ -982,7 +1005,7 @@ impl<'ast> ZGen<'ast> {
ast::Expression::ArrayInitializer(ai) => {
let val = self.expr_impl_::<IS_CNST>(&ai.value)?;
let num = self.const_usize_impl_::<IS_CNST>(&ai.count)?;
array(vec![val; num])
fill_array(val, num)
}
ast::Expression::Postfix(p) => {
// assume no functions in arrays, etc.
@@ -1130,6 +1153,21 @@ impl<'ast> ZGen<'ast> {
}
}
}
ast::Statement::CondStore(e) => {
if IS_CNST {
return Err("cannot evaluate a const CondStore".into());
}
let a = self.identifier_impl_::<false>(&e.array)?;
let i = self.expr_impl_::<false>(&e.index)?;
let v = self.expr_impl_::<false>(&e.value)?;
let c = self.expr_impl_::<false>(&e.condition)?;
let cbool = bool(c)?;
let new = mut_array_store(a, i, v, cbool)?;
trace!("Cond store: {} to {}", e.array.value, new);
self.circ_assign(Loc::local(e.array.value.clone()), Val::Term(new))
.map_err(|e| format!("{e}"))?;
Ok(())
}
ast::Statement::Iteration(i) => {
let ty = self.type_impl_::<IS_CNST>(&i.ty)?;
let ival_cons = match ty {
@@ -1857,13 +1895,36 @@ impl<'ast> ZGen<'ast> {
&self,
name: String,
ty: &Ty,
vis: Option<PartyId>,
vis: ZVis,
precomputed_value: Option<T>,
mangle_name: bool,
) -> Result<T, CircError> {
self.circ
.borrow_mut()
.declare_input(name, ty, vis, precomputed_value, mangle_name)
match vis {
ZVis::Public => {
self.circ
.borrow_mut()
.declare_input(name, ty, None, precomputed_value, mangle_name)
}
ZVis::Private(i) => self.circ.borrow_mut().declare_input(
name,
ty,
Some(i),
precomputed_value,
mangle_name,
),
ZVis::Committed => {
let size = match ty {
Ty::Array(size, _) => *size,
_ => panic!(),
};
Ok(self.circ.borrow_mut().start_persistent_array(
&name,
size,
default_field(),
crate::front::proof::PROVER_ID,
))
}
}
}
fn circ_declare_init(&self, name: String, ty: Ty, val: Val<T>) -> Result<Val<T>, CircError> {

View File

@@ -1,5 +1,5 @@
//! Symbolic Z# terms
use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;
use std::fmt::{self, Display, Formatter};
use rug::Integer;
@@ -17,6 +17,7 @@ pub enum Ty {
Field,
Struct(String, FieldList<Ty>),
Array(usize, Box<Ty>),
MutArray(usize),
}
impl Display for Ty {
@@ -42,6 +43,7 @@ impl Display for Ty {
write!(f, "{bb}")?;
dims.iter().try_for_each(|d| write!(f, "[{d}]"))
}
Ty::MutArray(n) => write!(f, "MutArray({n})"),
}
}
}
@@ -52,15 +54,26 @@ impl fmt::Debug for Ty {
}
}
pub fn default_field() -> circ_fields::FieldT {
cfg().field().clone()
}
fn default_field_sort() -> Sort {
Sort::Field(default_field())
}
impl Ty {
fn sort(&self) -> Sort {
match self {
Self::Bool => Sort::Bool,
Self::Uint(w) => Sort::BitVector(*w),
Self::Field => Sort::Field(cfg().field().clone()),
Self::Array(n, b) => Sort::Array(
Box::new(Sort::Field(cfg().field().clone())),
Box::new(b.sort()),
Self::Field => default_field_sort(),
Self::Array(n, b) => {
Sort::Array(Box::new(default_field_sort()), Box::new(b.sort()), *n)
}
Self::MutArray(n) => Sort::Array(
Box::new(default_field_sort()),
Box::new(default_field_sort()),
*n,
),
Self::Struct(_name, fs) => {
@@ -85,6 +98,7 @@ impl Ty {
pub fn array_val_ty(&self) -> &Self {
match self {
Self::Array(_, b) => b,
// TODO: MutArray?
_ => panic!("Not an array type: {:?}", self),
}
}
@@ -130,6 +144,9 @@ impl T {
Ty::Array(size, _sort) => Ok((0..*size)
.map(|i| term![Op::Select; self.term.clone(), pf_lit_ir(i)])
.collect()),
Ty::MutArray(size) => Ok((0..*size)
.map(|i| term![Op::Select; self.term.clone(), pf_lit_ir(i)])
.collect()),
s => Err(format!("Not an array: {s}")),
}
}
@@ -143,6 +160,11 @@ impl T {
.map(|t| T::new(sort.clone(), t))
.collect())
}
Ty::MutArray(_size) => Ok(self
.unwrap_array_ir()?
.into_iter()
.map(|t| T::new(Ty::Field, t))
.collect()),
s => Err(format!("Not an array: {s}")),
}
}
@@ -365,7 +387,7 @@ fn rem_field(a: Term, b: Term) -> Term {
let len = cfg().field().modulus().significant_bits() as usize;
let a_bv = term![Op::PfToBv(len); a];
let b_bv = term![Op::PfToBv(len); b];
term![Op::UbvToPf(cfg().field().clone()); term![Op::BvBinOp(BvBinOp::Urem); a_bv, b_bv]]
term![Op::UbvToPf(default_field()); term![Op::BvBinOp(BvBinOp::Urem); a_bv, b_bv]]
}
fn rem_uint(a: Term, b: Term) -> Term {
@@ -661,6 +683,11 @@ pub fn slice(arr: T, start: Option<usize>, end: Option<usize>) -> Result<T, Stri
let end = end.unwrap_or(*size);
array(arr.unwrap_array()?.drain(start..end))
}
Ty::MutArray(size) => {
let start = start.unwrap_or(0);
let end = end.unwrap_or(*size);
array(arr.unwrap_array()?.drain(start..end))
}
a => Err(format!("Cannot slice {a}")),
}
}
@@ -704,25 +731,44 @@ pub fn field_store(struct_: T, field: &str, val: T) -> Result<T, String> {
}
}
fn coerce_to_field(i: T) -> Result<Term, String> {
match &i.ty {
Ty::Uint(_) => Ok(term![Op::UbvToPf(default_field()); i.term]),
Ty::Field => Ok(i.term),
_ => Err(format!("Cannot coerce {} to a field element", &i)),
}
}
pub fn array_select(array: T, idx: T) -> Result<T, String> {
match array.ty {
Ty::Array(_, elem_ty) if matches!(idx.ty, Ty::Uint(_) | Ty::Field) => {
let iterm = if matches!(idx.ty, Ty::Uint(_)) {
term![Op::UbvToPf(cfg().field().clone()); idx.term]
} else {
idx.term
};
let iterm = coerce_to_field(idx).unwrap();
Ok(T::new(*elem_ty, term![Op::Select; array.term, iterm]))
}
Ty::MutArray(_) if matches!(idx.ty, Ty::Uint(_) | Ty::Field) => {
let iterm = coerce_to_field(idx).unwrap();
Ok(T::new(Ty::Field, term![Op::Select; array.term, iterm]))
}
_ => Err(format!("Cannot index {} using {}", &array.ty, &idx.ty)),
}
}
pub fn mut_array_store(array: T, idx: T, val: T, cond: Term) -> Result<T, String> {
if !matches!(array.ty, Ty::MutArray(_) | Ty::Array(..)) {
return Err(format!(
"Can only call mut_array_store on arrays, not {array}"
));
}
let i = coerce_to_field(idx).map_err(|s| format!("{s}: mutable array index"))?;
let v = coerce_to_field(val).map_err(|s| format!("{s}: mutable array value"))?;
Ok(T::new(array.ty, term![Op::CStore; array.term, i, v, cond]))
}
pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> {
if matches!(&array.ty, Ty::Array(_, _)) && matches!(&idx.ty, Ty::Uint(_) | Ty::Field) {
// XXX(q) typecheck here?
let iterm = if matches!(idx.ty, Ty::Uint(_)) {
term![Op::UbvToPf(cfg().field().clone()); idx.term]
term![Op::UbvToPf(default_field()); idx.term]
} else {
idx.term
};
@@ -735,34 +781,17 @@ pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> {
}
}
fn ir_array<I: IntoIterator<Item = Term>>(sort: Sort, elems: I) -> Term {
let mut values = HashMap::new();
let to_insert = elems
.into_iter()
.enumerate()
.filter_map(|(i, t)| {
let i_val = pf_val(i);
match const_value(&t) {
Some(v) => {
values.insert(i_val, v);
None
}
None => Some((leaf_term(Op::Const(i_val)), t)),
}
})
.collect::<Vec<(Term, Term)>>();
let len = values.len() + to_insert.len();
let arr = leaf_term(Op::Const(Value::Array(Array::new(
Sort::Field(cfg().field().clone()),
Box::new(sort.default_value()),
values.into_iter().collect::<BTreeMap<_, _>>(),
len,
))));
to_insert
.into_iter()
.fold(arr, |arr, (idx, val)| term![Op::Store; arr, idx, val])
fn ir_array<I: IntoIterator<Item = Term>>(value_sort: Sort, elems: I) -> Term {
let key_sort = Sort::Field(cfg().field().clone());
term(Op::Array(key_sort, value_sort), elems.into_iter().collect())
}
pub fn fill_array(value: T, size: usize) -> Result<T, String> {
Ok(T::new(
Ty::Array(size, Box::new(value.ty)),
term![Op::Fill(default_field_sort(), size); value.term],
))
}
pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
let v: Vec<T> = elems.into_iter().collect();
if let Some(e) = v.first() {
@@ -785,7 +814,7 @@ pub fn uint_to_field(u: T) -> Result<T, String> {
match &u.ty {
Ty::Uint(_) => Ok(T::new(
Ty::Field,
term![Op::UbvToPf(cfg().field().clone()); u.term],
term![Op::UbvToPf(default_field()); u.term],
)),
u => Err(format!("Cannot do uint-to-field on {u}")),
}
@@ -923,7 +952,7 @@ impl Embeddable for ZSharp {
Ty::Field,
ctx.cs.borrow_mut().new_var(
&name,
Sort::Field(cfg().field().clone()),
default_field_sort(),
visibility,
precompute.map(|p| p.term),
),
@@ -951,6 +980,20 @@ impl Embeddable for ZSharp {
)
.unwrap()
}
Ty::MutArray(n) => {
let ps: Vec<Option<T>> = match precompute.map(|p| p.unwrap_array()) {
Some(Ok(v)) => v.into_iter().map(Some).collect(),
Some(Err(e)) => panic!("{}", e),
None => std::iter::repeat(None).take(*n).collect(),
};
debug_assert_eq!(*n, ps.len());
array(
ps.into_iter().enumerate().map(|(i, p)| {
self.declare_input(ctx, &Ty::Field, idx_name(&name, i), visibility, p)
}),
)
.unwrap()
}
Ty::Struct(n, fs) => T::new_struct(
n.clone(),
fs.fields()
@@ -980,4 +1023,9 @@ impl Embeddable for ZSharp {
fn initialize_return(&self, ty: &Self::Ty, _ssa_name: &String) -> Self::T {
ty.default()
}
fn wrap_persistent_array(&self, t: Term) -> Self::T {
let size = check(&t).as_array().2;
T::new(Ty::MutArray(size), t)
}
}

View File

@@ -196,6 +196,7 @@ pub fn walk_visibility<'ast, Z: ZVisitorMut<'ast>>(
use ast::Visibility::*;
match vis {
Public(pu) => visitor.visit_public_visibility(pu),
Committed(c) => visitor.visit_commited_visibility(c),
Private(pr) => visitor.visit_private_visibility(pr),
}
}
@@ -708,6 +709,7 @@ pub fn walk_statement<'ast, Z: ZVisitorMut<'ast>>(
Return(r) => visitor.visit_return_statement(r),
Definition(d) => visitor.visit_definition_statement(d),
Assertion(a) => visitor.visit_assertion_statement(a),
CondStore(a) => visitor.visit_cond_store_statement(a),
Iteration(i) => visitor.visit_iteration_statement(i),
}
}
@@ -786,6 +788,17 @@ pub fn walk_assertion_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor.visit_span(&mut asrt.span)
}
pub fn walk_cond_store_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
s: &mut ast::CondStoreStatement<'ast>,
) -> ZVisitorResult {
visitor.visit_identifier_expression(&mut s.array)?;
visitor.visit_array_index_expression(&mut s.index)?;
visitor.visit_array_index_expression(&mut s.value)?;
visitor.visit_expression(&mut s.condition)?;
visitor.visit_span(&mut s.span)
}
pub fn walk_iteration_statement<'ast, Z: ZVisitorMut<'ast>>(
visitor: &mut Z,
iter: &mut ast::IterationStatement<'ast>,

View File

@@ -757,6 +757,15 @@ impl<'ast, 'ret> ZVisitorMut<'ast> for ZStatementWalker<'ast, 'ret> {
walk_assertion_statement(self, asrt)
}
fn visit_cond_store_statement(
&mut self,
s: &mut ast::CondStoreStatement<'ast>,
) -> ZVisitorResult {
let bool_ty = ast::Type::Basic(ast::BasicType::Boolean(ast::BooleanType { span: s.span }));
self.unify(Some(bool_ty), &mut s.condition)?;
walk_cond_store_statement(self, s)
}
fn visit_iteration_statement(
&mut self,
iter: &mut ast::IterationStatement<'ast>,
@@ -848,17 +857,19 @@ impl<'ast, 'ret> ZVisitorMut<'ast> for ZStatementWalker<'ast, 'ret> {
use ast::RangeOrExpression::*;
match roe {
Range(r) => self.visit_range(r),
Expression(e) => {
let mut zty = ZExpressionTyper::new(self);
if self.type_expression(e, &mut zty)?.is_none() {
let mut zrw = ZConstLiteralRewriter::new(Some(Ty::Field));
zrw.visit_expression(e)?;
}
self.visit_expression(e)
}
Expression(e) => self.visit_array_index_expression(e),
}
}
fn visit_array_index_expression(&mut self, e: &mut ast::Expression<'ast>) -> ZVisitorResult {
let mut zty = ZExpressionTyper::new(self);
if self.type_expression(e, &mut zty)?.is_none() {
let mut zrw = ZConstLiteralRewriter::new(Some(Ty::Field));
zrw.visit_expression(e)?;
}
self.visit_expression(e)
}
fn visit_range(&mut self, rng: &mut ast::Range<'ast>) -> ZVisitorResult {
let mut zty = ZExpressionTyper::new(self);
let fty = rng

View File

@@ -113,6 +113,10 @@ pub trait ZVisitorMut<'ast>: Sized {
Ok(())
}
fn visit_commited_visibility(&mut self, _c: &mut ast::CommittedVisibility) -> ZVisitorResult {
Ok(())
}
fn visit_private_visibility(
&mut self,
pr: &mut ast::PrivateVisibility<'ast>,
@@ -343,6 +347,13 @@ pub trait ZVisitorMut<'ast>: Sized {
walk_array_access(self, aa)
}
fn visit_array_index_expression(
&mut self,
index: &mut ast::Expression<'ast>,
) -> ZVisitorResult {
walk_expression(self, index)
}
fn visit_range_or_expression(
&mut self,
roe: &mut ast::RangeOrExpression<'ast>,
@@ -446,6 +457,13 @@ pub trait ZVisitorMut<'ast>: Sized {
walk_assertion_statement(self, asrt)
}
fn visit_cond_store_statement(
&mut self,
s: &mut ast::CondStoreStatement<'ast>,
) -> ZVisitorResult {
walk_cond_store_statement(self, s)
}
fn visit_iteration_statement(
&mut self,
iter: &mut ast::IterationStatement<'ast>,

View File

@@ -129,7 +129,7 @@ pub fn skolemize_challenges(comp: &mut Computation) {
}
for name in comp.metadata.ordered_input_names() {
let md = comp.metadata.lookup_mut(&name);
md.round = *actual_round.get(&md.term()).unwrap();
md.round = *actual_round.get(&md.term()).unwrap_or(&0);
}
let mut challs = TermMap::default();

View File

@@ -354,7 +354,7 @@ pub fn elim_obliv(t: &mut Computation) {
let mut replace_pass = Replacer {
not_obliv: prop_pass.not_obliv,
};
<Replacer as RewritePass>::traverse(&mut replace_pass, t)
<Replacer as RewritePass>::traverse_full(&mut replace_pass, t, false, false)
}
#[cfg(test)]

View File

@@ -397,7 +397,7 @@ impl Ram {
self.end_of_time = true;
let time = self.next_time_term();
trace!(
"write: ops: idx {}, val {}, time {}",
"final read: ops: idx {}, val {}, time {}",
idx.op(),
val.op(),
time.op()

View File

@@ -2,10 +2,18 @@
use super::hash::{MsHasher, UniversalHasher};
use super::*;
use crate::front::PROVER_VIS;
use crate::util::ns::Namespace;
use circ_fields::FieldT;
use log::trace;
/// Check a RAM
pub fn check_ram(c: &mut Computation, ram: Ram) {
trace!(
"Checking RAM {}, size {}, {} accesses",
ram.id,
ram.size,
ram.accesses.len()
);
let f = ram.cfg.field.clone();
let (only_init, default) = match &ram.boundary_conditions {
BoundaryConditions::Default(d) => (false, d.clone()),
@@ -13,10 +21,10 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
BoundaryConditions::Persistent(..) => panic!(),
};
let id = ram.id;
let ns = Namespace::new().subspace(&format!("ram{id}"));
let f_s = Sort::Field(f.clone());
let var_name = |name: &str| format!("__ram{id}_{name}");
let mut new_var =
|name: &str, val: Term| c.new_var(&var_name(name), f_s.clone(), PROVER_VIS, Some(val));
|name: &str, val: Term| c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val));
// (1) sort the transcript
let field_tuples: Vec<Term> = ram
@@ -42,8 +50,8 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
.into_iter()
.chain(sorted_accesses.iter().map(|a| a.to_field_tuple(&ram.cfg)))
.collect();
let uhf = UniversalHasher::new(var_name("uhf_key"), &f, uhf_inputs.clone(), ram.cfg.len());
let msh = MsHasher::new(var_name("ms_hash_key"), &f, uhf_inputs);
let uhf = UniversalHasher::new(ns.fqn("uhf_key"), &f, uhf_inputs.clone(), ram.cfg.len());
let msh = MsHasher::new(ns.fqn("ms_hash_key"), &f, uhf_inputs);
// (2) permutation argument
let univ_hashes_unsorted: Vec<Term> = ram
@@ -130,6 +138,7 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
c,
accs.iter().map(|a| a.idx.clone()).collect(),
accs.iter().map(|a| a.create.b.clone()).collect(),
&ns,
&mut assertions,
&f,
);
@@ -140,7 +149,7 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
assertions.push(c.outputs[0].clone());
#[allow(clippy::type_complexity)]
let range_checker: Box<
dyn Fn(&mut Computation, Vec<Term>, &str, &mut Vec<Term>, usize, &FieldT),
dyn Fn(&mut Computation, Vec<Term>, &Namespace, &mut Vec<Term>, usize, &FieldT),
> = if ram.cfg.split_times {
Box::new(&bit_split_range_check)
} else {
@@ -149,7 +158,7 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
range_checker(
c,
deltas,
&format!("__ram{}_time", ram.id),
&ns.subspace("time"),
&mut assertions,
ram.next_time + 1,
&f,
@@ -165,11 +174,12 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
fn range_check(
c: &mut Computation,
mut values: Vec<Term>,
range_name: &str,
ns: &Namespace,
assertions: &mut Vec<Term>,
n: usize,
f: &FieldT,
) {
let ns = ns.subspace("range");
let f_sort = Sort::Field(f.clone());
debug_assert!(values.iter().all(|v| check(v) == f_sort));
let mut ms_hash_inputs = values.clone();
@@ -181,14 +191,14 @@ fn range_check(
.into_iter()
.enumerate()
.map(|(i, t)| {
let full_name = format!("__range{range_name}.{i}");
let full_name = ns.fqn(i);
c.new_var(&full_name, f_sort.clone(), PROVER_VIS, Some(t))
})
.collect();
// permutation argument
ms_hash_inputs.extend(sorted.iter().cloned());
let msh = MsHasher::new(format!("__range_{range_name}_key"), f, ms_hash_inputs);
let msh = MsHasher::new(ns.fqn("key"), f, ms_hash_inputs);
assertions.push(term![EQ; msh.hash(values), msh.hash(sorted.clone())]);
// delta: 0 or 1
@@ -218,7 +228,7 @@ fn range_check(
fn bit_split_range_check(
_c: &mut Computation,
values: Vec<Term>,
_range_name: &str,
_ns: &Namespace,
assertions: &mut Vec<Term>,
n: usize,
f: &FieldT,
@@ -241,9 +251,11 @@ fn derivative_gcd(
comp: &mut Computation,
values: Vec<Term>,
conditions: Vec<Term>,
ns: &Namespace,
assertions: &mut Vec<Term>,
f: &FieldT,
) {
let ns = ns.subspace("uniq");
let fs = Sort::Field(f.clone());
let pairs = term(
Op::Array(fs.clone(), Sort::Tuple(Box::new([fs.clone(), Sort::Bool]))),
@@ -257,13 +269,13 @@ fn derivative_gcd(
let two_polys = term![Op::ExtOp(ExtOp::UniqDeriGcd); pairs];
let s_coeffs = unmake_array(term![Op::Field(0); two_polys.clone()]);
let t_coeffs = unmake_array(term![Op::Field(1); two_polys]);
let mut decl_poly = |coeffs: Vec<Term>, name: &str| -> Vec<Term> {
let mut decl_poly = |coeffs: Vec<Term>, poly_name: &str| -> Vec<Term> {
coeffs
.into_iter()
.enumerate()
.map(|(i, coeff)| {
comp.new_var(
&format!("__uniq_{name}_c{i}"),
&ns.fqn(format!("{poly_name}{i}")),
fs.clone(),
PROVER_VIS,
Some(coeff),
@@ -281,7 +293,7 @@ fn derivative_gcd(
terms_that_define_all_polys.extend(t_coeffs_skolem.iter().cloned());
let n = values.len();
let x = term(
Op::PfChallenge("__uniq_key".into(), f.clone()),
Op::PfChallenge(ns.fqn("x"), f.clone()),
terms_that_define_all_polys,
);
let r = values;
@@ -302,7 +314,7 @@ fn derivative_gcd(
.enumerate()
.map(|(i, d)| {
let recip = comp.new_var(
&format!("__uniq_recip{i}"),
&ns.fqn(format!("recip{i}")),
fs.clone(),
PROVER_VIS,
Some(term![PF_RECIP; d.clone()]),

View File

@@ -110,8 +110,8 @@ pub trait ProgressAnalysisPass {
let mut progress = true;
let mut order = Vec::new();
let mut visited = TermSet::default();
let mut stack = Vec::new();
stack.extend(computation.outputs.iter().cloned());
let mut stack: Vec<Term> = computation.outputs.clone();
while let Some(top) = stack.pop() {
stack.extend(top.cs().iter().filter(|c| !visited.contains(c)).cloned());
// was it missing?

View File

@@ -148,6 +148,18 @@ impl<'cfg> ToR1cs<'cfg> {
}
}
/// Create a committed witness vector. Each input is a (name, term) pair.
fn committed_wit(&mut self, elements: Vec<(String, Term)>) {
self.r1cs.add_committed_witness(elements.clone());
for (name, value) in elements {
let lc = self.r1cs.signal_lc(&name);
let var = leaf_term(Op::Var(name, check(&value)));
self.embed.borrow_mut().insert(var.clone());
self.cache
.insert(var, EmbeddedTerm::Field(TermLc(value, lc)));
}
}
/// Get a new variable, with name dependent on `d`.
/// If values are being recorded, `value` must be provided.
///
@@ -558,6 +570,10 @@ impl<'cfg> ToR1cs<'cfg> {
Ult => self.bv_cmp(n, false, true, &c.cs()[1], &c.cs()[0]),
}
}
Op::PfToBoolTrusted => {
// we trust that this is zero or one
self.get_pf(&c.cs()[0]).clone()
}
_ => panic!("Non-boolean in embed_bool: {}", c),
};
self.cache.insert(c.clone(), EmbeddedTerm::Bool(lc));
@@ -582,6 +598,8 @@ impl<'cfg> ToR1cs<'cfg> {
for c in t.cs() {
self.assert_bool(c);
}
} else if let Op::PfFitsInBits(n) = t.op() {
self.embed(term![Op::PfToBv(*n); t.cs()[0].clone()]);
} else {
self.embed(t.clone());
let lc = self.get_bool(t).clone();
@@ -1115,6 +1133,13 @@ pub fn to_r1cs(cs: &Computation, cfg: &CircCfg) -> R1cs {
for i in &vars.instances {
converter.embed_var(i, VarType::Inst);
}
for terms in &vars.committed_wit_vecs {
let names_and_terms = terms
.iter()
.map(|t| (t.as_var_name().to_owned(), t.clone()))
.collect();
converter.committed_wit(names_and_terms);
}
for round in &vars.rounds {
for w in &round.witnesses {
converter.embed_var(w, VarType::RoundWit);
@@ -1401,4 +1426,33 @@ pub mod test {
let r1cs = to_r1cs_mod17(cs);
r1cs.check_all(&values);
}
#[test]
fn pf_fits_in_bits() {
let mut cs = text::parse_computation(
b"
(computation
(metadata (parties P) (inputs (a (mod 17)) (b (mod 17)) (c (mod 17))) (commitments))
(precompute () () (#t ))
(and
((pf_fits_in_bits 1) a)
((pf_fits_in_bits 2) (+ b c))
(= (+ b c) (+ c b))
((pf_fits_in_bits 2) #f1m17)
)
)
",
);
let values = text::parse_value_map(
b"(let(
(a #f1m17)
(b #f4m17)
(c #f-1m17)
) false; ignored
)",
);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let r1cs = to_r1cs_mod17(cs);
r1cs.check_all(&values);
}
}

View File

@@ -1,6 +1,7 @@
//! Various data structures, etc.
pub mod hc;
pub mod ns;
pub mod once;
#[cfg(test)]

24
src/util/ns.rs Normal file
View File

@@ -0,0 +1,24 @@
//! Namespacing
use std::fmt::Display;
/// A namespace. Used to create unique names.
///
/// Doesn't check for uniqueness: just a helper.
#[derive(Default)]
pub struct Namespace(String);
impl Namespace {
/// The root namespace
pub fn new() -> Self {
Namespace("".to_owned())
}
/// Get a subspace
pub fn subspace(&self, ext: impl Display) -> Self {
Namespace(format!("{}__{ext}", self.0))
}
/// Get a (fully qualified) name in this space
pub fn fqn(&self, ext: impl Display) -> String {
format!("{}_{ext}", self.0)
}
}

View File

@@ -45,13 +45,15 @@ struct_field = { ty ~ identifier }
vis_private_num = @{ "<" ~ ASCII_DIGIT* ~ ">" }
vis_private = {"private" ~ vis_private_num? }
vis_public = {"public"}
vis = { vis_private | vis_public }
vis_committed = {"committed"}
vis = { vis_private | vis_public | vis_committed }
// Statements
statement = { (return_statement // does not require subsequent newline
| (iteration_statement
| definition_statement
| expression_statement
| cond_store_statement
) ~ NEWLINE
) ~ NEWLINE* }
@@ -59,6 +61,7 @@ iteration_statement = { "for" ~ ty ~ identifier ~ "in" ~ expression ~ ".." ~ exp
return_statement = { "return" ~ expression_list}
definition_statement = { typed_identifier_or_assignee_list ~ "=" ~ expression } // declare and assign, so only identifiers are allowed, unlike `assignment_statement`
expression_statement = {"assert" ~ "(" ~ expression ~ ("," ~ quoted_string)? ~ ")"}
cond_store_statement = {"cond_store" ~ "(" ~ identifier ~ "," ~ expression ~ "," ~ expression ~ "," ~ expression ~ ")"}
typed_identifier_or_assignee_list = _{ typed_identifier_or_assignee ~ ("," ~ typed_identifier_or_assignee)* }
typed_identifier_or_assignee = { typed_identifier | assignee } // we don't use { ty? ~ identifier } as with a single token, it gets parsed as `ty` but we want `identifier`

View File

@@ -217,10 +217,11 @@ checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c"
[[package]]
name = "pest"
version = "2.1.3"
version = "2.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53"
checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f"
dependencies = [
"thiserror",
"ucd-trie",
]

View File

@@ -10,10 +10,10 @@ extern crate lazy_static;
pub use ast::{
Access, AnyString, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType,
AssertionStatement, Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression,
BinaryOperator, BooleanLiteralExpression, BooleanType, CallAccess, ConstantDefinition,
ConstantGenericValue, Curve, DecimalLiteralExpression, DecimalNumber, DecimalSuffix,
DefinitionStatement, ExplicitGenerics, Expression, FieldSuffix, FieldType, File,
FromExpression, FromImportDirective, FunctionDefinition, HexLiteralExpression,
BinaryOperator, BooleanLiteralExpression, BooleanType, CallAccess, CommittedVisibility,
CondStoreStatement, ConstantDefinition, ConstantGenericValue, Curve, DecimalLiteralExpression,
DecimalNumber, DecimalSuffix, DefinitionStatement, ExplicitGenerics, Expression, FieldSuffix,
FieldType, File, FromExpression, FromImportDirective, FunctionDefinition, HexLiteralExpression,
HexNumberExpression, IdentifierExpression, ImportDirective, ImportSymbol,
InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement,
LiteralExpression, MainImportDirective, MemberAccess, NegOperator, NotOperator, Parameter,
@@ -353,6 +353,7 @@ mod ast {
#[pest_ast(rule(Rule::vis))]
pub enum Visibility<'ast> {
Public(PublicVisibility),
Committed(CommittedVisibility),
Private(PrivateVisibility<'ast>),
}
@@ -369,6 +370,10 @@ mod ast {
#[pest_ast(rule(Rule::vis_public))]
pub struct PublicVisibility {}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::vis_committed))]
pub struct CommittedVisibility {}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::vis_private))]
pub struct PrivateVisibility<'ast> {
@@ -384,6 +389,7 @@ mod ast {
Return(ReturnStatement<'ast>),
Definition(DefinitionStatement<'ast>),
Assertion(AssertionStatement<'ast>),
CondStore(CondStoreStatement<'ast>),
Iteration(IterationStatement<'ast>),
}
@@ -393,6 +399,7 @@ mod ast {
Statement::Return(x) => &x.span,
Statement::Definition(x) => &x.span,
Statement::Assertion(x) => &x.span,
Statement::CondStore(x) => &x.span,
Statement::Iteration(x) => &x.span,
}
}
@@ -416,6 +423,17 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::cond_store_statement))]
pub struct CondStoreStatement<'ast> {
pub array: IdentifierExpression<'ast>,
pub index: Expression<'ast>,
pub value: Expression<'ast>,
pub condition: Expression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[pest_ast(rule(Rule::iteration_statement))]
pub struct IterationStatement<'ast> {