mirror of
https://github.com/circify/circ.git
synced 2026-01-10 06:08:02 -05:00
change metadata serialization & representation (#137)
In the new approach, all variable metadata is stored per-variable. This makes it easier to add new kinds of metadata, and to serialize that metadata.
This commit is contained in:
5
TODO.md
5
TODO.md
@@ -1,4 +1,9 @@
|
||||
Concrete:
|
||||
[ ] R1CS optimizations
|
||||
* reduce linearity without recip
|
||||
* don't debitify eagerly
|
||||
* cache pf lits?
|
||||
* LCs as vectors
|
||||
[ ] shrink bit-vectors using range analysis.
|
||||
* IR analysis infrastructure
|
||||
* shrink comparisons too
|
||||
|
||||
@@ -123,8 +123,6 @@ def test(features, extra_args):
|
||||
if "ristretto255" in features:
|
||||
test_cmd += ["--no-default-features"]
|
||||
test_cmd_release += ["--no-default-features"]
|
||||
test_cmd += ["--", "--test-threads=1"]
|
||||
test_cmd_release += ["--", "--test-threads=1"]
|
||||
if len(extra_args) > 0:
|
||||
test_cmd += [a for a in extra_args if a != "--"]
|
||||
test_cmd_release += [a for a in extra_args if a != "--"]
|
||||
|
||||
@@ -21,7 +21,7 @@ use circ::front::datalog::{self, Datalog};
|
||||
#[cfg(all(feature = "smt", feature = "zok"))]
|
||||
use circ::front::zsharp::{self, ZSharpFE};
|
||||
use circ::front::{FrontEnd, Mode};
|
||||
use circ::ir::term::{Op, BV_LSHR, BV_SHL};
|
||||
use circ::ir::term::{Node, Op, BV_LSHR, BV_SHL};
|
||||
use circ::ir::{
|
||||
opt::{opt, Opt},
|
||||
term::{
|
||||
@@ -322,9 +322,12 @@ fn main() {
|
||||
.get("main")
|
||||
.clone()
|
||||
.metadata
|
||||
.input_vis
|
||||
.ordered_inputs()
|
||||
.iter()
|
||||
.map(|(name, (sort, _))| (name.clone(), check(sort)))
|
||||
.map(|term| match term.op() {
|
||||
Op::Var(n, s) => (n.clone(), s.clone()),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
let ilp = to_ilp(cs.get("main").clone());
|
||||
let solver_result = ilp.solve(default_solver);
|
||||
|
||||
@@ -115,7 +115,7 @@ mod test {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (bv 4)) (b (bv 4)) (c (bv 4))) ())
|
||||
(metadata (parties ) (inputs (a (bv 4)) (b (bv 4)) (c (bv 4))))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -138,7 +138,7 @@ mod test {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (mod 5)) (b (mod 5)) (c (mod 5))) ())
|
||||
(metadata (parties ) (inputs (a (mod 5)) (b (mod 5)) (c (mod 5))))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -161,7 +161,7 @@ mod test {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (bv 4)) (b (bv 4)) (c (bv 4))) ())
|
||||
(metadata (parties ) (inputs (a (bv 4)) (b (bv 4)) (c (bv 4))))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -184,7 +184,7 @@ mod test {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (mod 5)) (b (mod 5)) (c (mod 5))) ())
|
||||
(metadata (parties ) (inputs (a (mod 5)) (b (mod 5)) (c (mod 5))))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
|
||||
@@ -330,7 +330,7 @@ mod test {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () () ())
|
||||
(metadata (parties ) (inputs ))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -355,7 +355,7 @@ mod test {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () () ())
|
||||
(metadata (parties ) (inputs ))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -393,7 +393,7 @@ mod test {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(metadata (parties ) (inputs (a bool)))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -432,7 +432,7 @@ mod test {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(metadata (parties ) (inputs (a bool)))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
@@ -472,7 +472,7 @@ mod test {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(metadata (parties ) (inputs (a bool)))
|
||||
(precompute () () (#t ))
|
||||
(let
|
||||
(
|
||||
|
||||
@@ -94,11 +94,7 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
|
||||
}
|
||||
}
|
||||
Opt::Inline => {
|
||||
let public_inputs = c
|
||||
.metadata
|
||||
.public_input_names()
|
||||
.map(ToOwned::to_owned)
|
||||
.collect();
|
||||
let public_inputs = c.metadata.public_input_names_set();
|
||||
inline::inline(&mut c.outputs, &public_inputs);
|
||||
}
|
||||
Opt::Tuple => {
|
||||
|
||||
@@ -105,22 +105,10 @@ pub fn assert_all_vars_are_scalars(cs: &Computation) {
|
||||
|
||||
/// Check that every variables is a scalar.
|
||||
fn remove_non_scalar_vars_from_main_computation(cs: &mut Computation) {
|
||||
let new_inputs = cs
|
||||
.metadata
|
||||
.computation_inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.filter(|i| cs.metadata.input_sort(i).is_scalar())
|
||||
.collect::<Vec<_>>();
|
||||
cs.metadata.computation_inputs = new_inputs;
|
||||
for t in cs.terms_postorder() {
|
||||
if let Op::Var(_name, sort) = &t.op() {
|
||||
match sort {
|
||||
Sort::Array(..) | Sort::Tuple(..) => {
|
||||
panic!("Variable {} is non-scalar", t);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
for input in cs.metadata.ordered_public_inputs() {
|
||||
if !check(&input).is_scalar() {
|
||||
cs.metadata.remove_var(input.as_var_name());
|
||||
}
|
||||
}
|
||||
assert_all_vars_are_scalars(cs);
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ pub fn as_uint_constant(t: &Term) -> Option<Integer> {
|
||||
/// Assert that all variables in the term graph are declared in the metadata.
|
||||
#[cfg(test)]
|
||||
pub fn assert_all_vars_declared(c: &Computation) {
|
||||
let vars: FxHashSet<String> = c.metadata.input_vis.iter().map(|p| p.0.clone()).collect();
|
||||
let vars: FxHashSet<String> = c.metadata.vars.iter().map(|p| p.0.clone()).collect();
|
||||
for o in &c.outputs {
|
||||
for v in free_variables(o.clone()) {
|
||||
assert!(vars.contains(&v), "Variable {} is not declared", v);
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
//! Machinery for formatting IR types
|
||||
use super::{Array, Node, Op, PostOrderIter, Sort, Term, TermMap, Value};
|
||||
use super::{
|
||||
Array, ComputationMetadata, Node, Op, PartyId, PostOrderIter, Sort, Term, TermMap, Value,
|
||||
VariableMetadata,
|
||||
};
|
||||
use crate::cfg::{cfg, is_cfg_set};
|
||||
|
||||
use circ_fields::{FieldT, FieldV};
|
||||
|
||||
use fxhash::FxHashSet as HashSet;
|
||||
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
|
||||
|
||||
use std::fmt::{Debug, Display, Error as FmtError, Formatter, Result as FmtResult, Write};
|
||||
|
||||
@@ -278,6 +281,40 @@ impl DisplayIr for Term {
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayIr for VariableMetadata {
|
||||
fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult {
|
||||
write!(f, "({} ", self.name)?;
|
||||
self.sort.ir_fmt(f)?;
|
||||
if let Some(v) = self.vis.as_ref() {
|
||||
write!(f, " (party {})", v)?;
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayIr for ComputationMetadata {
|
||||
fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult {
|
||||
write!(f, "(metadata\n (parties ")?;
|
||||
let ids_to_parties: HashMap<PartyId, &str> = self
|
||||
.party_ids
|
||||
.iter()
|
||||
.map(|(name, id)| (*id, name.as_str()))
|
||||
.collect();
|
||||
for id in 0..self.party_ids.len() as u8 {
|
||||
let party = ids_to_parties.get(&id).unwrap();
|
||||
write!(f, " {}", party)?;
|
||||
}
|
||||
writeln!(f, ")")?;
|
||||
write!(f, "\n (inputs")?;
|
||||
for v in self.vars.values() {
|
||||
write!(f, "\n ")?;
|
||||
v.ir_fmt(f)?;
|
||||
}
|
||||
write!(f, "\n )")?;
|
||||
write!(f, "\n)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a term, introducing bindings.
|
||||
fn fmt_term_with_bindings(t: &Term, f: &mut IrFormatter) -> FmtResult {
|
||||
let close_dft_f = if f.cfg.use_default_field && f.default_field.is_none() {
|
||||
@@ -382,3 +419,9 @@ impl Display for Op {
|
||||
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ComputationMetadata {
|
||||
fn fmt(&self, f: &mut Formatter) -> FmtResult {
|
||||
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -798,6 +798,12 @@ pub enum Sort {
|
||||
Tuple(Box<[Sort]>),
|
||||
}
|
||||
|
||||
impl Default for Sort {
|
||||
fn default() -> Self {
|
||||
Self::Bool
|
||||
}
|
||||
}
|
||||
|
||||
impl Sort {
|
||||
#[track_caller]
|
||||
/// Unwrap the bitsize of this bit-vector, panicking otherwise.
|
||||
@@ -1051,6 +1057,16 @@ impl Term {
|
||||
pub fn is_const(&self) -> bool {
|
||||
matches!(&self.op(), Op::Const(..))
|
||||
}
|
||||
|
||||
/// Get the variable name; panic if not a variable.
|
||||
#[track_caller]
|
||||
pub fn as_var_name(&self) -> &str {
|
||||
if let Op::Var(n, _) = &self.op() {
|
||||
n
|
||||
} else {
|
||||
panic!("not a variable")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
@@ -1572,101 +1588,138 @@ impl std::iter::Iterator for PostOrderIter {
|
||||
/// A party identifier
|
||||
pub type PartyId = u8;
|
||||
|
||||
/// Which round the variable is given in.
|
||||
///
|
||||
/// (Relevant when compiling to/from an interactive protocol).
|
||||
pub type Round = u8;
|
||||
|
||||
/// Metadata associated with a variable.
|
||||
///
|
||||
/// We require all fields to have a [Default] implementation. This requirement is forced by
|
||||
/// deriving [Default].
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct VariableMetadata {
|
||||
/// Who knows it (None if public)
|
||||
vis: Option<PartyId>,
|
||||
/// Its type
|
||||
sort: Sort,
|
||||
/// The name
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl VariableMetadata {
|
||||
/// term (cached)
|
||||
fn term(&self) -> Term {
|
||||
leaf_term(Op::Var(self.name.clone(), self.sort.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// An IR constraint system.
|
||||
pub struct ComputationMetadata {
|
||||
/// A map from variables to their metadata
|
||||
vars: FxHashMap<String, VariableMetadata>,
|
||||
/// A map from party names to numbers assigned to them.
|
||||
pub party_ids: FxHashMap<String, PartyId>,
|
||||
/// The next free id.
|
||||
pub next_party_id: PartyId,
|
||||
/// All inputs, including who knows them. If no visibility is set, the input is public.
|
||||
pub input_vis: FxHashMap<String, (Term, Option<PartyId>)>,
|
||||
/// The inputs for the computation itself (not the precomputation).
|
||||
pub computation_inputs: Vec<String>,
|
||||
party_ids: FxHashMap<String, PartyId>,
|
||||
}
|
||||
|
||||
impl ComputationMetadata {
|
||||
/// Add a new party to the computation, getting a [PartyId] for them.
|
||||
pub fn add_party(&mut self, name: String) -> PartyId {
|
||||
self.party_ids.insert(name, self.next_party_id);
|
||||
self.next_party_id += 1;
|
||||
self.next_party_id - 1
|
||||
self.party_ids.insert(name, self.party_ids.len() as u8);
|
||||
self.party_ids.len() as u8 - 1
|
||||
}
|
||||
/// Add a new input to the computation, visible to `party`, or public if `party` is [None].
|
||||
pub fn new_input(&mut self, input_name: String, party: Option<PartyId>, sort: Sort) {
|
||||
let term = leaf_term(Op::Var(input_name.clone(), sort));
|
||||
pub fn new_input(&mut self, name: String, party: Option<PartyId>, sort: Sort) {
|
||||
debug_assert!(
|
||||
!self.input_vis.contains_key(&input_name)
|
||||
|| self.input_vis.get(&input_name).unwrap().1 == party,
|
||||
!self.vars.contains_key(&name),
|
||||
"Tried to create input {} (visibility {:?}), but it already existed (visibility {:?})",
|
||||
input_name,
|
||||
name,
|
||||
party,
|
||||
self.input_vis.get(&input_name).unwrap()
|
||||
self.vars.get(&name).unwrap()
|
||||
);
|
||||
self.input_vis.insert(input_name.clone(), (term, party));
|
||||
self.computation_inputs.push(input_name);
|
||||
let var_md = VariableMetadata {
|
||||
sort,
|
||||
vis: party,
|
||||
name: name.clone(),
|
||||
};
|
||||
self.vars.insert(name, var_md);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn lookup<Q: std::borrow::Borrow<str> + ?Sized>(&self, name: &Q) -> &VariableMetadata {
|
||||
let n = name.borrow();
|
||||
self.vars
|
||||
.get(n)
|
||||
.unwrap_or_else(|| panic!("Missing input {} in inputs{:#?}", n, self.vars))
|
||||
}
|
||||
|
||||
/// Returns None if the value is public. Otherwise, the unique party that knows it.
|
||||
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
|
||||
self.input_vis
|
||||
.get(input_name)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Missing input {} in inputs{:#?}",
|
||||
input_name, self.input_vis
|
||||
)
|
||||
})
|
||||
.1
|
||||
self.lookup(input_name).vis
|
||||
}
|
||||
/// Is this input public?
|
||||
|
||||
/// Is this an input?
|
||||
pub fn is_input(&self, input_name: &str) -> bool {
|
||||
self.input_vis.contains_key(input_name)
|
||||
self.vars.contains_key(input_name)
|
||||
}
|
||||
|
||||
/// Is this input public?
|
||||
pub fn is_input_public(&self, input_name: &str) -> bool {
|
||||
self.get_input_visibility(input_name).is_none()
|
||||
}
|
||||
|
||||
/// What sort is this input?
|
||||
pub fn input_sort(&self, input_name: &str) -> Sort {
|
||||
check(&self.input_vis.get(input_name).unwrap().0)
|
||||
self.lookup(input_name).sort.clone()
|
||||
}
|
||||
/// Get all public inputs to the computation itself.
|
||||
///
|
||||
/// Excludes pre-computation inputs
|
||||
pub fn public_input_names(&'_ self) -> impl Iterator<Item = &str> + '_ {
|
||||
self.input_vis.iter().filter_map(move |(name, party)| {
|
||||
if party.1.is_none() && self.computation_inputs.contains(name) {
|
||||
Some(name.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
|
||||
/// Give all inputs, in a fixed order.
|
||||
pub fn ordered_input_names(&self) -> Vec<String> {
|
||||
let mut out: Vec<String> = self.vars.keys().cloned().collect();
|
||||
out.sort();
|
||||
out
|
||||
}
|
||||
/// Get all public inputs to the computation itself.
|
||||
///
|
||||
/// Excludes pre-computation inputs.
|
||||
// I think the lint is just broken here.
|
||||
// TODO: submit a patch
|
||||
#[allow(clippy::needless_lifetimes)]
|
||||
pub fn public_inputs<'a>(&'a self) -> impl Iterator<Item = Term> + 'a {
|
||||
// TODO: check order?
|
||||
self.input_vis
|
||||
.iter()
|
||||
.filter_map(move |(name, (term, vis))| {
|
||||
if vis.is_none() && self.computation_inputs.contains(name) {
|
||||
Some(term.clone())
|
||||
|
||||
/// Give all public inputs, in a fixed order.
|
||||
pub fn ordered_public_inputs(&self) -> Vec<Term> {
|
||||
let mut out: Vec<Term> = self
|
||||
.vars
|
||||
.values()
|
||||
.filter_map(|v| {
|
||||
if v.vis.is_none() {
|
||||
Some(v.term())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
out.sort_by(|a, b| a.as_var_name().cmp(b.as_var_name()));
|
||||
out
|
||||
}
|
||||
|
||||
/// Give all inputs, in a fixed order.
|
||||
pub fn ordered_inputs(&self) -> Vec<Term> {
|
||||
let mut out: Vec<Term> = self.vars.values().map(|v| v.term()).collect();
|
||||
out.sort_by(|a, b| a.as_var_name().cmp(b.as_var_name()));
|
||||
out
|
||||
}
|
||||
|
||||
/// Give the set of public input names.
|
||||
pub fn public_input_names_set(&self) -> FxHashSet<String> {
|
||||
self.ordered_public_inputs()
|
||||
.iter()
|
||||
.map(|t| t.as_var_name().into())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all the inputs visible to `party`.
|
||||
pub fn get_inputs_for_party(&self, party: Option<PartyId>) -> FxHashSet<String> {
|
||||
self.input_vis
|
||||
.iter()
|
||||
.filter_map(|(name, (_, vis))| {
|
||||
if vis.is_none() || vis == &party {
|
||||
Some(name.clone())
|
||||
self.vars
|
||||
.values()
|
||||
.filter_map(|v| {
|
||||
if v.vis.is_none() || v.vis == party {
|
||||
Some(v.name.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -1674,65 +1727,9 @@ impl ComputationMetadata {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// From a list of parties, a list of inputs, and a list of visibilities,
|
||||
/// create a [ComputationMetadata].
|
||||
pub fn from_parts(
|
||||
parties: Vec<String>,
|
||||
mut inputs: FxHashMap<String, Term>,
|
||||
visibilities: FxHashMap<String, String>,
|
||||
) -> Self {
|
||||
let party_ids: FxHashMap<String, PartyId> = parties
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, n)| (n, i as u8))
|
||||
.collect();
|
||||
let next_party_id = party_ids.len() as u8;
|
||||
let computation_inputs: Vec<String> = inputs.keys().cloned().collect();
|
||||
let input_vis = computation_inputs
|
||||
.iter()
|
||||
.map(|i| {
|
||||
let vis = visibilities.get(i).map(|p| *party_ids.get(p).unwrap());
|
||||
let term = inputs.remove(i).unwrap();
|
||||
(i.clone(), (term, vis))
|
||||
})
|
||||
.collect();
|
||||
ComputationMetadata {
|
||||
party_ids,
|
||||
next_party_id,
|
||||
input_vis,
|
||||
computation_inputs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove an input
|
||||
pub fn remove_var(&mut self, name: &str) {
|
||||
self.input_vis.remove(name);
|
||||
if let Some(pos) = self.computation_inputs.iter().position(|x| *x == name) {
|
||||
self.computation_inputs.remove(pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ComputationMetadata {
|
||||
fn fmt(&self, f: &mut Formatter) -> FmtResult {
|
||||
write!(f, "(metadata\n (")?;
|
||||
for id in 0..self.next_party_id {
|
||||
let party = self.party_ids.iter().find(|(_, i)| **i == id).unwrap().0;
|
||||
write!(f, " {}", party)?;
|
||||
}
|
||||
write!(f, ")\n (")?;
|
||||
for input in self.input_vis.keys() {
|
||||
let sort = self.input_sort(input);
|
||||
write!(f, " ({} {})", input, sort)?;
|
||||
}
|
||||
write!(f, ")\n (")?;
|
||||
for (input, (_, vis)) in &self.input_vis {
|
||||
if let Some(id) = vis {
|
||||
let party = self.party_ids.iter().find(|(_, i)| *i == id).unwrap();
|
||||
write!(f, " ({} {})", input, party.0)?;
|
||||
}
|
||||
}
|
||||
write!(f, ")\n)")
|
||||
self.vars.remove(name);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,10 +16,11 @@
|
||||
//! * `X`: identifier
|
||||
//! * regex: `[^()0-9#; \t\n\f][^(); \t\n\f#]*`
|
||||
//! * Computation `C`: `(computation M P T)`
|
||||
//! * Metadata `M`: `(metadata PARTIES INPUTS VISIBILITIES)`
|
||||
//! * PARTIES is `(X1 .. Xn)`
|
||||
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
|
||||
//! * VISIBILITIES is `((X_INPUT_1 X_PARTY_1) .. (X_INPUT_n X_PARTY_n))`
|
||||
//! * Metadata `M`: `(metadata PARTIES INPUTS)`
|
||||
//! * PARTIES is `(parties X1 .. Xn)`
|
||||
//! * INPUTS is `(inputs INPUT1 .. INPUTn)`
|
||||
//! * INPUT is `(X S PARTY)`
|
||||
//! * PARTY is `(party X)` or nothing (public)
|
||||
//! * Precompute `P`: `(precompute INPUTS OUTPUTS TUPLE_TERM)`
|
||||
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
|
||||
//! * OUTPUTS is `((X1 S1) .. (Xn Sn))`
|
||||
@@ -161,6 +162,10 @@ enum CtrlOp {
|
||||
SetDefaultModulus,
|
||||
}
|
||||
|
||||
enum VariableMetadataItem {
|
||||
Party(u8),
|
||||
}
|
||||
|
||||
impl<'src> IrInterp<'src> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
@@ -513,28 +518,87 @@ impl<'src> IrInterp<'src> {
|
||||
}
|
||||
}
|
||||
|
||||
fn visibility_list(&self, tt: &TokTree<'src>) -> Vec<(String, String)> {
|
||||
#[track_caller]
|
||||
fn unwrap_list<'a>(&self, tt: &'a TokTree<'src>, err: &str) -> &'a [TokTree<'src>] {
|
||||
if let List(tts) = tt {
|
||||
tts.iter()
|
||||
.map(|tti| match tti {
|
||||
List(ls) => match &ls[..] {
|
||||
[Leaf(Token::Ident, var), Leaf(Token::Ident, party)] => {
|
||||
let var = from_utf8(var).unwrap().to_owned();
|
||||
let party = from_utf8(party).unwrap().to_owned();
|
||||
(var, party)
|
||||
}
|
||||
_ => panic!("Expected visibility pair, found {}", tti),
|
||||
},
|
||||
_ => panic!("Expected visibility pair, found {}", tti),
|
||||
})
|
||||
.collect()
|
||||
tts.as_slice()
|
||||
} else {
|
||||
panic!("Expected visibility list, found: {}", tt)
|
||||
panic!("Expected {}, found non-list: {}", err, tt)
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn unwrap_prefix_list<'a>(&self, tt: &'a TokTree<'src>, prefix: &str) -> &'a [TokTree<'src>] {
|
||||
let tts = self.unwrap_list(tt, prefix);
|
||||
assert_eq!(
|
||||
self.ident_str(&tts[0]),
|
||||
prefix,
|
||||
"Expected list head '{}', but found {}",
|
||||
prefix,
|
||||
&tts[0]
|
||||
);
|
||||
&tts[1..]
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn ident(&self, tt: &TokTree<'src>) -> &'src [u8] {
|
||||
if let Leaf(Token::Ident, i) = tt {
|
||||
i
|
||||
} else {
|
||||
panic!("Expected identifier, found {}", tt)
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn ident_str(&self, tt: &TokTree<'src>) -> &'src str {
|
||||
from_utf8(self.ident(tt)).unwrap()
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn ident_string(&self, tt: &TokTree<'src>) -> String {
|
||||
self.ident_str(tt).to_owned()
|
||||
}
|
||||
|
||||
fn variable_metadata_item(&mut self, tt: &TokTree<'src>) -> VariableMetadataItem {
|
||||
let tts = self.unwrap_list(tt, "variable metadata item");
|
||||
match self.ident(&tts[0]) {
|
||||
b"party" => {
|
||||
let id = self.int(&tts[1]).to_u8().unwrap();
|
||||
VariableMetadataItem::Party(id)
|
||||
}
|
||||
i => {
|
||||
panic!(
|
||||
"Expected variable metadata item, got {}",
|
||||
from_utf8(i).unwrap()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn variable_metadata(&mut self, tt: &TokTree<'src>) -> (&'src [u8], VariableMetadata) {
|
||||
let tts = self.unwrap_list(tt, "variable metadata");
|
||||
let name = self.ident_string(&tts[0]);
|
||||
let name_bytes = self.ident(&tts[0]);
|
||||
let sort = self.sort(&tts[1]);
|
||||
let mut md = VariableMetadata {
|
||||
vis: None,
|
||||
sort,
|
||||
name,
|
||||
};
|
||||
for tti in &tts[2..] {
|
||||
match self.variable_metadata_item(tti) {
|
||||
VariableMetadataItem::Party(p) => {
|
||||
md.vis = Some(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
(name_bytes, md)
|
||||
}
|
||||
|
||||
/// Returns a [ComputationMetadata] and a list of sort bindings to un-bind.
|
||||
fn metadata(&mut self, tt: &TokTree<'src>) -> (ComputationMetadata, Vec<Vec<u8>>) {
|
||||
let mut md = ComputationMetadata::default();
|
||||
let mut unbind = Vec::new();
|
||||
if let List(tts) = tt {
|
||||
if tts.is_empty() || tts[0] != Leaf(Token::Ident, b"metadata") {
|
||||
panic!(
|
||||
@@ -543,22 +607,19 @@ impl<'src> IrInterp<'src> {
|
||||
)
|
||||
}
|
||||
match &tts[1..] {
|
||||
[parties, inputs, viss] => {
|
||||
[parties, inputs] => {
|
||||
let parties = self.string_list(parties);
|
||||
let input_names = self.decl_list(inputs);
|
||||
let inputs: FxHashMap<String, Term> = input_names
|
||||
.iter()
|
||||
.map(|i| (from_utf8(i).unwrap().into(), self.get_binding(i).clone()))
|
||||
.collect();
|
||||
let visibilities = self.visibility_list(viss);
|
||||
(
|
||||
ComputationMetadata::from_parts(
|
||||
parties,
|
||||
inputs,
|
||||
visibilities.into_iter().collect(),
|
||||
),
|
||||
input_names,
|
||||
)
|
||||
for p in parties.into_iter().skip(1) {
|
||||
md.add_party(p);
|
||||
}
|
||||
let tts_inputs = self.unwrap_prefix_list(inputs, "inputs");
|
||||
for tti_input in tts_inputs {
|
||||
let (name_bytes, v_md) = self.variable_metadata(tti_input);
|
||||
self.bind(name_bytes, v_md.term());
|
||||
unbind.push(name_bytes.to_owned());
|
||||
md.vars.insert(v_md.name.clone(), v_md);
|
||||
}
|
||||
(md, unbind)
|
||||
}
|
||||
_ => panic!("Expected meta-data, found {}", tt),
|
||||
}
|
||||
@@ -848,9 +909,8 @@ mod test {
|
||||
b"
|
||||
(computation
|
||||
(metadata
|
||||
(P V)
|
||||
((a bool) (b bool) (A (tuple bool bool)))
|
||||
((a P))
|
||||
(parties P V)
|
||||
(inputs (a bool (party 0)) (b bool) (A (tuple bool bool)))
|
||||
)
|
||||
(precompute
|
||||
((c bool) (d bool))
|
||||
@@ -863,7 +923,7 @@ mod test {
|
||||
((field 0) (#t false false #b0000 true))))
|
||||
)",
|
||||
);
|
||||
assert_eq!(c.metadata.input_vis.len(), 3);
|
||||
assert_eq!(c.metadata.vars.len(), 3);
|
||||
assert!(!c.metadata.is_input_public("a"));
|
||||
assert!(c.metadata.is_input_public("b"));
|
||||
assert!(c.metadata.is_input_public("A"));
|
||||
|
||||
@@ -282,7 +282,7 @@ impl<'a> ToABY<'a> {
|
||||
match &t.op() {
|
||||
Op::Var(name, Sort::Bool) => {
|
||||
let md = self.get_md();
|
||||
if !self.inputs.contains(&t) && md.input_vis.contains_key(name) {
|
||||
if !self.inputs.contains(&t) && md.is_input(name) {
|
||||
let vis = self.unwrap_vis(name);
|
||||
let s = self.get_share(&t, to_share_type);
|
||||
let op = "IN";
|
||||
@@ -409,7 +409,7 @@ impl<'a> ToABY<'a> {
|
||||
match &t.op() {
|
||||
Op::Var(name, Sort::BitVector(_)) => {
|
||||
let md = self.get_md();
|
||||
if !self.inputs.contains(&t) && md.input_vis.contains_key(name) {
|
||||
if !self.inputs.contains(&t) && md.is_input(name) {
|
||||
let vis = self.unwrap_vis(name);
|
||||
let s = self.get_share(&t, to_share_type);
|
||||
let op = "IN";
|
||||
@@ -830,7 +830,7 @@ impl<'a> ToABY<'a> {
|
||||
|
||||
let inputs: Vec<String> = comp
|
||||
.metadata
|
||||
.computation_inputs
|
||||
.ordered_input_names()
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if bytecode_input_map.contains_key(x) {
|
||||
|
||||
@@ -977,10 +977,7 @@ impl<'cfg> ToR1cs<'cfg> {
|
||||
pub fn to_r1cs(mut cs: Computation, cfg: &CircCfg) -> (ProverData, VerifierData) {
|
||||
let assertions = cs.outputs.clone();
|
||||
let metadata = cs.metadata.clone();
|
||||
let public_inputs = metadata
|
||||
.public_input_names()
|
||||
.map(ToOwned::to_owned)
|
||||
.collect();
|
||||
let public_inputs = metadata.public_input_names_set();
|
||||
debug!("public inputs: {:?}", public_inputs);
|
||||
let mut converter = ToR1cs::new(cfg, public_inputs);
|
||||
debug!(
|
||||
@@ -991,7 +988,7 @@ pub fn to_r1cs(mut cs: Computation, cfg: &CircCfg) -> (ProverData, VerifierData)
|
||||
.sum::<usize>()
|
||||
);
|
||||
debug!("declaring inputs");
|
||||
for i in metadata.public_inputs() {
|
||||
for i in metadata.ordered_public_inputs() {
|
||||
debug!("input {}", i);
|
||||
converter.embed(i);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user