Merge updates needed for SHA with lookups. (#196)

This is highly unoptimized, for now.
This commit is contained in:
Alex Ozdemir
2024-06-19 13:09:43 -07:00
committed by GitHub
parent aa318e55a5
commit 2cdc019b86
60 changed files with 2300 additions and 330 deletions

View File

@@ -371,14 +371,22 @@ impl PartialOrd for FieldV {
impl Ord for FieldV {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.full_cow().cmp(&other.full_cow())
if self.is_full() || other.is_full() {
self.full_cow().cmp(&other.full_cow())
} else {
self.inline_i64().cmp(&other.inline_i64())
}
}
}
impl PartialEq for FieldV {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.full_cow().eq(&other.full_cow())
if self.is_full() || other.is_full() {
self.full_cow().eq(&other.full_cow())
} else {
self.inline_i64().eq(&other.inline_i64())
}
}
}

View File

@@ -49,7 +49,7 @@ impl crate::Table<u8> for Table {
"hashconsing"
}
fn for_each(f: impl FnMut(&u8, &[Self::Node])) {
fn for_each(_f: impl FnMut(&u8, &[Self::Node])) {
panic!()
}

View File

@@ -52,7 +52,7 @@ macro_rules! generate_hashcons_hashconsing {
"hashconsing"
}
fn for_each(f: impl FnMut(&$Op, &[Self::Node])) {
fn for_each(_f: impl FnMut(&$Op, &[Self::Node])) {
panic!()
}

View File

@@ -50,7 +50,7 @@ impl crate::Table<TemplateOp> for Table {
"hashconsing"
}
fn for_each(f: impl FnMut(&TemplateOp, &[Self::Node])) {
fn for_each(_f: impl FnMut(&TemplateOp, &[Self::Node])) {
panic!()
}

View File

@@ -3,20 +3,18 @@ use fxhash::FxHashMap as HashMap;
use crate::Id;
use log::trace;
use std::borrow::Borrow;
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::sync::atomic::AtomicU64;
use std::thread_local;
#[allow(dead_code)]
struct NodeData {
op: u8,
hash: AtomicU64,
cs: Box<[Node]>,
}
#[allow(dead_code)]
struct NodeDataRef<'a, Q: Borrow<[Node]>>(&'a u8, &'a Q);
#[derive(Clone)]
pub struct Node {
data: Rc<NodeData>,
@@ -158,6 +156,7 @@ impl Manager {
let mut table = self.table.borrow_mut();
let data = Rc::new(NodeData {
op: op.clone(),
hash: Default::default(),
cs: children.into(),
});
@@ -218,10 +217,10 @@ impl Manager {
let duration = start.elapsed();
self.in_gc.set(false);
trace!(
"GC: {} terms -> {} terms in {} us",
"GC: {} terms -> {} terms in {} ns",
collected,
new_size,
duration.as_micros()
duration.as_nanos()
);
collected
} else {
@@ -296,37 +295,45 @@ impl crate::Weak<u8> for Weak {
}
mod hash {
use super::{Node, NodeData, NodeDataRef, Weak};
use std::borrow::Borrow;
use super::{Node, NodeData, Weak};
use fxhash::FxHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::Ordering::SeqCst;
impl Hash for Node {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl Hash for Weak {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl Hash for NodeData {
fn hash<H: Hasher>(&self, state: &mut H) {
self.op.hash(state);
impl NodeData {
fn rehash(&self) -> u64 {
let mut hasher = FxHasher::default();
self.op.hash(&mut hasher);
for c in self.cs.iter() {
c.hash(state);
c.hash(&mut hasher);
}
let current_hash = hasher.finish();
self.hash.store(current_hash, SeqCst);
current_hash
}
}
impl<'a, Q: Borrow<[Node]>> Hash for NodeDataRef<'a, Q> {
impl Hash for NodeData {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
for c in self.1.borrow().iter() {
c.hash(state);
let mut current_hash: u64 = self.hash.load(SeqCst);
if current_hash == 0 {
current_hash = self.rehash();
}
state.write_u64(current_hash);
}
}
}

View File

@@ -5,21 +5,19 @@ macro_rules! generate_hashcons_rc {
use fxhash::FxHashMap as HashMap;
use log::trace;
use std::borrow::Borrow;
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::sync::atomic::AtomicU64;
use std::thread_local;
use $crate::Id;
#[allow(dead_code)]
struct NodeData {
op: $Op,
hash: AtomicU64,
cs: Box<[Node]>,
}
#[allow(dead_code)]
struct NodeDataRef<'a, Q: Borrow<[Node]>>(&'a $Op, &'a Q);
#[derive(Clone)]
pub struct Node {
data: Rc<NodeData>,
@@ -161,6 +159,7 @@ macro_rules! generate_hashcons_rc {
let mut table = self.table.borrow_mut();
let data = Rc::new(NodeData {
op: op.clone(),
hash: Default::default(),
cs: children.into(),
});
@@ -221,10 +220,10 @@ macro_rules! generate_hashcons_rc {
let duration = start.elapsed();
self.in_gc.set(false);
trace!(
"GC: {} terms -> {} terms in {} us",
"GC: {} terms -> {} terms in {} ns",
collected,
new_size,
duration.as_micros()
duration.as_nanos()
);
collected
} else {
@@ -299,37 +298,45 @@ macro_rules! generate_hashcons_rc {
}
mod hash {
use super::{Node, NodeData, NodeDataRef, Weak};
use std::borrow::Borrow;
use super::{Node, NodeData, Weak};
use fxhash::FxHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::Ordering::SeqCst;
impl Hash for Node {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl Hash for Weak {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl Hash for NodeData {
fn hash<H: Hasher>(&self, state: &mut H) {
self.op.hash(state);
impl NodeData {
fn rehash(&self) -> u64 {
let mut hasher = FxHasher::default();
self.op.hash(&mut hasher);
for c in self.cs.iter() {
c.hash(state);
c.hash(&mut hasher);
}
let current_hash = hasher.finish();
self.hash.store(current_hash, SeqCst);
current_hash
}
}
impl<'a, Q: Borrow<[Node]>> Hash for NodeDataRef<'a, Q> {
impl Hash for NodeData {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
for c in self.1.borrow().iter() {
c.hash(state);
let mut current_hash: u64 = self.hash.load(SeqCst);
if current_hash == 0 {
current_hash = self.rehash();
}
state.write_u64(current_hash);
}
}
}

View File

@@ -2,21 +2,19 @@ use fxhash::FxHashMap as HashMap;
use crate::Id;
use log::trace;
use std::borrow::Borrow;
use std::cell::{Cell, RefCell};
use std::net::SocketAddrV6 as TemplateOp;
use std::rc::Rc;
use std::sync::atomic::AtomicU64;
use std::thread_local;
#[allow(dead_code)]
struct NodeData {
op: TemplateOp,
hash: AtomicU64,
cs: Box<[Node]>,
}
#[allow(dead_code)]
struct NodeDataRef<'a, Q: Borrow<[Node]>>(&'a TemplateOp, &'a Q);
#[derive(Clone)]
pub struct Node {
data: Rc<NodeData>,
@@ -158,6 +156,7 @@ impl Manager {
let mut table = self.table.borrow_mut();
let data = Rc::new(NodeData {
op: op.clone(),
hash: Default::default(),
cs: children.into(),
});
@@ -218,10 +217,10 @@ impl Manager {
let duration = start.elapsed();
self.in_gc.set(false);
trace!(
"GC: {} terms -> {} terms in {} us",
"GC: {} terms -> {} terms in {} ns",
collected,
new_size,
duration.as_micros()
duration.as_nanos()
);
collected
} else {
@@ -296,37 +295,46 @@ impl crate::Weak<TemplateOp> for Weak {
}
mod hash {
use super::{Node, NodeData, NodeDataRef, Weak};
use std::borrow::Borrow;
use super::{Node, NodeData, Weak};
use fxhash::FxHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::Ordering::SeqCst;
impl Hash for Node {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl Hash for Weak {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl Hash for NodeData {
fn hash<H: Hasher>(&self, state: &mut H) {
self.op.hash(state);
impl NodeData {
fn rehash(&self) -> u64 {
let mut hasher = FxHasher::default();
self.op.hash(&mut hasher);
for c in self.cs.iter() {
c.hash(state);
c.hash(&mut hasher);
}
let current_hash = hasher.finish();
self.hash.store(current_hash, SeqCst);
current_hash
}
}
impl<'a, Q: Borrow<[Node]>> Hash for NodeDataRef<'a, Q> {
impl Hash for NodeData {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
for c in self.1.borrow().iter() {
c.hash(state);
let mut current_hash: u64 = self.hash.load(SeqCst);
if current_hash == 0 {
current_hash = self.rehash();
}
state.write_u64(current_hash);
}
}
}

View File

@@ -194,6 +194,14 @@ pub struct IrOpt {
default_value = "true"
)]
pub fits_in_bits_ip: bool,
/// Time operator evaluations
#[arg(
long = "ir-time-eval-ops",
env = "IR_TIME_EVAL_OPS",
action = ArgAction::Set,
default_value = "false"
)]
pub time_eval_ops: bool,
}
impl Default for IrOpt {
@@ -202,6 +210,7 @@ impl Default for IrOpt {
field_to_bv: Default::default(),
frequent_gc: Default::default(),
fits_in_bits_ip: true,
time_eval_ops: false,
}
}
}

View File

@@ -1,7 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(x #f6)
(return #f6)
(return #f0)
) false ; ignored
))

View File

@@ -0,0 +1 @@
This directory contains a SHA256 implementation by Anna Woo.

View File

@@ -0,0 +1,157 @@
// #pragma curve bn128
from "big_nat" import BigNatb, BigNatb_v2, BigNat, BigNatParams, GpBigNats
import "utils/pack/bool/unpack" as unpack
import "utils/pack/bool/unpack_unchecked"
import "utils/pack/bool/pack" as pack
// from "field" import FIELD_SIZE_IN_BITS
from "EMBED" import bit_array_le, u32_to_u64, value_in_array //, reverse_lookup //, fits_in_bits
from "const_range_check" import D_1, D_2, D_3, D_4, D_5, D_6, D_7, D_8, D_9, D_10, D_TO_S_1, D_TO_S_2, D_TO_S_3, D_TO_S_4, D_TO_S_5, D_TO_S_6, D_TO_S_7, D_TO_S_8, D_TO_S_9, D_TO_S_10, D_TO_S_11
// Check that x has N bits
def fits_in_bits<N>(field x) -> bool:
assert(N!=1 || value_in_array(x, D_1))
assert(N!=2 || value_in_array(x, D_2))
assert(N!=3 || value_in_array(x, D_3))
assert(N!=4 || value_in_array(x, D_4))
assert(N!=5 || value_in_array(x, D_5))
assert(N!=6 || value_in_array(x, D_6))
assert(N!=7 || value_in_array(x, D_7))
assert(N!=8 || value_in_array(x, D_8))
assert(N!=9 || value_in_array(x, D_9))
assert(N!=10 || value_in_array(x, D_10))
return (N >= 1) && (N <= 10) // maximum bitwidth of range check
// Check that x is a N-bit value in sparse form
def fits_in_bits_sparse<N>(field x) -> bool:
assert(N!=1 || value_in_array(x, D_TO_S_1))
assert(N!=2 || value_in_array(x, D_TO_S_2))
assert(N!=3 || value_in_array(x, D_TO_S_3))
assert(N!=4 || value_in_array(x, D_TO_S_4))
assert(N!=5 || value_in_array(x, D_TO_S_5))
assert(N!=6 || value_in_array(x, D_TO_S_6))
assert(N!=7 || value_in_array(x, D_TO_S_7))
assert(N!=8 || value_in_array(x, D_TO_S_8))
assert(N!=9 || value_in_array(x, D_TO_S_9))
assert(N!=10 || value_in_array(x, D_TO_S_10))
assert(N!=11 || value_in_array(x, D_TO_S_11))
return (N >= 1) && (N <= 11) // maximum bitwidth of range check
// // Convert sparse form to dense form
// def sparse_to_dense<N>(field x) -> field:
// assert(N!=3 || reverse_lookup(x, D_TO_S_3))
// return x
// check if the input is non-zero
def is_non_zero<NQ>(BigNat<NQ> input) -> bool:
bool non_zero = false
for u32 i in 0..NQ do
non_zero = non_zero || (input.limbs[i] != 0)
endfor
return non_zero
def group_bignat<N, W>(BigNat<N> left, BigNat<N> right) -> GpBigNats<2>: // assume we can pack N-1 limbs into one field element
u32 end = N-1
BigNat<2> gp_left = BigNat {limbs: [0, left.limbs[end]]}
BigNat<2> gp_right = BigNat {limbs: [0, right.limbs[end]]}
field base = 2 ** W
field shift = 1
for u32 i in 0..end do
gp_left.limbs[0] = gp_left.limbs[0] + left.limbs[i] * shift
gp_right.limbs[0] = gp_right.limbs[0] + right.limbs[i] * shift
shift = shift * base
endfor
GpBigNats<2> output = GpBigNats {left: gp_left, right: gp_right}
return output
def is_equal<N, W>(BigNat<N> left, BigNat<N> right) -> bool: // assume we can pack N-1 limbs into one field element
field base = 2 ** W
GpBigNats<2> output = group_bignat::<N, W>(left, right)
return (output.left.limbs[0] == output.right.limbs[0] && output.left.limbs[1] == output.right.limbs[1])
def bignat_to_field<N, W>(BigNat<N> input) -> field: // assume left and right have the same limbwidth
field output = 0
field base = 2 ** W
field shift = 1
for u32 i in 0..N do
output = output + input.limbs[i] * shift
shift = shift * base
endfor
return output
def less_than_threshold_inner<P, P2>(BigNat<P> input, field input_value, field carry, field threshold) -> bool:
// The case input <= threshold is true if and only if the followings are true
// - If threshold_bignat[P2..P] is a trailing sequence of zeros in its limb representation,
// then input[P2..P] is a sequence of zeros
// - There exists carry such that
// i) the bit-length of carry is at most the bit-length of threshold
// ii) carry + input = threshold
bool notlessthan = false
for u32 i in P2..P do
notlessthan = notlessthan || (input.limbs[i] != 0) // set notlessthan to be true if one of the last several limbs of input is non-zero
endfor
notlessthan = notlessthan || (input_value + carry != threshold)
return !notlessthan
// return true if input<=threshold; return false otherwise
// assume that the prover is only incentivized to prove that the result is true; But the result is false does not allow him to trick on the final result
// Assume P2 * W does not exceed the number of bits of field characteristics
def less_than_threshold<P, P2, W>(BigNat<P> input, field carry, field threshold) -> bool: // assume P is even
assert(P2 == 4)
BigNat<P2> trunc_input = BigNat{ limbs: input.limbs[0..P2]}
field input_value = bignat_to_field::<P2, W>(trunc_input)
return less_than_threshold_inner::<P, P2>(input, input_value, carry, threshold)
// return !notlessthan
def assert_well_formed<N, K>(BigNat<N> value) -> bool:
//u64 limb_width = value.bparams.limb_width
bool[K] res = [false; K]
for u32 i in 0..N do //ensure elements in 'limb_values' fit in 'limb_width' bits
res = unpack_unchecked(value.limbs[i]) //assume K < FIELD_SIZE_IN_BITS
//assert(if K >= FIELD_SIZE_IN_BITS then bit_array_le(res, [...[false; K - FIELD_SIZE_IN_BITS], ...unpack_unchecked::<FIELD_SIZE_IN_BITS>(-1)]) else true fi)
endfor
return true
def bool_to_field<W>(bool[W] x) -> field:
return pack(x)
def bignat_fit_in_bits<N, W>(BigNat<N> x) -> bool:
for u32 i in 0..N do
assert(fits_in_bits::<W>(x.limbs[i]))
endfor
return true
def BigNatb_to_BigNat<N, W>(BigNatb<N, W> x) -> BigNat<N>:
BigNat<N> res = BigNat{limbs: [0; N]}
for u32 i in 0..N do
res.limbs[i] = pack(x.limbs[i])
endfor
return res
def BigNatb_to_BigNat_v2<N, Nm1, W, W2>(BigNatb_v2<Nm1, W, W2> x) -> BigNat<N>: // Nm1 = N - 1 // difference from BigNatb_to_BigNat is that BigNatb_to_BigNat_v2 allows the last limb has a smaller bitwidth
// field[N] limbsres = [0; N]
BigNat<N> res = BigNat{limbs: [0; N]}
for u32 i in 0..Nm1 do
res.limbs[i] = pack(x.limbs[i])
endfor
res.limbs[Nm1] = pack::<W2>(x.limb)
// BigNat<N> res = BigNat{limbs: limbsres}
return res
def check_limbwidth<W>(u32 limbwidth) -> bool:
//return u32_to_u64(W) == limbwidth
return W == limbwidth
def main(BigNatb<10, 256> a, BigNat<10> b) -> bool:
//BigNatParams res = BigNatb_to_BigNat(a)
//BigNat<10> res = BigNatb_to_BigNat(a)
//bool res = check_limbwidth::<256>(a.bparams.limb_width)
return true
//return check_limbwidth<256>(a.bparams.limb_width)

View File

@@ -0,0 +1,166 @@
from "assert_well_formed" import fits_in_bits, fits_in_bits_sparse
from "utils" import Dual, unsafe_split, split_limbs_in_sparse, unsafe_split_dyn, unsafe_split_dyn_sparse, split_even_dual_10, split_even_dual_11, split_odd_dual_10, split_odd_dual_11, dense_limb_to_dual_limb, dual_limbs_to_sparse_limbs, dual_limbs_to_dense_limbs, combine_limbs, split_even_dual_for_all_limbs
from "const_range_check" import S_ONES_10, S_ONES_11
// Compute right and left parts of input s.t.
// i. input[N-1]||0||..||input[1]||0||input[0] = left||0||right
// ii. left is sparse form of bitwidth RED_L = LIMBWIDTH[SPLIT_IDX]-RED_R bits
// iii. right = input[SPLIT_IDX] - left * (2 ** (2 * RED_R))
def split_for_shift<N, R>(field[N] input, u32[N] LIMBWIDTH) -> field[2]:
u32 CUR_WIDTH = 0
u32 SPLIT_IDX = 0 // input[split_idx] needs to be split
u32 RED_R = R // limbwidth of the right part of the splited limb
for u32 i in 0..N do
SPLIT_IDX = if CUR_WIDTH < R then i else SPLIT_IDX fi // When i=0, CUR_WIDTH=0; When i=1, CUR_WIDTH=LIMBWIDTH[0]; When i=2, CUR_WIDTH=LIMBWIDTH[0]+LIMBWIDTH[1]; ...
RED_R = if CUR_WIDTH < R then R-CUR_WIDTH else RED_R fi
CUR_WIDTH = CUR_WIDTH + LIMBWIDTH[i]
endfor
u32 TOTAL_WIDTH = CUR_WIDTH
u32 LOW_BITS = RED_R * 2
u32 HIGH_BITS = 2*LIMBWIDTH[SPLIT_IDX] - 1 - LOW_BITS
unsafe witness field[2] split = unsafe_split::<LOW_BITS, HIGH_BITS>(input[SPLIT_IDX]) // would input[SPLIT_IDX] incur lookup cost?
field[2] safe_split = [0, split[1]]
safe_split[0] = input[SPLIT_IDX] - split[1] * (2 ** LOW_BITS)
// Check that the split limbs are well-formed
u32 RED_L = LIMBWIDTH[SPLIT_IDX] - RED_R
assert(fits_in_bits_sparse::<RED_L>(safe_split[1]))
// split[0] = input[SPLIT_IDX] - split[1] * (2 ** LOW_BITS)
// assert(input[SPLIT_IDX] == split[1] * (2 ** LOW_BITS) + split[0])
assert(fits_in_bits_sparse::<RED_R>(safe_split[0]))
CUR_WIDTH = 0
field right = 0
for u32 i in 0..SPLIT_IDX do
right = right + input[i] * (2 ** (2 * CUR_WIDTH))
CUR_WIDTH = CUR_WIDTH + LIMBWIDTH[i]
endfor
right = right + safe_split[0] * (2 ** (2 * CUR_WIDTH))
// CUR_WIDTH = RED_R
CUR_WIDTH = RED_L
field left = safe_split[1]
for u32 i in (SPLIT_IDX+1)..N do
left = left + input[i] * (2 ** (2 * CUR_WIDTH))
CUR_WIDTH = CUR_WIDTH + LIMBWIDTH[i]
endfor
return [right, left] // right = low_bits, left = high_bits
// constant-offset rotation (sparse->sparse) (when LIMBWIDTH[0] != R and LIMBWIDTH[0] + LIMBWIDTH[1] != R)
def rotr<N, R>(field[N] input, u32[N] LIMBWIDTH_ORI, u32[N] LIMBWIDTH_NEW) -> field:
field[2] overall_split = split_for_shift::<N, R>(input, LIMBWIDTH_ORI)
u32 TOTAL_WIDTH = 0
for u32 i in 0..N do
TOTAL_WIDTH = TOTAL_WIDTH + LIMBWIDTH_ORI[i]
endfor
assert(TOTAL_WIDTH == 32)
field output_val = overall_split[0] * (2 ** (2 * (TOTAL_WIDTH - R))) + overall_split[1]
// return split_limbs_in_sparse::<N>(output_val, LIMBWIDTH_NEW)
return output_val
// constant-offset shift (sparse->sparse) (when LIMBWIDTH[0] != R and LIMBWIDTH[0] + LIMBWIDTH[1] != R)
def shr<N, R>(field[N] input,u32[N] LIMBWIDTH_ORI, u32[N] LIMBWIDTH_NEW) -> field:
field[2] overall_split = split_for_shift::<N, R>(input, LIMBWIDTH_ORI)
field output_val = overall_split[1]
// return split_limbs_in_sparse::<N>(output_val, LIMBWIDTH_NEW)
return output_val
// N-ary XOR for 10-bit values (sparse to dense) where N = 2 or 3
def xor_10<N>(field[N] input) -> field:
assert(N == 2 || N == 3)
field sum = 0
for u32 i in 0..N do
sum = sum + input[i]
endfor
Dual dual = split_even_dual_10(sum)
return dual.d
// N-ary XOR for 11-bit values (sparse to dense) where N = 2 or 3
def xor_11<N>(field[N] input) -> field:
assert(N == 2 || N == 3)
field sum = 0
for u32 i in 0..N do
sum = sum + input[i]
endfor
Dual dual = split_even_dual_11(sum)
return dual.d
// N-ary XOR for value in limb representation (sparse to dense) where N = 2 or 3
def xor_for_all_limbs<N>(field[3] input, u32[3] LIMBWIDTH) -> field[3]:
field int = 0
for u32 i in 0..3 do
int = int + input[i]
endfor
return split_even_dual_for_all_limbs(int, LIMBWIDTH)
// 2-ary AND for 10-bit values (sparse to Dual)
def and_10(field[2] input) -> Dual:
// Dual dual = split_odd_dual_10(input[0] + input[1])
// return dual.s
return split_odd_dual_10(input[0] + input[1])
// 2-ary AND for 11-bit values (sparse to Dual)
def and_11(field[2] input) -> Dual:
// Dual dual = split_odd_dual_11(input[0] + input[1])
// return dual.s
return split_odd_dual_11(input[0] + input[1])
// 2-ary AND for value in limb representation (sparse to dual)
def and(field[3] x, field[3] y) -> Dual[3]:
Dual[3] output = [Dual {d: 0, s: 0} ; 3]
output[0] = and_11([x[0], y[0]])
output[1] = and_11([x[1], y[1]])
output[2] = and_10([x[2], y[2]])
return output
// // 2-ary AND for value in limb representation (sparse to sparse)
// // LIMBWIDTH = [11, 11, 10]
// def and_s2s(field[3] x, field[3] y) -> field[3]:
// // field[3] output = [0; 3]
// // output[0] = and_11([x[0], y[0]])
// // output[1] = and_11([x[1], y[1]])
// // output[2] = and_10([x[2], y[2]])
// // return output
// Dual[3] output = and(x, y)
// return dual_limbs_to_sparse_limbs(output)
// 2-ary AND for value in limb representation (sparse to dense)
// LIMBWIDTH = [11, 11, 10]
def and_s2d(field[3] x, field[3] y) -> field[3]:
Dual[3] output = and(x, y)
return dual_limbs_to_dense_limbs(output)
// NOT for 10-bit values (sparse to sparse)
def not_10(field input) -> field:
return S_ONES_10 - input
// NOT for 11-bit values (sparse to sparse)
def not_11(field input) -> field:
return S_ONES_11 - input
// 2-ary NOT for value in limb representation (sparse to sparse)
// LIMBWIDTH = [11, 11, 10]
def not(field[3] input) -> field[3]:
field[3] output = [0; 3]
output[0] = not_11(input[0])
output[1] = not_11(input[1])
output[2] = not_10(input[2])
return output
// N-ary ADD modulo 2^32 (Convert N dense-single values to M limbs in dual form)
// C = \ceil{log2 N}
// Note: Should also work for modulo 2^K
def sum<N, M, C, CM>(field[N] input, u32[M] LIMBWIDTH) -> Dual[M]:
assert((1 << C) >= N)
field sum = 0
for u32 i in 0..N do
sum = sum + input[i]
endfor
u32 MP1 = M + 1
u32[MP1] SPLITWIDTH = [...LIMBWIDTH, C]
unsafe witness field[MP1] split = unsafe_split_dyn::<MP1>(sum, SPLITWIDTH)
field[MP1] safe_split = [0, ...split[1..MP1]]
safe_split[0] = sum - combine_limbs::<M>(safe_split[1..MP1], SPLITWIDTH[1..MP1]) * (2 ** (LIMBWIDTH[0]))
assert(fits_in_bits::<CM>(safe_split[M]))
field res_sum = combine_limbs::<M>(safe_split[0..MP1], LIMBWIDTH)
// assert(sum == split[M] * (2 ** TOTAL_WIDTH) + res_sum)
return dense_limb_to_dual_limb::<M>(safe_split[0..M], LIMBWIDTH)

View File

@@ -0,0 +1,125 @@
// from "certificate" import Certificate
struct BigNatParams {
field max_words //max value for each limb
//u32 limb_width//should be no need now
//u64 n_limbs
}
struct BigNatb<N, W> {
bool[N][W] limbs
//BigNatParams bparams
}
struct BigNatb_v2<Nm1, W, W2> {
bool[Nm1][W] limbs
bool[W2] limb
}
struct BigNat<N> {
field[N] limbs
//BigNatParams bparams
}
struct GpBigNats<NG> {
BigNat<NG> left
BigNat<NG> right
}
struct BigNatModMult<W, A, Z, ZG, CW, Q, V> {
BigNat<Z> z
BigNat<V> v
BigNatb<Q, W> quotientb
bool[ZG][CW] carry
BigNatb<A, W> res
}
struct BigNatModMult_v4<W, A, Z, CW, Q, V> { // be careful of the generics
BigNat<Z> z
BigNat<V> v
BigNatb<Q, W> quotientb
bool[CW] carry
BigNatb<A, W> res
}
struct BigNatModMult_v5<W, W2, A, Z, CW, Qm1, V> { // be careful of the generics
BigNat<Z> z
BigNat<V> v
BigNatb_v2<Qm1, W, W2> quotientb
// BigNatb<Q, W> quotientb
bool[CW] carry
BigNatb<A, W> res
}
struct BigNatModMult_v6<W, W2, A, Z, ZG, Qm1, V> { // be careful of the generics
BigNat<Z> z
BigNat<V> v
BigNatb_v2<Qm1, W, W2> quotientb
// BigNatb<Q, W> quotientb
// bool[CW] carry
field[ZG] carry
BigNatb<A, W> res
}
struct BigNatModMultwores_v5<W, W2, Z, V, Qm1, CW> { // be careful of the generics
BigNat<Z> z
BigNat<V> v
BigNatb_v2<Qm1, W, W2> quotientb
bool[CW] carry
}
struct BigNatModMult_v2<W, W2, Am1, Z, ZG, CW, Qm1, V> {
BigNat<Z> z
BigNat<V> v
BigNatb_v2<Qm1, W, W2> quotientb
bool[ZG][CW] carry
BigNatb_v2<Am1, W, W2> res
}
struct BigNatMod<W, A, ZG, CW, Q, V> {
BigNat<V> v
BigNatb<Q, W> quotientb
bool[ZG][CW] carry
BigNatb<A, W> res
}
// BigNatMont<W, Z, ZG, CW, P, Q, V>[EXPBITS] mont
// def MonPro<W, Z, ZG, ZGW, P, Q, QW, V, CW>(BigNat<P> a, BigNat<P> b, BigNat<P> modulus, BigNat<Q> mod_prim, BigNatb<P, W>[3] res, BigNatModMult<W, Z, ZG, CW, Q, V>[3] mm, bool greaterthanp, bool[ZG][ZGW] carry) -> BigNat<P>: //assume we know the number of limbs at compile time
// BigNat<P> cur_x = MonPro::<W, Z, ZG, ZGW, P, Q, W, V, CW>(init_mont, x, modul, mod_prim, mont[0].res, mont[0].mm, mont[0].greaterthanp, mont[0].carry) // compute MonPro(a~, x~) // assume A = P
struct BigNatMont<W, Z, ZG, CW, P, Q, V> {
BigNatb<P, W>[3] res
BigNatModMult<W, Z, ZG, CW, Q, V>[3] mm
bool greaterthanp
bool[ZG][CW] carry
}
struct BigNatAdd<Z, ZG, ZGW, Q, QW, V> {
BigNat<V> v
BigNatb<Q, QW> quotientb
bool[ZG][ZGW] carry
}
// u32 AC = NG+1
// u32 ZG = NG-1
struct ModuloConst<ZG, NG, AC>{
u8[ZG] CW_list
field[NG] gp_maxword
field[AC] aux_const
}
struct ModuloHelperConst<ZG, NG, AC>{
ModuloConst<ZG, NG, AC> moduloconst
field shift
}
// r = 2^4096
const BigNat<34> r = BigNat {limbs: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10141204801825835211973625643008]}
// const BigNat<NLIMBSP1> r = BigNat {limbs: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]}
def main(BigNatb<10, 256> a, BigNat<10> b) -> bool:
return true

View File

@@ -0,0 +1,6 @@
from "utils" import Dual
const field[64] K_DD = [1116352408, 1899447441, 3049323471, 3921009573, 961987163, 1508970993, 2453635748, 2870763221, 3624381080, 310598401, 607225278, 1426881987, 1925078388, 2162078206, 2614888103, 3248222580, 3835390401, 4022224774, 264347078, 604807628, 770255983, 1249150122, 1555081692, 1996064986, 2554220882, 2821834349, 2952996808, 3210313671, 3336571891, 3584528711, 113926993, 338241895, 666307205, 773529912, 1294757372, 1396182291, 1695183700, 1986661051, 2177026350, 2456956037, 2730485921, 2820302411, 3259730800, 3345764771, 3516065817, 3600352804, 4094571909, 275423344, 430227734, 506948616, 659060556, 883997877, 958139571, 1322822218, 1537002063, 1747873779, 1955562222, 2024104815, 2227730452, 2361852424, 2428436474, 2756734187, 3204031479, 3329325298]
// const field[64][3] K_D = [[1944, 325, 266], [1169, 1768, 452], [975, 31, 727], [933, 1723, 934], [603, 728, 229], [497, 1570, 359], [676, 2032, 584], [1749, 907, 684], [664, 245, 864], [769, 107, 74], [1470, 1584, 144], [1475, 399, 340], [1396, 1995, 458], [510, 982, 515], [1703, 896, 623], [372, 894, 774], [449, 877, 914], [1926, 1992, 958], [1478, 51, 63], [460, 404, 144], [1135, 1317, 183], [1194, 1680, 297], [476, 1557, 370], [218, 1841, 475], [338, 1994, 608], [1645, 1592, 672], [1992, 100, 704], [1991, 815, 765], [1011, 1025, 795], [327, 1266, 854], [849, 332, 27], [359, 1317, 80], [645, 1761, 158], [312, 868, 184], [1532, 1421, 308], [1299, 1793, 332], [852, 334, 404], [699, 1345, 473], [302, 89, 519], [1157, 1605, 585], [161, 2045, 650], [1611, 844, 672], [880, 369, 777], [419, 1418, 797], [25, 605, 838], [1572, 800, 858], [1413, 454, 976], [112, 1364, 65], [278, 1176, 102], [1032, 1773, 120], [1868, 270, 157], [1205, 1559, 210], [1203, 897, 228], [586, 789, 315], [591, 921, 366], [2035, 1485, 416], [750, 496, 466], [879, 1196, 482], [20, 271, 531], [520, 224, 563], [2042, 2015, 578], [1259, 525, 657], [1015, 1844, 763], [242, 1583, 793]]
// const field[64][3] K_S = [[1392960, 69649, 65604], [1065217, 1332288, 86032], [348245, 341, 282901], [345105, 1328453, 345108], [266565, 282944, 21521], [87297, 1311748, 70677], [279568, 1398016, 266304], [1331473, 344133, 279632], [278848, 21777, 332800], [327681, 5189, 4164], [1131860, 1312000, 16640], [1134597, 82005, 69904], [1119504, 1396805, 86084], [87380, 348436, 262149], [1328149, 344064, 267349], [70928, 333140, 327700], [86017, 332881, 344324], [1392660, 1396800, 345428], [1134612, 1285, 1365], [86096, 82192, 16640], [1053781, 1115153, 17685], [1066052, 1327360, 66625], [86352, 1310993, 70916], [20804, 1377537, 86341], [69892, 1396804, 267264], [1315921, 1312064, 279552], [1396800, 5136, 282624], [1396757, 328789, 283985], [349445, 1048577, 328005], [69653, 1070340, 332052], [332033, 69712, 325], [70677, 1115153, 4352], [278545, 1332225, 16724], [66880, 332816, 17728], [1135952, 1130577, 66832], [1114373, 1376257, 69712], [332048, 69716, 82192], [279877, 1118209, 86337], [66644, 4417, 262165], [1064977, 1314833, 266305], [17409, 1398097, 278596], [1314885, 331856, 279552], [333056, 70913, 327745], [82949, 1130564, 328017], [321, 266577, 331796], [1311760, 328704, 332100], [1130513, 86036, 348416], [5376, 1118480, 4097], [65812, 1065280, 5140], [1048640, 1332305, 5440], [1380432, 65620, 16721], [1066257, 1310997, 20740], [1066245, 344065, 21520], [266308, 327953, 66885], [266325, 344385, 70740], [1398021, 1134673, 82944], [283732, 87296, 86276], [332885, 1066064, 87044], [272, 65621, 262405], [262208, 21504, 263429], [1398084, 1397077, 266244], [1070149, 262225, 278785], [349461, 1377552, 283973], [21764, 1311829, 328001]]
const Dual[8][3] IV_S = [[Dual {d: 1639, s: 1315861},Dual {d: 316, s: 66896},Dual {d: 424, s: 83008}], [Dual {d: 1669, s: 1327121},Dual {d: 1269, s: 1070353},Dual {d: 749, s: 283729}], [Dual {d: 882, s: 333060},Dual {d: 1502, s: 1134932},Dual {d: 241, s: 21761}], [Dual {d: 1338, s: 1115460},Dual {d: 510, s: 87380},Dual {d: 661, s: 278801}], [Dual {d: 639, s: 267605},Dual {d: 458, s: 86084},Dual {d: 324, s: 69648}], [Dual {d: 140, s: 16464},Dual {d: 173, s: 17489},Dual {d: 620, s: 267344}], [Dual {d: 427, s: 83013},Dual {d: 123, s: 5445},Dual {d: 126, s: 5460}], [Dual {d: 1305, s: 1114433},Dual {d: 1049, s: 1048897},Dual {d: 367, s: 70741}]]

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,74 @@
from "basic_op" import xor_11, xor_10, xor_for_all_limbs, rotr, shr, and_s2s, and_s2d, not
from "utils" import combine_limbs, combine_sparse_limbs, split_odd_dual_11, split_odd_dual_10, Dual, dual_limbs_to_dense_limbs
// SSIG0 (sparse to dense-single) function for SHA-256
def ssig0<N>(field[N] input, u32[N] LIMBWIDTH) -> field:
// u32[N] LIMBWIDTH = [11, 11, 10]
field[3] int = [0; 3]
int[0] = rotr::<N, 7>(input, LIMBWIDTH, LIMBWIDTH)
int[1] = rotr::<N, 18>(input, LIMBWIDTH, LIMBWIDTH)
int[2] = shr::<N, 3>(input, LIMBWIDTH, LIMBWIDTH)
field[N] output_limbs = xor_for_all_limbs::<3>(int, LIMBWIDTH)
return combine_limbs::<N>(output_limbs, LIMBWIDTH)
// SSIG1 (sparse to dense-single) function for SHA-256
def ssig1<N>(field[N] input, u32[N] LIMBWIDTH) -> field:
// u32[N] LIMBWIDTH = [11, 11, 10]
field[3] int = [0; 3]
int[0] = rotr::<N, 17>(input, LIMBWIDTH, LIMBWIDTH)
int[1] = rotr::<N, 19>(input, LIMBWIDTH, LIMBWIDTH)
int[2] = shr::<N, 10>(input, LIMBWIDTH, LIMBWIDTH)
field[N] output_limbs = xor_for_all_limbs::<3>(int, LIMBWIDTH)
return combine_limbs::<N>(output_limbs, LIMBWIDTH)
// bsig0 (sparse to dense-single) function for SHA-256
def bsig0<N>(field[N] input) -> field:
u32[N] LIMBWIDTH_ORI = [11, 11, 10]
u32[N] LIMBWIDTH_NEW = [10, 11, 11]
field[3] int = [0; 3]
int[0] = rotr::<N, 2>(input, LIMBWIDTH_ORI, LIMBWIDTH_NEW)
int[1] = rotr::<N, 13>(input, LIMBWIDTH_ORI, LIMBWIDTH_NEW)
int[2] = combine_sparse_limbs::<N>([input[2], input[0], input[1]], LIMBWIDTH_NEW) // ROTR^22
field[N] output_limbs = xor_for_all_limbs::<3>(int, LIMBWIDTH_ORI)
return combine_limbs::<N>(output_limbs, LIMBWIDTH_ORI)
// bsig1 (sparse to dense-single) function for SHA-256
def bsig1<N>(field[N] input) -> field:
u32[N] LIMBWIDTH_ORI = [11, 11, 10]
u32[N] LIMBWIDTH_NEW = [11, 10, 11]
field[3] int = [0; 3]
int[0] = rotr::<N, 6>(input, LIMBWIDTH_ORI, LIMBWIDTH_NEW)
int[1] = combine_sparse_limbs::<N>([input[1], input[2], input[0]], LIMBWIDTH_NEW)// ROTR^11
int[2] = rotr::<N, 25>(input, LIMBWIDTH_ORI, LIMBWIDTH_NEW)
field[N] output_limbs = xor_for_all_limbs::<3>(int, LIMBWIDTH_ORI)
return combine_limbs::<N>(output_limbs, LIMBWIDTH_ORI)
// MAJ (sparse to dense-single) function for SHA-256
// LIMBWIDTH = [11, 11, 10];
def maj<N>(field[3][N] input) -> field:
field[N] intermediate = [0; N]
for u32 i in 0..N do
intermediate[i] = input[0][i] + input[1][i] + input[2][i]
endfor
Dual[N] output_dual = [Dual{d: 0, s: 0}; N]
output_dual[0] = split_odd_dual_11(intermediate[0])
output_dual[1] = split_odd_dual_11(intermediate[1])
output_dual[2] = split_odd_dual_10(intermediate[2])
u32[N] LIMBWIDTH = [11, 11, 10]
field[N] output_limbs = dual_limbs_to_dense_limbs::<N>(output_dual)
return combine_limbs::<N>(output_limbs, LIMBWIDTH)
// CH (sparse to dense-single) function for SHA-256
// LIMBWIDTH = [11, 11, 10];
def ch<N>(field[3][N] input) -> field:
field[2][N] int = [[0; N]; 2]
int[0] = and_s2d(input[0], input[1]) // of type field[N]
int[1] = and_s2d(not(input[0]), input[2]) // of type field[N]
field[N] output_limbs = [0; N]
for u32 i in 0..N do
output_limbs[i] = int[0][i] + int[1][i] // replace xor with pure addition
endfor
u32[N] LIMBWIDTH = [11, 11, 10]
return combine_limbs::<N>(output_limbs, LIMBWIDTH)

View File

@@ -0,0 +1,25 @@
import "./shaRound" as shaRound
from "utils" import Dual, dual_limbs_to_dense_limbs, dense_limbs_to_dual_limbs, combine_limbs
from "const" import IV_S
// N: Number of invocations of sha256 blocks
// NL: Number of limbs
// output dense form of sha256(message)
// def main<N, NL>(field[N][16][NL] message) -> field[8][NL]:
def main<N, NL>(field[N][16][NL] message) -> field[8]: // for debug purpose
u32[NL] LIMBWIDTH = [11, 11, 10]
Dual[8][NL] current = IV_S
for u32 i in 0..N do
Dual[16][NL] cur_msg = dense_limbs_to_dual_limbs::<16, NL>(message[i], LIMBWIDTH) // implicitly do range checks for message
current = shaRound::<NL>(cur_msg, current, LIMBWIDTH)
endfor
// field[8][NL] output = [[0; NL]; 8]
// for u32 i in 0..8 do
// output[i] = dual_limbs_to_dense_limbs(current[i])
// endfor
field[8] output = [0; 8]
for u32 i in 0..8 do
output[i] = combine_limbs(dual_limbs_to_dense_limbs(current[i]), LIMBWIDTH)
endfor
return output

View File

@@ -0,0 +1,69 @@
from "logic_func" import ssig0, ssig1, bsig0, bsig1, ch, maj
from "utils" import Dual, combine_limbs, dual_limbs_to_sparse_limbs, dual_limbs_to_dense_limbs
from "basic_op" import sum
from "const" import K_DD // K_S
// N = number of limbs
def one_extend<N, CM>(Dual[4][N] w_input, u32[N] LIMBWIDTH) -> Dual[N]:
field[4] addend = [0; 4]
addend[0] = ssig1::<N>(dual_limbs_to_sparse_limbs(w_input[0]), LIMBWIDTH)
addend[1] = combine_limbs::<N>(dual_limbs_to_dense_limbs(w_input[1]), LIMBWIDTH)
addend[2] = ssig0::<N>(dual_limbs_to_sparse_limbs(w_input[2]), LIMBWIDTH)
addend[3] = combine_limbs::<N>(dual_limbs_to_dense_limbs(w_input[3]), LIMBWIDTH)
return sum::<4, N, 2, CM>(addend, LIMBWIDTH)
// Extension (48 rounds)
def whole_extend<N, CM>(Dual[16][N] message, u32[N] LIMBWIDTH) -> Dual[64][N]:
Dual[64][N] w = [...message, ...[[Dual{s: 0, d: 0}; N]; 48]]
for u32 i in 16..64 do
w[i] = one_extend::<N, CM>([w[i-2], w[i-7], w[i-15], w[i-16]], LIMBWIDTH)
endfor
return w
def one_main<N, CM>(Dual[8][N] input, field k, Dual[N] w, u32[N] LIMBWIDTH) -> Dual[8][N]:
field[5] t1 = [0; 5]
t1[0] = combine_limbs::<N>(dual_limbs_to_dense_limbs(input[7]), LIMBWIDTH)
t1[1] = bsig1::<N>(dual_limbs_to_sparse_limbs(input[4]))
field[3][N] input_to_ch = [dual_limbs_to_sparse_limbs(input[4]), dual_limbs_to_sparse_limbs(input[5]), dual_limbs_to_sparse_limbs(input[6])]
t1[2] = ch::<N>(input_to_ch)
t1[3] = k
t1[4] = combine_limbs::<N>(dual_limbs_to_dense_limbs(w), LIMBWIDTH)
field[2] t2 = [0; 2]
t2[0] = bsig0::<N>(dual_limbs_to_sparse_limbs(input[0]))
field[3][N] input_to_maj = [dual_limbs_to_sparse_limbs(input[0]), dual_limbs_to_sparse_limbs(input[1]), dual_limbs_to_sparse_limbs(input[2])]
t2[1] = maj::<N>(input_to_maj)
Dual[8][N] output = [[Dual{s: 0, d: 0}; N]; 8]
for u32 i in 0..8 do
u32 j = (i + 7) % 8
output[i] = input[j]
endfor
output[0] = sum::<7, N, 3, CM>([...t1, ...t2], LIMBWIDTH)
field d_val = combine_limbs::<N>(dual_limbs_to_dense_limbs(input[3]), LIMBWIDTH)
output[4] = sum::<6, N, 3, CM>([d_val, ...t1], LIMBWIDTH)
return output
// Round function (64 rounds)
def whole_main<N, CM>(Dual[8][N] current, Dual[64][N] w, u32[N] LIMBWIDTH) -> Dual[8][N]:
Dual[8][N] interm = current
for u32 i in 0..64 do
interm = one_main::<N, CM>(interm, K_DD[i], w[i], LIMBWIDTH)
endfor
return interm
// H(i) = H(i-1) + output of main round function
def compute_final_output<N, CM>(Dual[8][N] interm, Dual[8][N] current, u32[N] LIMBWIDTH) -> Dual[8][N]:
Dual[8][N] output = [[Dual{s: 0, d: 0}; N]; 8]
for u32 i in 0..8 do
field cur_val = combine_limbs::<N>(dual_limbs_to_dense_limbs(current[i]), LIMBWIDTH)
field interm_val = combine_limbs::<N>(dual_limbs_to_dense_limbs(interm[i]), LIMBWIDTH)
output[i] = sum::<2, N, 1, CM>([cur_val, interm_val], LIMBWIDTH)
endfor
return output
def main<N>(Dual[16][N] input, Dual[8][N] current, u32[3] LIMBWIDTH) -> Dual[8][N]:
u32 CM = 3
Dual[64][N] w = whole_extend::<N, CM>(input, LIMBWIDTH)
Dual[8][N] interm = whole_main::<N, CM>(current, w, LIMBWIDTH)
return compute_final_output::<N, CM>(interm, current, LIMBWIDTH)

View File

@@ -0,0 +1,10 @@
import "sha256" as sha256
const u32[3] LIMBWIDTH = [11, 11, 10]
// N: Number of invocations of sha256 blocks
// NL: Number of limbs
// input message is padded already
def test_sha256<N, NL>(field[8] expected_hash, field[N][16][NL] padded_message) -> bool:
field[8] actual_hash = sha256::<N, NL>(padded_message)
assert(expected_hash == actual_hash)
return true

View File

@@ -0,0 +1,7 @@
from "test_sha256_adv" import test_sha256
const u32 N = 1
const u32 NL = 3 // Number of limbs
def main(field[8] expected_hash, private field[N][16][NL] padded_message) -> bool:
return test_sha256::<N, NL>(expected_hash, padded_message)

View File

@@ -0,0 +1,7 @@
from "test_sha256_adv" import test_sha256
const u32 N = 8
const u32 NL = 3 // Number of limbs
def main(field[8] expected_hash, private field[N][16][NL] padded_message) -> bool:
return test_sha256::<N, NL>(expected_hash, padded_message)

View File

@@ -0,0 +1,396 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(padded_message.3.0.2 #f12)
(padded_message.0.1.0 #f513)
(padded_message.1.0.0 #f531)
(padded_message.5.0.1 #f6)
(padded_message.4.15.1 #f2016)
(padded_message.0.0.1 #f65)
(padded_message.0.2.1 #f66)
(padded_message.1.3.1 #f1614)
(padded_message.1.12.2 #f197)
(padded_message.1.15.0 #f49)
(padded_message.3.10.0 #f1503)
(padded_message.3.13.2 #f751)
(padded_message.4.14.2 #f340)
(padded_message.5.15.2 #f83)
(padded_message.3.6.1 #f961)
(padded_message.6.3.1 #f32)
(padded_message.3.14.2 #f621)
(padded_message.2.0.2 #f216)
(padded_message.2.7.0 #f1036)
(padded_message.7.6.2 #f0)
(expected_hash.1 #f1327195860)
(padded_message.5.14.0 #f582)
(padded_message.3.0.0 #f1795)
(padded_message.5.9.2 #f13)
(padded_message.1.13.1 #f737)
(padded_message.3.10.2 #f163)
(padded_message.6.0.1 #f1254)
(padded_message.2.6.1 #f192)
(padded_message.4.7.2 #f8)
(padded_message.7.4.2 #f0)
(padded_message.0.9.2 #f988)
(padded_message.1.8.2 #f192)
(padded_message.7.11.1 #f0)
(padded_message.2.2.0 #f304)
(padded_message.2.4.0 #f560)
(padded_message.3.3.1 #f71)
(padded_message.7.15.1 #f1)
(padded_message.4.11.1 #f192)
(expected_hash.7 #f1529670075)
(padded_message.5.12.1 #f1520)
(padded_message.5.4.2 #f315)
(padded_message.3.13.1 #f228)
(padded_message.4.14.1 #f930)
(padded_message.5.15.1 #f1646)
(padded_message.7.13.1 #f0)
(padded_message.7.3.2 #f0)
(padded_message.2.5.1 #f1570)
(padded_message.4.2.2 #f520)
(padded_message.4.9.0 #f1036)
(padded_message.0.4.2 #f900)
(padded_message.0.6.2 #f502)
(padded_message.5.2.2 #f56)
(padded_message.0.11.1 #f1569)
(padded_message.5.9.0 #f1315)
(padded_message.6.7.2 #f192)
(padded_message.4.13.1 #f384)
(padded_message.6.9.2 #f232)
(padded_message.7.6.0 #f0)
(padded_message.7.8.0 #f0)
(padded_message.1.5.2 #f457)
(padded_message.1.8.0 #f1539)
(padded_message.7.9.1 #f0)
(padded_message.2.1.0 #f563)
(padded_message.4.8.1 #f192)
(padded_message.3.10.1 #f273)
(expected_hash.5 #f2797358084)
(padded_message.0.15.2 #f192)
(padded_message.5.1.2 #f24)
(padded_message.5.4.0 #f1892)
(padded_message.5.6.0 #f605)
(padded_message.3.14.1 #f560)
(padded_message.6.4.2 #f193)
(padded_message.7.3.0 #f0)
(padded_message.0.12.0 #f853)
(padded_message.4.2.0 #f1328)
(padded_message.7.5.0 #f0)
(padded_message.0.1.2 #f640)
(padded_message.0.3.2 #f623)
(padded_message.0.6.0 #f501)
(padded_message.1.0.2 #f340)
(padded_message.1.2.2 #f413)
(padded_message.1.9.0 #f787)
(padded_message.4.7.1 #f240)
(padded_message.4.11.2 #f172)
(padded_message.5.12.2 #f510)
(padded_message.3.7.2 #f728)
(padded_message.4.10.1 #f320)
(padded_message.5.7.1 #f512)
(padded_message.5.11.1 #f657)
(padded_message.4.15.0 #f1026)
(padded_message.5.14.1 #f540)
(padded_message.6.8.1 #f1678)
(padded_message.7.4.1 #f0)
(expected_hash.0 #f2856353870)
(padded_message.0.7.1 #f416)
(padded_message.0.9.1 #f416)
(padded_message.1.6.1 #f1636)
(padded_message.4.13.2 #f192)
(padded_message.7.0.0 #f1285)
(padded_message.1.14.2 #f204)
(padded_message.0.12.2 #f36)
(padded_message.2.9.2 #f445)
(padded_message.4.3.0 #f304)
(padded_message.5.1.0 #f1309)
(padded_message.4.12.1 #f224)
(padded_message.5.3.0 #f961)
(padded_message.5.13.1 #f697)
(padded_message.6.1.2 #f24)
(padded_message.6.4.0 #f39)
(padded_message.3.2.2 #f272)
(padded_message.3.9.0 #f2022)
(padded_message.0.3.0 #f865)
(padded_message.1.2.0 #f1312)
(padded_message.5.2.1 #f130)
(padded_message.5.4.1 #f225)
(padded_message.6.11.2 #f449)
(padded_message.4.0.1 #f23)
(padded_message.7.1.1 #f0)
(padded_message.0.4.1 #f920)
(padded_message.0.11.0 #f816)
(padded_message.1.5.1 #f1741)
(padded_message.1.7.1 #f102)
(padded_message.6.5.1 #f261)
(padded_message.6.7.1 #f48)
(padded_message.1.10.2 #f41)
(padded_message.3.8.1 #f766)
(padded_message.2.12.2 #f24)
(padded_message.6.13.2 #f189)
(padded_message.2.2.2 #f192)
(padded_message.1.14.0 #f1073)
(padded_message.2.9.0 #f1125)
(padded_message.6.15.1 #f193)
(padded_message.0.13.1 #f194)
(padded_message.1.1.0 #f1903)
(padded_message.7.8.2 #f0)
(padded_message.6.1.0 #f774)
(padded_message.2.13.1 #f455)
(padded_message.3.2.0 #f1342)
(padded_message.5.1.1 #f106)
(padded_message.6.2.1 #f160)
(padded_message.0.14.0 #f290)
(padded_message.4.9.2 #f116)
(padded_message.2.8.1 #f1484)
(padded_message.3.11.1 #f1138)
(padded_message.0.1.1 #f96)
(padded_message.1.0.1 #f129)
(padded_message.2.10.1 #f1133)
(padded_message.6.14.2 #f197)
(padded_message.2.1.2 #f92)
(padded_message.2.3.2 #f192)
(padded_message.2.6.0 #f853)
(padded_message.2.15.2 #f537)
(padded_message.3.5.1 #f1629)
(padded_message.3.15.1 #f280)
(padded_message.2.12.0 #f646)
(padded_message.2.14.1 #f193)
(padded_message.3.1.0 #f1109)
(padded_message.3.3.0 #f1951)
(padded_message.5.6.2 #f983)
(padded_message.5.8.2 #f668)
(padded_message.7.5.2 #f0)
(padded_message.2.7.1 #f97)
(padded_message.4.4.2 #f56)
(padded_message.7.7.2 #f0)
(padded_message.0.8.2 #f170)
(padded_message.4.12.0 #f769)
(padded_message.6.10.2 #f397)
(padded_message.2.11.2 #f193)
(padded_message.7.12.2 #f0)
(padded_message.3.0.1 #f32)
(padded_message.6.11.1 #f1389)
(padded_message.1.7.2 #f305)
(padded_message.1.9.2 #f340)
(expected_hash.2 #f3085693120)
(padded_message.2.3.0 #f310)
(padded_message.2.15.0 #f1597)
(padded_message.1.15.1 #f1543)
(padded_message.1.12.1 #f102)
(padded_message.5.3.2 #f80)
(padded_message.5.5.2 #f961)
(padded_message.5.8.0 #f1798)
(padded_message.3.12.0 #f81)
(padded_message.6.6.2 #f4)
(padded_message.6.10.0 #f46)
(padded_message.2.0.1 #f1638)
(padded_message.4.1.2 #f567)
(padded_message.4.3.2 #f520)
(padded_message.0.5.2 #f67)
(padded_message.0.8.0 #f134)
(padded_message.1.4.2 #f464)
(padded_message.2.14.2 #f4)
(padded_message.4.4.0 #f853)
(padded_message.4.6.0 #f1027)
(padded_message.4.9.1 #f1184)
(padded_message.6.13.1 #f1262)
(padded_message.7.0.2 #f24)
(padded_message.3.9.2 #f290)
(padded_message.5.9.1 #f675)
(padded_message.7.6.1 #f0)
(padded_message.7.7.0 #f0)
(padded_message.7.8.1 #f0)
(padded_message.7.10.1 #f0)
(padded_message.7.12.0 #f0)
(padded_message.7.14.1 #f0)
(padded_message.7.15.2 #f0)
(padded_message.1.8.1 #f544)
(padded_message.7.2.0 #f0)
(padded_message.2.1.1 #f422)
(padded_message.4.5.0 #f257)
(padded_message.3.15.0 #f504)
(padded_message.0.0.2 #f194)
(padded_message.0.7.0 #f1545)
(padded_message.0.15.0 #f1539)
(padded_message.5.5.0 #f466)
(padded_message.7.11.2 #f0)
(padded_message.6.3.2 #f4)
(padded_message.3.12.2 #f889)
(padded_message.6.6.0 #f1287)
(padded_message.3.4.2 #f640)
(padded_message.5.6.1 #f1734)
(padded_message.0.5.0 #f902)
(padded_message.1.1.2 #f101)
(padded_message.1.4.0 #f869)
(padded_message.7.15.0 #f1568)
(padded_message.4.11.0 #f261)
(padded_message.4.2.1 #f387)
(padded_message.4.4.1 #f192)
(padded_message.5.12.0 #f1485)
(padded_message.0.6.1 #f932)
(padded_message.1.9.1 #f128)
(padded_message.7.3.1 #f0)
(padded_message.7.5.1 #f0)
(padded_message.7.13.2 #f0)
(padded_message.0.14.2 #f341)
(padded_message.4.15.2 #f7)
(padded_message.6.9.1 #f1509)
(expected_hash.6 #f186422342)
(padded_message.3.11.0 #f359)
(padded_message.5.0.0 #f29)
(padded_message.2.10.0 #f1901)
(padded_message.6.5.0 #f774)
(padded_message.6.7.0 #f1563)
(padded_message.4.13.0 #f1539)
(padded_message.2.4.2 #f204)
(padded_message.7.11.0 #f0)
(padded_message.4.10.2 #f192)
(padded_message.0.0.0 #f1316)
(padded_message.1.3.0 #f1395)
(padded_message.5.11.2 #f512)
(padded_message.7.14.2 #f0)
(padded_message.7.0.1 #f32)
(padded_message.1.11.1 #f104)
(padded_message.6.3.0 #f1118)
(padded_message.6.12.1 #f1517)
(padded_message.5.10.1 #f774)
(padded_message.3.1.2 #f264)
(padded_message.3.4.0 #f822)
(padded_message.3.6.0 #f1501)
(padded_message.3.13.0 #f1021)
(padded_message.4.14.0 #f769)
(padded_message.5.3.1 #f2039)
(padded_message.5.15.0 #f309)
(padded_message.4.1.1 #f98)
(padded_message.6.4.1 #f902)
(padded_message.4.12.2 #f20)
(padded_message.0.3.1 #f221)
(padded_message.1.2.1 #f1420)
(padded_message.0.11.2 #f280)
(padded_message.5.13.2 #f954)
(padded_message.7.10.2 #f0)
(padded_message.2.5.2 #f92)
(padded_message.2.8.0 #f1903)
(padded_message.7.12.1 #f0)
(padded_message.1.10.0 #f1107)
(padded_message.3.7.1 #f906)
(padded_message.7.13.0 #f0)
(padded_message.7.14.0 #f0)
(expected_hash.4 #f537200913)
(padded_message.6.0.0 #f106)
(expected_hash.3 #f203566965)
(padded_message.3.5.0 #f1593)
(padded_message.4.10.0 #f1544)
(padded_message.5.11.0 #f628)
(padded_message.7.9.2 #f0)
(padded_message.0.10.0 #f48)
(padded_message.2.9.1 #f1261)
(padded_message.4.6.2 #f1020)
(padded_message.0.15.1 #f1024)
(padded_message.1.1.1 #f237)
(padded_message.4.8.2 #f76)
(padded_message.6.1.1 #f261)
(padded_message.0.12.1 #f192)
(padded_message.7.10.0 #f0)
(padded_message.3.2.1 #f1291)
(padded_message.3.4.1 #f628)
(padded_message.5.13.0 #f1341)
(padded_message.6.14.0 #f816)
(padded_message.2.5.0 #f1328)
(padded_message.5.10.2 #f16)
(padded_message.5.7.2 #f567)
(padded_message.6.8.2 #f417)
(padded_message.7.2.2 #f0)
(padded_message.2.2.1 #f1766)
(padded_message.0.10.2 #f44)
(padded_message.2.4.1 #f1579)
(padded_message.0.7.2 #f192)
(padded_message.1.6.2 #f405)
(padded_message.4.5.2 #f116)
(padded_message.1.11.0 #f288)
(padded_message.4.8.0 #f853)
(padded_message.2.0.0 #f602)
(padded_message.5.10.0 #f22)
(padded_message.6.12.0 #f1895)
(padded_message.2.11.0 #f19)
(padded_message.3.1.1 #f0)
(padded_message.6.15.2 #f196)
(padded_message.7.9.0 #f0)
(padded_message.5.14.2 #f627)
(padded_message.2.13.0 #f1282)
(padded_message.1.13.0 #f1330)
(padded_message.0.14.1 #f614)
(padded_message.4.7.0 #f48)
(padded_message.1.10.1 #f234)
(padded_message.2.3.1 #f1798)
(padded_message.0.2.2 #f8)
(padded_message.0.9.0 #f257)
(padded_message.4.0.2 #f534)
(padded_message.2.10.2 #f185)
(padded_message.5.0.2 #f192)
(padded_message.5.7.0 #f72)
(padded_message.6.5.2 #f24)
(padded_message.6.8.0 #f1136)
(padded_message.1.14.1 #f1542)
(padded_message.3.6.2 #f201)
(padded_message.3.8.2 #f759)
(padded_message.0.13.0 #f770)
(padded_message.1.3.2 #f337)
(padded_message.1.6.0 #f76)
(padded_message.2.12.1 #f229)
(padded_message.1.11.2 #f129)
(padded_message.4.6.1 #f128)
(padded_message.2.14.0 #f42)
(padded_message.5.8.1 #f1539)
(padded_message.0.8.1 #f201)
(padded_message.2.13.2 #f291)
(padded_message.6.12.2 #f413)
(padded_message.6.15.0 #f43)
(padded_message.7.1.2 #f512)
(padded_message.7.4.0 #f0)
(padded_message.7.7.1 #f0)
(padded_message.5.2.0 #f1540)
(padded_message.6.0.2 #f116)
(padded_message.6.2.2 #f4)
(padded_message.1.13.2 #f120)
(padded_message.2.6.2 #f76)
(padded_message.2.8.2 #f168)
(padded_message.4.0.0 #f887)
(padded_message.0.2.0 #f256)
(padded_message.0.4.0 #f776)
(padded_message.1.5.0 #f355)
(padded_message.1.7.0 #f275)
(padded_message.3.15.2 #f691)
(padded_message.4.5.1 #f480)
(padded_message.6.9.0 #f1903)
(padded_message.6.11.0 #f302)
(padded_message.7.1.0 #f0)
(padded_message.3.3.2 #f209)
(padded_message.3.5.2 #f859)
(padded_message.1.15.2 #f220)
(padded_message.0.13.2 #f16)
(padded_message.3.8.0 #f1905)
(padded_message.5.5.1 #f1696)
(padded_message.0.10.1 #f160)
(padded_message.4.3.1 #f387)
(padded_message.6.6.1 #f160)
(padded_message.1.12.0 #f816)
(padded_message.0.5.1 #f1977)
(padded_message.1.4.1 #f1034)
(padded_message.6.13.0 #f1139)
(padded_message.3.12.1 #f1173)
(padded_message.3.14.0 #f1764)
(padded_message.2.7.2 #f16)
(padded_message.2.11.1 #f806)
(padded_message.4.1.0 #f1443)
(padded_message.6.10.1 #f1646)
(padded_message.3.9.1 #f1443)
(padded_message.3.11.2 #f198)
(padded_message.7.2.1 #f0)
(padded_message.6.2.0 #f1287)
(padded_message.6.14.1 #f1126)
(padded_message.2.15.1 #f281)
(padded_message.3.7.0 #f1559)
) true;ignored
))

View File

@@ -0,0 +1,219 @@
from "assert_well_formed" import fits_in_bits, fits_in_bits_sparse
from "EMBED" import unpack, reverse_lookup //, value_in_array
from "const_range_check" import D_TO_S_10, D_TO_S_11
struct Dual {
field s
field d
}
def ceildiv(u32 x, u32 y) -> u32:
return (x + y - 1) / y
// Reverse the limbs
def reverse_limbs<N>(field[N] input) -> field[N]:
field[N] output = [0; N]
for u32 i in 0..N do
output[i] = input[N-1-i]
endfor
return output
// convert the limb representation (in dense form) into a value
def combine_limbs<N>(field[N] input, u32[N] LIMBWIDTH) -> field:
field output = 0
u32 CUR_WIDTH = 0
for u32 i in 0..N do
u32 W = LIMBWIDTH[i]
output = output + input[i] * (2 ** CUR_WIDTH)
CUR_WIDTH = CUR_WIDTH + LIMBWIDTH[i]
endfor
return output
// convert the limb representation (in sparse form) into a value
def combine_sparse_limbs<N>(field[N] input, u32[N] LIMBWIDTH) -> field:
u32[N] SPARSE_LIMBWIDTH = [0; N]
for u32 i in 0..N do
SPARSE_LIMBWIDTH[i] = 2 * LIMBWIDTH[i]
endfor
return combine_limbs::<N>(input, SPARSE_LIMBWIDTH)
// split a number into (unchecked) high and low bits
def unsafe_split<LOW_BITS,HIGH_BITS>(field x) -> field[2]:
u32 TOTAL_BITS = LOW_BITS + HIGH_BITS
bool[TOTAL_BITS] bits = unpack(x)
field low = 0
field high = 0
for u32 i in 0..LOW_BITS do
low = low + (2 ** i) * (if bits[TOTAL_BITS-1-i] then 1 else 0 fi)
endfor
// for u32 i in LOW_BITS..HIGH_BITS do
for u32 i in LOW_BITS..TOTAL_BITS do
// high = high + 2 ** i * (if bits[LOW_BITS+HIGH_BITS-1-i] then 1 else 0 fi)
high = high + (2 ** (i-LOW_BITS)) * (if bits[TOTAL_BITS-1-i] then 1 else 0 fi)
endfor
return [low, high]
// split a number into (unchecked) N limbs
def unsafe_split_dyn<N>(field x, u32[N] LIMBWIDTH) -> field[N]:
u32 TOTAL_WIDTH = 0
for u32 i in 0..N do
TOTAL_WIDTH = TOTAL_WIDTH + LIMBWIDTH[i]
endfor
bool[TOTAL_WIDTH] bits = unpack(x)
field[N] output = [0; N]
u32 idx = TOTAL_WIDTH-1
for u32 i in 0..N do
for u32 j in 0..LIMBWIDTH[i] do
output[i] = output[i] + 2 ** j * (if bits[idx] then 1 else 0 fi)
idx = idx - 1
endfor
endfor
return output
// split a number in sparse form into (unchecked) N limbs
// Note: LIMBWIDTH is unsparsed
def unsafe_split_dyn_sparse<N>(field x, u32[N] LIMBWIDTH) -> field[N]:
u32[N] LIMBWIDTH_SPARSE = [0; N]
for u32 i in 0..N do
LIMBWIDTH_SPARSE[i] = 2 * LIMBWIDTH[i]
endfor
return unsafe_split_dyn::<N>(x, LIMBWIDTH_SPARSE)
// split a 2W bit number into (unchecked) even and odd bits (in sparse form)
def unsafe_separate_sparse<N>(field x) -> field[2]:
bool[2*N] bits = unpack(x)
field even = 0
field odd = 0
for u32 i in 0..N do
even = even + 4 ** i * (if bits[2*N-1-(2*i)] then 1 else 0 fi)
odd = odd + 4 ** i * (if bits[2*N-1-(2*i+1)] then 1 else 0 fi)
endfor
return [even, odd]
// - Split input into limbs according to LIMBWIDTH
// - Check that the split limbs are sparse forms of desired bitwidths
def split_limbs_in_sparse<N>(field input, u32[N] LIMBWIDTH) -> field[N]:
unsafe witness field[N] output_limbs = unsafe_split_dyn_sparse::<N>(input, LIMBWIDTH) // should not cost any constraint
field[N] safe_output_limbs = [0, ...output_limbs[1..N]]
u32 Nm1 = N - 1
safe_output_limbs[0] = input - combine_sparse_limbs::<Nm1>(safe_output_limbs[1..N], LIMBWIDTH[1..N]) * (2 ** (2 * LIMBWIDTH[0])) // output_limbs[N-1]||..||output_limbs[0] = overall_split[0]||overall_split[1]
field check_left = 0
// u32 CUR_WIDTH = 0
for u32 i in 0..N do
u32 W = LIMBWIDTH[i]
// Check that the output limbs are well-formed
assert(fits_in_bits_sparse::<W>(output_limbs[i]))
endfor
return output_limbs
// ** to test
def split_limbs_in_sparse_to_dense<N>(field input, u32[N] LIMBWIDTH) -> field[N]:
unsafe witness field[N] output_limbs = unsafe_split_dyn_sparse::<N>(input, LIMBWIDTH) // should not cost any constraint
field[N] safe_output_limbs = [0, ...output_limbs[1..N]]
u32 Nm1 = N - 1
safe_output_limbs[0] = input - combine_sparse_limbs::<Nm1>(safe_output_limbs[1..N], LIMBWIDTH[1..N]) * (2 ** (2 * LIMBWIDTH[0])) // output_limbs[N-1]||..||output_limbs[0] = overall_split[0]||overall_split[1]
field check_left = 0
field[N] output_limbs_sparse = [0; N]
output_limbs_sparse[0] = reverse_lookup(D_TO_S_11, output_limbs[0])
output_limbs_sparse[1] = reverse_lookup(D_TO_S_11, output_limbs[1])
output_limbs_sparse[2] = reverse_lookup(D_TO_S_10, output_limbs[2])
return output_limbs_sparse
// get the old and even bits of a 2N-bit value in sparse form (without checking if they are well-formed)
def split_both_sparse_inner<W>(field x) -> field[2]:
unsafe witness field[2] split = unsafe_separate_sparse::<W>(x)
field[2] safe_split = [0, split[1]]
safe_split[0] = x - 2 * safe_split[1]
return safe_split
// get the even bits of a 2*10-bit value in dual form; ensures the value fits in 2*10 bits.
def split_even_dual_10(field x) -> Dual:
field[2] split = split_both_sparse_inner::<10>(x) // do I need to add unsafe witness here?
field even = split[0]
field odd = split[1]
field even_d = reverse_lookup(D_TO_S_10, even)
assert(fits_in_bits_sparse::<10>(odd))
return Dual { s: even, d: even_d }
// get the odd bits of a 2*10-bit value in dual form; ensures the value fits in 2*10 bits.
def split_odd_dual_10(field x) -> Dual:
field[2] split = split_both_sparse_inner::<10>(x) // do I need to add unsafe witness here?
field even = split[0]
field odd = split[1]
field odd_d = reverse_lookup(D_TO_S_10, odd) // implicitly does fits_in_bits_sparse::<10>(odd)
assert(fits_in_bits_sparse::<10>(even))
return Dual { s: odd, d: odd_d }
// get the even bits of a 2*11-bit value in dual form; ensures the value fits in 2*11 bits.
def split_even_dual_11(field x) -> Dual: // it can probably merged with split_even_dual_10
field[2] split = split_both_sparse_inner::<11>(x) // do I need to add unsafe witness here?
field even = split[0]
field odd = split[1]
field even_d = reverse_lookup(D_TO_S_11, even)
assert(fits_in_bits_sparse::<11>(odd))
return Dual { s: even, d: even_d }
// ** to test
// return dense form of even bits
def split_even_dual_for_all_limbs(field x, u32[3] LIMBWIDTH) -> field[3]:
u32 TOTAL_WIDTH = 32
field[2] split = split_both_sparse_inner::<TOTAL_WIDTH>(x)
field even = split[0]
field odd = split[1]
field[3] even_dense = split_limbs_in_sparse_to_dense::<3>(even, LIMBWIDTH)
field[3] odd_sparse = split_limbs_in_sparse::<3>(odd, LIMBWIDTH) // for range check only
return even_dense
// get the odd bits of a 2*11-bit value in dual form; ensures the value fits in 2*11 bits.
def split_odd_dual_11(field x) -> Dual:
field[2] split = split_both_sparse_inner::<11>(x) // do I need to add unsafe witness here?
field even = split[0]
field odd = split[1]
field odd_d = reverse_lookup(D_TO_S_11, odd)
assert(fits_in_bits_sparse::<11>(even))
return Dual { s: odd, d: odd_d }
def dual_limbs_to_sparse_limbs<N>(Dual[N] input) -> field[N]:
field[N] output = [0; N]
for u32 i in 0..N do
output[i] = input[i].s
endfor
return output
def dual_limbs_to_dense_limbs<N>(Dual[N] input) -> field[N]:
field[N] output = [0; N]
for u32 i in 0..N do
output[i] = input[i].d
endfor
return output
// convert a dense W-bit value to dual form; ensures the value fits in W bits.
// Note: Lookup implicitly checks that the value fits in W bits
// Assume W = 10 or 11
def dense_to_dual<W>(field x) -> Dual:
assert(W == 10 || W == 11)
field s = if W == 10 then D_TO_S_10[x] else D_TO_S_11[x] fi
return Dual {s: s, d: x}
// def dense_to_dual_11_11_10(field[3] input) -> Dual[3]:
// return [dense_to_dual::<11>(input[0]), dense_to_dual::<11>(input[1]), dense_to_dual::<10>(input[2])]
// Convert input in dense form to dual form
def dense_limb_to_dual_limb<N>(field[N] input, u32[N] LIMBWIDTH) -> Dual[N]:
Dual[N] output = [Dual {s: 0, d: 0}; N]
for u32 i in 0..N do
u32 W = LIMBWIDTH[i]
output[i] = dense_to_dual::<W>(input[i])
endfor
return output
// Convert input in dense form to dual form
def dense_limbs_to_dual_limbs<N, NL>(field[N][NL] input, u32[N] LIMBWIDTH) -> Dual[N][NL]:
Dual[N][NL] output = [[Dual {s: 0, d: 0}; NL]; N]
for u32 i in 0..N do
output[i] = dense_limb_to_dual_limb::<NL>(input[i], LIMBWIDTH)
endfor
return output

View File

@@ -1,8 +1,8 @@
def main(field x) -> field:
transcript field[25] A = [0; 25]
for field counter in 0..30 do
bool oob = counter < x
cond_store(A, if oob then counter else 0 fi, x, oob)
bool inbound = counter < x
cond_store(A, if inbound then counter else 0 fi, x, inbound)
endfor
return A[x]

View File

@@ -1,7 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(x #f6)
(return #f6)
(return #f0)
) false ; ignored
))

View File

@@ -109,18 +109,14 @@ def normalize_sum_4(field s) -> Dual:
//do a bitwise AND.
def main(private field dense_x, private field dense_y) -> field:
Dual z = dense_to_dual_4(0)
Dual x = dense_to_dual_4(dense_x) // 1010 (10)
Dual y = dense_to_dual_4(dense_y) // 1001 (9)
Dual a = and_4(x, y) // 1000 (8)
for field i in 0..10 do
a = and_4(a, y) // idempotent
endfor
Dual b = or_4(x, y) // 1011 (11)
Dual s = normalize_sum_4(b.d + a.d) // 0011 (3)
Dual x = dense_to_dual_4(dense_x) // 10001000 (136)
Dual y = dense_to_dual_4(dense_y) // 10000001 (129)
Dual a = and_4(x, y) // 10000000
Dual b = or_4(x, y) // 10001001
Dual c = xor_4(x, y, z) // 00001001
Dual d = maj_4(x, y, c) // 10001001
Dual s = normalize_sum_4(d.d + c.d + b.d + a.d) // 10011011 (128+27=155)
return s.d
// return reverse_lookup(D_TO_S_4, dense_x) * dense_y
// return reverse_lookup(D_TO_S_4, dense_x) * reverse_lookup(D_TO_S_4, dense_y)
// return dense_x * dense_y

View File

@@ -266,6 +266,7 @@ fn main() {
Mode::Proof | Mode::ProofOfHighValue(_) => {
let mut opts = Vec::new();
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::DeskolemizeWitnesses);
opts.push(Opt::ScalarizeVars);
opts.push(Opt::Flatten);
@@ -309,11 +310,17 @@ fn main() {
let cs = cs.get("main");
trace!("IR: {}", circ::ir::term::text::serialize_computation(cs));
let mut r1cs = to_r1cs(cs, cfg());
if cfg().r1cs.profile {
println!("R1CS stats: {:#?}", r1cs.stats());
}
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
r1cs = reduce_linearities(r1cs, cfg());
println!("Final R1cs size: {}", r1cs.constraints().len());
if cfg().r1cs.profile {
println!("R1CS stats: {:#?}", r1cs.stats());
}
let (prover_data, verifier_data) = r1cs.finalize(cs);
match action {
ProofAction::Count => (),

View File

@@ -71,8 +71,8 @@ transcript_type_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok "covering ROM"
# A=400; N=20; L=2; expected cost ~= N + A(L+1) = 1220
cs_count_test ./examples/ZoKrates/pf/mem/rom.zok 1230
ram_test ./examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/2024_05_24_benny_bug_tr.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/volatile.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
# waksman is broken for non-scalar array values

View File

@@ -1007,7 +1007,7 @@ impl<'ast> ZGen<'ast> {
format!(
"Undefined const identifier {} in {}",
&i.value,
self.cur_path().canonicalize().unwrap().to_string_lossy()
self.cur_path().to_string_lossy()
)
}),
_ => match self

View File

@@ -46,7 +46,7 @@ pub fn fold(node: &Term, ignore: &[Op]) -> Term {
// make the cache unbounded during the fold_cache call
let old_capacity = cache.cap();
cache.resize(std::usize::MAX);
cache.resize(usize::MAX);
let ret = fold_cache(node, &mut cache, ignore);
// shrink cache to its max size
@@ -59,10 +59,15 @@ pub fn fold(node: &Term, ignore: &[Op]) -> Term {
pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> Term {
// (node, children pushed)
let mut stack = vec![(node.clone(), false)];
let mut retainer: Vec<Term> = Vec::new();
// Maps terms to their rewritten versions.
while let Some((t, children_pushed)) = stack.pop() {
if cache.contains(&t.downgrade()) {
if cache
.get(&t.downgrade())
.and_then(|x| x.upgrade())
.is_some()
{
continue;
}
if !children_pushed {
@@ -72,10 +77,12 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T
}
let mut c_get = |x: &Term| -> Term {
cache
let weak = cache
.get(&x.downgrade())
.and_then(|x| x.upgrade())
.expect("postorder cache")
.unwrap_or_else(|| panic!("postorder cache missing key: {} {}", x.id(), x));
weak.upgrade().unwrap_or_else(|| {
panic!("postorder cache missing value: {} -> {}", x.id(), weak.id())
})
};
if ignore.contains(t.op()) {
@@ -401,11 +408,14 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T
new_t_opt.unwrap_or_else(|| term(t.op().clone(), t.cs().iter().map(cc_get).collect()))
};
cache.put(t.downgrade(), new_t.downgrade());
retainer.push(new_t);
}
cache
let result = cache
.get(&node.downgrade())
.and_then(|x| x.upgrade())
.expect("postorder cache")
.expect("postorder cache");
std::mem::drop(retainer);
result
}
fn neg_bool(t: Term) -> Term {

View File

@@ -54,7 +54,8 @@ pub fn deskolemize_witnesses(comp: &mut Computation) {
if let Op::Witness(prefix) = orig.op() {
let name = self.0.mk_uniq(prefix);
let sort = check(orig);
let var = computation.new_var(&name, sort, Some(PROVER_ID), Some(orig.clone()));
let var =
computation.new_var(&name, sort, Some(PROVER_ID), Some(orig.cs()[0].clone()));
Some(var)
} else {
None
@@ -69,7 +70,7 @@ pub fn deskolemize_witnesses(comp: &mut Computation) {
.chain(comp.precomputes.outputs().iter().map(|o| o.0.clone()))
.chain(comp.precomputes.inputs().iter().map(|o| o.0.clone()));
let uniqer = Uniquer::new(names_iterator);
WitPass(uniqer).traverse_full(comp, false, false);
WitPass(uniqer).traverse_full(comp, true, true);
}
/// Replace the challenge terms in this computation with random inputs.
@@ -101,7 +102,10 @@ pub fn deskolemize_challenges(comp: &mut Computation) {
let round = match t.op() {
Op::Var(n, _) => {
if let Some(v) = comp.precomputes.outputs().get(n) {
*min_round.borrow().get(v).unwrap()
*min_round
.borrow()
.get(v)
.unwrap_or_else(|| panic!("missing key: {}", v))
} else {
0
}
@@ -177,12 +181,15 @@ pub fn deskolemize_challenges(comp: &mut Computation) {
let mut challs = TermMap::default();
for t in comp.terms_postorder() {
if let Op::PfChallenge(name, field) = t.op() {
let round = *actual_round.get(&t).unwrap();
debug!("challenge {name}: round = {round}");
trace!("challenge term {t}");
let md = VariableMetadata {
name: name.clone(),
random: true,
vis: None,
sort: Sort::Field(field.clone()),
round: *actual_round.get(&t).unwrap(),
round,
..Default::default()
};
let var = comp.new_var_metadata(md, None);
@@ -247,7 +254,12 @@ impl RewritePass for Pass {
_rewritten_children: F,
) -> Option<Term> {
if let Op::PfChallenge(..) = orig.op() {
Some(self.0.get(orig).unwrap().clone())
Some(
self.0
.get(orig)
.unwrap_or_else(|| panic!("missing key: {}", orig))
.clone(),
)
} else {
None
}

View File

@@ -122,6 +122,15 @@ impl OblivRewriter {
None,
)
}
Op::Witness(s) => {
let arg = &t.cs()[0];
(
self.tups
.get(arg)
.map(|targ| term![Op::Witness(s.clone()); targ.clone()]),
None,
)
}
Op::Eq => {
let a = &t.cs()[0];
let b = &t.cs()[1];
@@ -188,10 +197,19 @@ pub fn elim_obliv(c: &mut Computation) {
for t in c.terms_postorder() {
pass.visit(&t);
}
for t in PostOrderIter::from_roots_and_skips(
c.precomputes.outputs.values().cloned(),
Default::default(),
) {
pass.visit(&t);
}
for o in &mut c.outputs {
debug_assert!(check(o).is_scalar());
*o = pass.get(o).clone();
}
for v in c.precomputes.outputs.values_mut() {
*v = pass.get(v).clone();
}
}
fn arr_val_to_tup(v: &Value) -> Value {

View File

@@ -28,10 +28,9 @@ pub fn lookup(c: &mut Computation, ns: Namespace, haystack: Vec<Term>, needles:
}
let sort = check(&haystack[0]);
let f = sort.as_pf().clone();
let array_op = Op::Array(sort.clone(), sort.clone());
let haystack_array = term(array_op.clone(), haystack.clone());
let needles_array = term(array_op.clone(), needles.clone());
let counts_pre = unmake_array(term![Op::ExtOp(ExtOp::Haboeck); haystack_array, needles_array]);
let haystack_tup = term(Op::Tuple, haystack.clone());
let needles_tup = term(Op::Tuple, needles.clone());
let counts_pre = tuple_terms(term![Op::ExtOp(ExtOp::Haboeck); haystack_tup, needles_tup]);
let counts: Vec<Term> = counts_pre
.into_iter()
.enumerate()
@@ -53,22 +52,69 @@ pub fn lookup(c: &mut Computation, ns: Namespace, haystack: Vec<Term>, needles:
.cloned()
.collect(),
);
let haysum = term(
PF_ADD,
counts
.into_iter()
.zip(haystack)
.map(|(ct, hay)| term![PF_DIV; ct, term![PF_ADD; hay, key.clone()]])
.collect(),
// x_i + k
let needle_shifts: Vec<Term> = needles
.into_iter()
.map(|needle| term![PF_ADD; needle, key.clone()])
.collect();
// tup(1 / (x_i + k))
let needle_invs_tup: Term =
term![Op::ExtOp(ExtOp::PfBatchInv); term(Op::Tuple, needle_shifts.clone())];
// 1 / (x_i + k)
let needle_invs: Vec<Term> = (0..needle_shifts.len())
.map(|i| {
c.new_var(
&ns.fqn(format!("needi{i}")),
sort.clone(),
PROVER_VIS,
Some(term_c![Op::Field(i); needle_invs_tup]),
)
})
.collect();
let one = pf_lit(f.new_v(1));
let mut assertions: Vec<Term> = Vec::new();
// check 1 / (x_i + k)
assertions.extend(
needle_invs
.iter()
.zip(&needle_shifts)
.map(|(ix, x)| term![EQ; term_c![PF_MUL; ix, x], one.clone()]),
);
let needlesum = term(
PF_ADD,
needles
.into_iter()
.map(|needle| term![PF_RECIP; term![PF_ADD; needle, key.clone()]])
.collect(),
// 1 / (x_i + k)
let hay_shifts: Vec<Term> = haystack
.clone()
.into_iter()
.map(|hay| term![PF_ADD; hay, key.clone()])
.collect();
// tup(1 / (x_i + k))
let hay_invs_tup: Term =
term![Op::ExtOp(ExtOp::PfBatchInv); term(Op::Tuple, hay_shifts.clone())];
// ct_i / (x_i + k)
let hay_divs: Vec<Term> = (0..hay_shifts.len())
.zip(&counts)
.map(|(i, ct)| {
c.new_var(
&ns.fqn(format!("hayi{i}")),
sort.clone(),
PROVER_VIS,
Some(term![PF_MUL; ct.clone(), term_c![Op::Field(i); hay_invs_tup]]),
)
})
.collect();
assertions.extend(
hay_divs
.iter()
.zip(hay_shifts)
.zip(counts)
.map(|((div, hay_shift), ct)| term![EQ; term_c![PF_MUL; div, hay_shift.clone()], ct]),
);
term![Op::Eq; haysum, needlesum]
let needlesum = term(PF_ADD, needle_invs);
let haysum = term(PF_ADD, hay_divs);
assertions.push(term![Op::Eq; haysum, needlesum]);
term(AND, assertions)
}
/// Returns a term to assert.

View File

@@ -105,7 +105,6 @@ pub fn check_ram(c: &mut Computation, mut ram: Ram, cfg: &AccessCfg) {
let field_s = Sort::Field(field.clone());
let mut new_f_var =
|name: &str, val: Term| c.new_var(name, field_s.clone(), PROVER_VIS, Some(val));
let field_tuple_s = Sort::Tuple(Box::new([field_s.clone(), field_s.clone()]));
let mut uhf_inputs = inital_terms.clone();
uhf_inputs.extend(final_terms.iter().cloned());
let uhf_key = term(
@@ -113,22 +112,20 @@ pub fn check_ram(c: &mut Computation, mut ram: Ram, cfg: &AccessCfg) {
uhf_inputs,
);
let uhf = |idx: Term, val: Term| term![PF_ADD; val, term![PF_MUL; uhf_key.clone(), idx]];
let init_and_fin_values = make_array(
field_s.clone(),
field_tuple_s,
let init_and_fin_values = term(
Op::Tuple,
inital_terms
.iter()
.zip(&final_terms)
.map(|(i, f)| term![Op::Tuple; i.clone(), f.clone()])
.collect(),
);
let used_indices = make_array(
field_s.clone(),
field_s.clone(),
let used_indices = term(
Op::Tuple,
ram.accesses.iter().map(|a| a.idx.clone()).collect(),
);
let split = term![Op::ExtOp(ExtOp::PersistentRamSplit); init_and_fin_values, used_indices];
let unused_hashes: Vec<Term> = unmake_array(term![Op::Field(0); split.clone()])
let unused_hashes: Vec<Term> = tuple_terms(term![Op::Field(0); split.clone()])
.into_iter()
.enumerate()
.map(|(i, entry)| {
@@ -137,8 +134,8 @@ pub fn check_ram(c: &mut Computation, mut ram: Ram, cfg: &AccessCfg) {
new_f_var(&format!("__unused_hash.{j}.{i}"), uhf(idx_term, val_term))
})
.collect();
let mut declare_access_vars = |array: Term, name: &str| -> Vec<(Term, Term)> {
unmake_array(array)
let mut declare_access_vars = |tuple: Term, name: &str| -> Vec<(Term, Term)> {
tuple_terms(tuple)
.into_iter()
.enumerate()
.map(|(i, access)| {

View File

@@ -332,6 +332,7 @@ impl RewritePass for Extactor {
}
fn traverse(&mut self, computation: &mut Computation) {
let initial_precompute_len = computation.precomputes.outputs.len();
let terms: Vec<Term> = computation.terms_postorder().collect();
let term_refs: HashSet<&Term> = terms.iter().collect();
let mut cache = TermMap::<Term>::default();
@@ -355,10 +356,44 @@ impl RewritePass for Extactor {
.iter()
.map(|o| cache.get(o).unwrap().clone())
.collect();
if !self.cfg.waksman {
for ram in &mut self.rams {
if ram.is_covering_rom() {
ram.cfg.covering_rom = true;
let final_precompute_len = computation.precomputes.outputs.len();
debug!("from {initial_precompute_len} to {final_precompute_len} pre-variables");
if !self.rams.is_empty() {
for t in PostOrderIter::from_roots_and_skips(
computation.precomputes.sequence()[..initial_precompute_len]
.iter()
.map(|(name, _)| computation.precomputes.outputs.get(name).unwrap())
.cloned()
.collect::<Vec<_>>(),
cache.keys().cloned().collect(),
) {
// false positive: the value constructor uses `cache`.
#[allow(clippy::map_entry)]
if !cache.contains_key(&t) {
let new_t = term(
t.op().clone(),
t.cs()
.iter()
.map(|c| cache.get(c).unwrap().clone())
.collect(),
);
cache.insert(t, new_t);
}
}
// false positive; need to clone to drop reference.
#[allow(clippy::unnecessary_to_owned)]
for (name, _sort) in
computation.precomputes.sequence()[..initial_precompute_len].to_owned()
{
let term = computation.precomputes.outputs.get_mut(&name).unwrap();
*term = cache.get(term).unwrap().clone();
}
computation.precomputes.reorder();
if !self.cfg.waksman {
for ram in &mut self.rams {
if ram.is_covering_rom() {
ram.cfg.covering_rom = true;
}
}
}
}

View File

@@ -61,6 +61,9 @@ pub enum Opt {
/// Run optimizations on `cs`, in this order, returning the new constraint system.
pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I) -> Computations {
for c in cs.comps.values() {
trace!("Before all opts: {}", text::serialize_computation(c));
}
for i in optimizations {
debug!("Applying: {:?}", i);
@@ -79,10 +82,13 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
}
Opt::ConstantFold(ignore) => {
let mut cache = TermCache::with_capacity(TERM_CACHE_LIMIT);
cache.resize(std::usize::MAX);
cache.resize(usize::MAX);
for a in &mut c.outputs {
*a = cfold::fold_cache(a, &mut cache, &ignore.clone());
}
for v in &mut c.precomputes.outputs.values_mut() {
*v = cfold::fold_cache(v, &mut cache, &ignore.clone());
}
c.ram_arrays = c
.ram_arrays
.iter()
@@ -157,9 +163,11 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
}
}
debug!("After {:?}: {} outputs", i, c.outputs.len());
trace!("After {:?}: {}", i, c.outputs[0]);
trace!("After {:?}: {}", i, text::serialize_computation(c));
//debug!("After {:?}: {}", i, Letified(cs.outputs[0].clone()));
debug!("After {:?}: {} terms", i, c.terms());
#[cfg(debug_assertions)]
c.precomputes.check_topo_orderable();
}
if crate::cfg::cfg().ir.frequent_gc {
garbage_collect();

View File

@@ -1,4 +1,6 @@
//! Replacing array and tuple variables with scalars.
//!
//! Also replaces array and tuple *witnesses* with scalars.
use log::trace;
use crate::ir::opt::visit::RewritePass;
@@ -60,12 +62,43 @@ fn create_vars(
}
}
fn create_wits(prefix: &str, prefix_term: Term, sort: &Sort) -> Term {
match sort {
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| {
create_wits(
&format!("{prefix}.{i}"),
term![Op::Field(i); prefix_term.clone()],
sort,
)
})
.collect(),
),
Sort::Array(key_s, val_s, size) => {
let array_elements = extras::array_elements(&prefix_term);
make_array(
(**key_s).clone(),
(**val_s).clone(),
(0..*size)
.zip(array_elements)
.map(|(i, element)| create_wits(&format!("{prefix}.{i}"), element, val_s))
.collect(),
)
}
_ => term![Op::Witness(prefix.to_owned()); prefix_term],
}
}
impl RewritePass for Pass {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
computation: &mut Computation,
orig: &Term,
_rewritten_children: F,
rewritten_children: F,
) -> Option<Term> {
if let Op::Var(name, sort) = &orig.op() {
trace!("Considering var: {}", name);
@@ -80,6 +113,14 @@ impl RewritePass for Pass {
trace!("Skipping b/c it is commited.");
None
}
} else if let Op::Witness(name) = &orig.op() {
let sort = check(orig);
let mut cs = rewritten_children();
debug_assert_eq!(cs.len(), 1);
if !sort.is_scalar() {
trace!("Considering witness: {}", name);
}
Some(create_wits(name, cs.pop().unwrap(), &sort))
} else {
None
}
@@ -89,7 +130,7 @@ impl RewritePass for Pass {
/// Run the tuple elimination pass.
pub fn scalarize_inputs(cs: &mut Computation) {
let mut pass = Pass;
pass.traverse(cs);
pass.traverse_full(cs, false, true);
#[cfg(debug_assertions)]
assert_all_vars_are_scalars(cs);
remove_non_scalar_vars_from_main_computation(cs);

View File

@@ -149,6 +149,15 @@ impl TupleTree {
_ => panic!("{:?} is tuple!", self),
}
}
#[allow(clippy::wrong_self_convention)]
fn as_term(self) -> Term {
match self {
TupleTree::NonTuple(t) => t,
TupleTree::Tuple(items) => {
term(Op::Tuple, items.into_iter().map(|i| i.as_term()).collect())
}
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
@@ -236,7 +245,10 @@ fn tuple_free(t: Term) -> bool {
/// Run the tuple elimination pass.
pub fn eliminate_tuples(cs: &mut Computation) {
let mut lifted: TermMap<TupleTree> = TermMap::default();
for t in cs.terms_postorder() {
let terms =
PostOrderIter::from_roots_and_skips(cs.outputs().iter().cloned(), Default::default());
// .chain(cs.precomputes.outputs().values().cloned()),
for t in terms {
let mut cs: Vec<TupleTree> = t
.cs()
.iter()
@@ -294,7 +306,7 @@ pub fn eliminate_tuples(cs: &mut Computation) {
Op::Tuple => TupleTree::Tuple(cs.into()),
_ => TupleTree::NonTuple(term(
t.op().clone(),
cs.into_iter().map(|c| c.unwrap_non_tuple()).collect(),
cs.into_iter().map(|c| c.as_term()).collect(),
)),
};
lifted.insert(t, new_t);
@@ -303,6 +315,11 @@ pub fn eliminate_tuples(cs: &mut Computation) {
.into_iter()
.flat_map(|o| lifted.get(&o).unwrap().clone().flatten())
.collect();
// let os = cs.precomputes.outputs().clone();
// for (name, old_term) in os {
// let new_term = lifted.get(&old_term).unwrap().clone().as_term();
// cs.precomputes.change_output(&name, new_term);
// }
#[cfg(debug_assertions)]
for o in &cs.outputs {
if let Some(t) = find_tuple_term(o.clone()) {

View File

@@ -78,22 +78,24 @@ pub trait RewritePass {
}
}
}
let cache_get = |t: &Term| {
cache
.get(t)
.unwrap_or_else(|| panic!("Cache is missing: {}", t))
.clone()
};
if persistent_arrays {
for (_name, final_term) in &mut computation.persistent_arrays {
let new_final_term = cache.get(final_term).unwrap().clone();
let new_final_term = cache_get(final_term);
trace!("Array {} -> {}", final_term, new_final_term);
*final_term = new_final_term;
}
}
computation.outputs = computation
.outputs
.iter()
.map(|o| cache.get(o).unwrap().clone())
.collect();
computation.outputs = computation.outputs.iter().map(cache_get).collect();
if precompute {
let os = computation.precomputes.outputs().clone();
for (name, old_term) in os {
let new_term = cache.get(&old_term).unwrap().clone();
let new_term = cache_get(&old_term);
computation.precomputes.change_output(&name, new_term);
}
}

View File

@@ -253,14 +253,52 @@ impl BitVector {
impl Display for BitVector {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "#b")?;
for i in 0..self.width {
if self.width % 4 == 0 {
write!(
f,
"{}",
self.uint.get_bit((self.width - i - 1) as u32) as u8
"#x{:0>width$}",
&format!("{:x}", self.uint).as_str(),
width = self.width / 4
)?;
} else {
write!(
f,
"#b{:0>width$}",
&format!("{:b}", self.uint).as_str(),
width = self.width
)?;
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn formatting() {
for bits in 0..8 {
for i in 0..(1 << bits) {
let int = Integer::from(i);
let bv = BitVector::new(int.clone(), bits);
let fmt = format!("{}", bv);
let hex = bits % 4 == 0;
assert_eq!(
&fmt[..2],
if hex { "#x" } else { "#b" },
"formatted {} ({} bits) as {}",
i,
bits,
fmt
);
let integer = Integer::from_str_radix(&fmt[2..], if hex { 16 } else { 2 }).unwrap();
assert_eq!(
int, integer,
"formatted {} ({} bits) as {}, which parsed as {}",
i, bits, fmt, integer
);
}
}
}
}

View File

@@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
mod haboeck;
mod map;
mod pf_batch_inv;
mod poly;
mod ram;
mod sort;
@@ -19,6 +20,8 @@ mod waksman;
pub enum ExtOp {
/// See [haboeck].
Haboeck,
/// See [pf_batch_inv].
PfBatchInv,
/// See [ram::eval]
PersistentRamSplit,
/// Given an array of tuples, returns a reordering such that the result is sorted.
@@ -42,6 +45,7 @@ impl ExtOp {
pub fn arity(&self) -> Option<usize> {
match self {
ExtOp::Haboeck => Some(2),
ExtOp::PfBatchInv => Some(1),
ExtOp::PersistentRamSplit => Some(2),
ExtOp::Sort => Some(1),
ExtOp::Waksman => Some(1),
@@ -56,6 +60,7 @@ impl ExtOp {
pub fn check(&self, arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
match self {
ExtOp::Haboeck => haboeck::check(arg_sorts),
ExtOp::PfBatchInv => pf_batch_inv::check(arg_sorts),
ExtOp::PersistentRamSplit => ram::check(arg_sorts),
ExtOp::Sort => sort::check(arg_sorts),
ExtOp::Waksman => waksman::check(arg_sorts),
@@ -70,6 +75,7 @@ impl ExtOp {
pub fn eval(&self, args: &[&Value]) -> Value {
match self {
ExtOp::Haboeck => haboeck::eval(args),
ExtOp::PfBatchInv => pf_batch_inv::eval(args),
ExtOp::PersistentRamSplit => ram::eval(args),
ExtOp::Sort => sort::eval(args),
ExtOp::Waksman => waksman::eval(args),
@@ -88,6 +94,7 @@ impl ExtOp {
pub fn parse(bytes: &[u8]) -> Option<Self> {
match bytes {
b"haboeck" => Some(ExtOp::Haboeck),
b"pf_batch_inv" => Some(ExtOp::PfBatchInv),
b"persistent_ram_split" => Some(ExtOp::PersistentRamSplit),
b"uniq_deri_gcd" => Some(ExtOp::UniqDeriGcd),
b"sort" => Some(ExtOp::Sort),
@@ -103,6 +110,7 @@ impl ExtOp {
pub fn to_str(&self) -> &'static str {
match self {
ExtOp::Haboeck => "haboeck",
ExtOp::PfBatchInv => "pf_batch_inv",
ExtOp::PersistentRamSplit => "persistent_ram_split",
ExtOp::UniqDeriGcd => "uniq_deri_gcd",
ExtOp::Sort => "sort",

View File

@@ -12,30 +12,29 @@ use crate::ir::term::*;
/// Type-check [super::ExtOp::UniqDeriGcd].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
if let &[haystack, needles] = arg_sorts {
let (key0, value0, _n) = ty::array_or(haystack, "haystack must be an array")?;
let (key1, value1, _a) = ty::array_or(needles, "needles must be an array")?;
let key0 = pf_or(key0, "haystack indices must be field")?;
let key1 = pf_or(key1, "needles indices must be field")?;
let value0 = pf_or(value0, "haystack values must be field")?;
let value1 = pf_or(value1, "needles values must be field")?;
eq_or(key0, key1, "field must be the same")?;
eq_or(key1, value0, "field must be the same")?;
eq_or(value0, value1, "field must be the same")?;
Ok(haystack.clone())
} else {
// wrong arg count
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
}
let &[haystack, needles] = ty::count_or_ref(arg_sorts)?;
let (_n, value0) = ty::homogenous_tuple_or(haystack, "haystack must be a tuple")?;
let (_a, value1) = ty::homogenous_tuple_or(needles, "needles must be a tuple")?;
let value0 = pf_or(value0, "haystack values must be field")?;
let value1 = pf_or(value1, "needles values must be field")?;
eq_or(value0, value1, "field must be the same")?;
Ok(haystack.clone())
}
/// Evaluate [super::ExtOp::UniqDeriGcd].
pub fn eval(args: &[&Value]) -> Value {
let haystack = args[0].as_array().values();
let sort = args[0].sort().as_array().0.clone();
let field = sort.as_pf().clone();
let needles = args[1].as_array().values();
let haystack_item_index: FxHashMap<Value, usize> = haystack
let haystack: Vec<FieldV> = args[0]
.as_tuple()
.iter()
.map(|v| v.as_pf().clone())
.collect();
let needles: Vec<FieldV> = args[1]
.as_tuple()
.iter()
.map(|v| v.as_pf().clone())
.collect();
let field = haystack[0].ty();
let haystack_item_index: FxHashMap<FieldV, usize> = haystack
.iter()
.enumerate()
.map(|(i, v)| (v.clone(), i))
@@ -55,5 +54,5 @@ pub fn eval(args: &[&Value]) -> Value {
.into_iter()
.map(|c| Value::Field(field.new_v(c)))
.collect();
Value::Array(Array::from_vec(sort.clone(), sort, field_counts))
Value::Tuple(field_counts.into())
}

View File

@@ -5,7 +5,7 @@ use crate::ir::term::*;
/// Type-check [super::ExtOp::ArrayToMap].
pub fn check_array_to_map(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let [array] = ty::count_or(arg_sorts)?;
let [array] = ty::count_or_ref(arg_sorts)?;
let (k, v, _size) = ty::array_or(array, "ArrayToMap expects array")?;
Ok(Sort::Map(Box::new(k.clone()), Box::new(v.clone())))
}
@@ -20,7 +20,7 @@ pub fn eval_array_to_map(args: &[&Value]) -> Value {
/// Type-check [super::ExtOp::MapFlip].
pub fn check_map_flip(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let [map] = ty::count_or(arg_sorts)?;
let [map] = ty::count_or_ref(arg_sorts)?;
let (k, v) = ty::map_or(map, "MapFlip expects map")?;
Ok(Sort::Map(Box::new(k.clone()), Box::new(v.clone())))
}
@@ -37,7 +37,7 @@ pub fn eval_map_flip(args: &[&Value]) -> Value {
/// Type-check [super::ExtOp::MapSelect].
pub fn check_map_select(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let [map, k] = ty::count_or(arg_sorts)?;
let [map, k] = ty::count_or_ref(arg_sorts)?;
let (km, v) = ty::map_or(map, "MapSelect expects map")?;
ty::eq_or(km, k, "MapSelect key")?;
Ok(v.clone())
@@ -52,7 +52,7 @@ pub fn eval_map_select(args: &[&Value]) -> Value {
/// Type-check [super::ExtOp::MapContainsKey].
pub fn check_map_contains_key(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let [map, k] = ty::count_or(arg_sorts)?;
let [map, k] = ty::count_or_ref(arg_sorts)?;
let (km, _) = ty::map_or(map, "MapContainsKey expects map")?;
ty::eq_or(km, k, "MapContainsKey key")?;
Ok(Sort::Bool)

View File

@@ -0,0 +1,59 @@
//! Batch (multiplicative) inversion for prime field elements.
//!
//! Takes a non-empty tuple of elements from the same field and returns the inverses as a tuple of
//! the same size.
use crate::ir::term::ty::*;
use crate::ir::term::*;
/// Type-check [super::ExtOp::PfBatchInv].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let &[values] = ty::count_or_ref(arg_sorts)?;
let (_n, v_sort) = ty::homogenous_tuple_or(values, "pf batch inversion")?;
let _f = pf_or(v_sort, "pf_batch_inv")?;
Ok(values.clone())
}
fn batch_inv(field: &FieldT, data: &mut [Value]) {
// Montgomerys Trick and Fast Implementation of Masked AES
// Genelle, Prouff and Quisquater
// Section 3.2
// First pass: compute [a, ab, abc, ...]
let mut prod = Vec::with_capacity(data.len());
let mut tmp = field.new_v(1);
for f in data.iter().map(Value::as_pf).filter(|f| !f.is_zero()) {
tmp *= f;
prod.push(tmp.clone());
}
// Invert `tmp`.
tmp = tmp.recip_ref(); // Guaranteed to be nonzero.
// Second pass: iterate backwards to compute inverses
for (f, s) in data
.iter_mut()
// Backwards
.rev()
// Ignore normalized elements
.filter(|f| !f.as_pf().is_zero())
// Backwards, skip last element, fill in one for last term.
.zip(prod.into_iter().rev().skip(1).chain(Some(field.new_v(1))))
{
// tmp := tmp * f; f := tmp * s = 1/f
let new_tmp = tmp.clone() * f.as_pf();
*f = Value::Field(tmp.clone() * &s);
tmp = new_tmp;
}
}
/// Evaluate [super::ExtOp::PfBatchInv].
pub fn eval(args: &[&Value]) -> Value {
// adapted from ark_ff
let mut values: Vec<Value> = args[0].as_tuple().to_owned();
if !values.is_empty() {
let field = values[0].as_pf().ty();
batch_inv(&field, &mut values)
}
Value::Tuple(values.into())
}

View File

@@ -6,50 +6,22 @@ use fxhash::FxHashSet as HashSet;
/// Type-check [super::ExtOp::PersistentRamSplit].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
if let &[entries, indices] = arg_sorts {
let (key, value, size) = ty::array_or(entries, "PersistentRamSplit entries")?;
let f = pf_or(key, "PersistentRamSplit entries: indices must be field")?;
let value_tup = ty::tuple_or(value, "PersistentRamSplit entries: value must be a tuple")?;
if let &[old, new] = &value_tup {
eq_or(
f,
old,
"PersistentRamSplit entries: value must be a field pair",
)?;
eq_or(
f,
new,
"PersistentRamSplit entries: value must be a field pair",
)?;
let (i_key, i_value, i_size) = ty::array_or(indices, "PersistentRamSplit indices")?;
eq_or(f, i_key, "PersistentRamSplit indices: key must be a field")?;
eq_or(
f,
i_value,
"PersistentRamSplit indices: value must be a field",
)?;
let n_touched = i_size.min(size);
let n_ignored = size - n_touched;
let box_f = Box::new(f.clone());
let f_pair = Sort::Tuple(Box::new([f.clone(), f.clone()]));
let ignored_entries_sort =
Sort::Array(box_f.clone(), Box::new(f_pair.clone()), n_ignored);
let selected_entries_sort = Sort::Array(box_f, Box::new(f_pair), n_touched);
Ok(Sort::Tuple(Box::new([
ignored_entries_sort,
selected_entries_sort.clone(),
selected_entries_sort,
])))
} else {
// non-pair entries value
Err(TypeErrorReason::Custom(
"PersistentRamSplit: entries value must be a pair".into(),
))
}
} else {
// wrong arg count
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
}
let &[entries, indices] = ty::count_or_ref(arg_sorts)?;
let (size, value) = ty::homogenous_tuple_or(entries, "PersistentRamSplit entries")?;
let [old, new] = ty::count_or(ty::tuple_or(value, "PersistentRamSplit entries")?)?;
eq_or(old, new, "PersistentRamSplit entries")?;
let (i_size, i_value) = ty::homogenous_tuple_or(indices, "PersistentRamSplit indices")?;
let f = pf_or(i_value, "PersistentRamSplit indices")?;
let n_touched = i_size.min(size);
let n_ignored = size - n_touched;
let f_pair = Sort::Tuple(Box::new([f.clone(), f.clone()]));
let ignored_entries_sort = Sort::Tuple(vec![f_pair.clone(); n_ignored].into());
let selected_entries_sort = Sort::Tuple(vec![f_pair.clone(); n_touched].into());
Ok(Sort::Tuple(Box::new([
ignored_entries_sort,
selected_entries_sort.clone(),
selected_entries_sort,
])))
}
/// Evaluate [super::ExtOp::PersistentRamSplit].
@@ -72,14 +44,14 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
/// * init_reads (array field (tuple (field field)) (length I))
/// * fin_writes (array field (tuple (field field)) (length I))
pub fn eval(args: &[&Value]) -> Value {
let entries = &args[0].as_array().values();
let entries = &args[0].as_tuple();
let (init_vals, fin_vals): (Vec<Value>, Vec<Value>) = entries
.iter()
.map(|t| (t.as_tuple()[0].clone(), t.as_tuple()[1].clone()))
.unzip();
let indices = &args[1].as_array().values();
let indices = &args[1].as_tuple();
let num_accesses = indices.len();
let field = args[0].as_array().key_sort.as_pf();
let field = args[1].sort().as_tuple()[0].as_pf().clone();
let uniq_indices = {
let mut uniq_indices = Vec::<usize>::new();
let mut used_indices = HashSet::<usize>::default();
@@ -108,19 +80,14 @@ pub fn eval(args: &[&Value]) -> Value {
untouched_entries.push((i, init_val));
}
}
let key_sort = Sort::Field(field.clone());
let entry_to_vals =
|e: (usize, Value)| Value::Tuple(Box::new([Value::Field(field.new_v(e.0)), e.1]));
let vec_to_arr = |v: Vec<(usize, Value)>| {
let vec_to_tuple = |v: Vec<(usize, Value)>| {
let vals: Vec<Value> = v.into_iter().map(entry_to_vals).collect();
Value::Array(Array::from_vec(
key_sort.clone(),
vals.first().unwrap().sort(),
vals,
))
Value::Tuple(vals.into())
};
let init_reads = vec_to_arr(init_reads);
let untouched_entries = vec_to_arr(untouched_entries);
let fin_writes = vec_to_arr(fin_writes);
let init_reads = vec_to_tuple(init_reads);
let untouched_entries = vec_to_tuple(untouched_entries);
let fin_writes = vec_to_tuple(fin_writes);
Value::Tuple(vec![untouched_entries, init_reads, fin_writes].into_boxed_slice())
}

View File

@@ -5,7 +5,13 @@ use crate::ir::term::*;
/// Type-check [super::ExtOp::Sort].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
array_or(arg_sorts[0], "sort argument").map(|_| arg_sorts[0].clone())
let [arg_sort] = ty::count_or_ref(arg_sorts)?;
match arg_sort {
Sort::Tuple(_) | Sort::Array(..) => Ok((**arg_sort).clone()),
_ => Err(TypeErrorReason::Custom(
"sort takes an array or tuple".into(),
)),
}
}
/// Evaluate [super::ExtOp::Sort].

View File

@@ -84,8 +84,8 @@ fn persistent_ram_split_eval() {
let t = text::parse_term(
b"
(declare (
(entries (array (mod 17) (tuple (mod 17) (mod 17)) 5))
(indices (array (mod 17) (mod 17) 3))
(entries (tuple 5 (tuple (mod 17) (mod 17))))
(indices (tuple 3 (mod 17)))
)
(persistent_ram_split entries indices))",
);
@@ -95,8 +95,8 @@ fn persistent_ram_split_eval() {
(set_default_modulus 17
(let
(
(entries (#l (mod 17) ( (#t #f0 #f1) (#t #f1 #f1) (#t #f2 #f3) (#t #f3 #f4) (#t #f4 #f4) )))
(indices (#l (mod 17) (#f0 #f2 #f3)))
(entries (#t (#t #f0 #f1) (#t #f1 #f1) (#t #f2 #f3) (#t #f3 #f4) (#t #f4 #f4) ))
(indices (#t #f0 #f2 #f3))
) false))
",
);
@@ -107,19 +107,13 @@ fn persistent_ram_split_eval() {
(let
(
(output (#t
(#l (mod 17) ( (#t #f1 #f1) (#t #f4 #f4) )) ; untouched
(#l (mod 17) ( (#t #f0 #f0) (#t #f2 #f2) (#t #f3 #f3) )) ; init_reads
(#l (mod 17) ( (#t #f0 #f1) (#t #f2 #f3) (#t #f3 #f4) )) ; fin_writes
(#t (#t #f1 #f1) (#t #f4 #f4) ) ; untouched
(#t (#t #f0 #f0) (#t #f2 #f2) (#t #f3 #f3) ) ; init_reads
(#t (#t #f0 #f1) (#t #f2 #f3) (#t #f3 #f4) ) ; fin_writes
))
) false))
",
);
dbg!(&actual_output.as_tuple()[0].as_array().default);
dbg!(
&expected_output.get("output").unwrap().as_tuple()[0]
.as_array()
.default
);
assert_eq!(&actual_output, expected_output.get("output").unwrap());
// duplicates
@@ -128,8 +122,8 @@ fn persistent_ram_split_eval() {
(set_default_modulus 17
(let
(
(entries (#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) (#t #f3 #f3) (#t #f4 #f4) )))
(indices (#l (mod 17) (#f1 #f1 #f1)))
(entries (#t (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) (#t #f3 #f3) (#t #f4 #f4) ))
(indices (#t #f1 #f1 #f1))
) false))
",
);
@@ -140,9 +134,9 @@ fn persistent_ram_split_eval() {
(let
(
(output (#t
(#l (mod 17) ( (#t #f3 #f3) (#t #f4 #f4) )) ; untouched
(#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f1) (#t #f2 #f2) )) ; init_reads
(#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) )) ; fin_writes
(#t (#t #f3 #f3) (#t #f4 #f4) ) ; untouched
(#t (#t #f0 #f0) (#t #f1 #f1) (#t #f2 #f2) ) ; init_reads
(#t (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) ) ; fin_writes
))
) false))
",
@@ -155,8 +149,8 @@ fn haboeck_eval(haystack: &[usize], needles: &[usize], counts: &[usize]) {
format!(
"
(declare (
(haystack (array (mod 17) (mod 17) {}))
(needles (array (mod 17) (mod 17) {}))
(haystack (tuple {} (mod 17)))
(needles (tuple {} (mod 17)))
)
(haboeck haystack needles))",
haystack.len(),
@@ -174,8 +168,8 @@ fn haboeck_eval(haystack: &[usize], needles: &[usize], counts: &[usize]) {
"(set_default_modulus 17
(let
(
(haystack (#l (mod 17) ({})))
(needles (#l (mod 17) ({})))
(haystack (#t {}))
(needles (#t {}))
) false))",
haystack.join(" "),
needles.join(" ")
@@ -187,7 +181,7 @@ fn haboeck_eval(haystack: &[usize], needles: &[usize], counts: &[usize]) {
"(set_default_modulus 17
(let
(
(counts (#l (mod 17) ({})))
(counts (#t {}))
) false))",
counts.join(" ")
)

View File

@@ -246,3 +246,38 @@ pub fn collect_asserted_ops(
}
}
}
/// Iterator over a node's children.
pub fn node_cs_iter(node: Term) -> impl Iterator<Item = Term> {
(0..node.cs().len()).map(move |i| node.cs()[i].clone())
}
// impl ChildrenIter {
// fn new(node: Term) -> Self {
// Self {node, next_child: 0}
// }
// fn next(&mut self) {
// if self.next_child < self.node.cs().len() {
// self.next_child += 1;
// }
// }
// fn is_done(&self)
// fn this_child(&self)
// }
//
// #[allow(unused_variables)]
// /// Term traversal control
// pub trait TermTraversalControl {
// /// Whether to skip this term and all descendents.
// ///
// /// This term is guaranteed to be skipped, but its descendents are not guaranteed to be
// /// skipped.
// fn skip(&mut self, t: &Term) -> bool { false }
// /// Any extra dependencies that should be traversed before this term.
// fn extra_dependencies(&mut self, t: &Term) -> Option<Vec<Term>> { None }
// }
//
// pub struct TermTraversal {
//
// }

View File

@@ -148,12 +148,18 @@ impl DisplayIr for Sort {
write!(f, ")")
}
Sort::Tuple(fields) => {
write!(f, "(tuple")?;
for field in fields.iter() {
write!(f, " ")?;
field.ir_fmt(f)?;
if fields.len() > 1 && fields[1..].iter().all(|s| s == &fields[0]) {
write!(f, "(tuple {} ", fields.len())?;
fields[0].ir_fmt(f)?;
write!(f, ")")
} else {
write!(f, "(tuple")?;
for field in fields.iter() {
write!(f, " ")?;
field.ir_fmt(f)?;
}
write!(f, ")")
}
write!(f, ")")
}
}
}

View File

@@ -1076,6 +1076,11 @@ thread_local! {
static LAST_LEN: Cell<usize> = Default::default();
}
/// Size of the term table.
pub fn table_size() -> usize {
hc::Table::table_size()
}
fn should_collect() -> bool {
let last_len = LAST_LEN.with(|l| l.get());
let ret = LEN_THRESH_DEN * hc::Table::table_size() > LEN_THRESH_NUM * last_len;
@@ -1100,21 +1105,33 @@ pub fn maybe_garbage_collect() -> bool {
}
if should_collect() {
collect_terms();
collect_types();
super::opt::cfold::collect();
let orig_terms = table_size();
let n_collected = collect_terms();
if 50 * n_collected > orig_terms {
collect_types();
super::opt::cfold::collect();
}
true
} else {
false
}
}
fn collect_terms() {
fn collect_terms() -> usize {
let size_before = hc::Table::table_size();
hc::Table::gc();
let size_after = hc::Table::table_size();
let pct_removed = (size_before - size_after) as f64 / size_before as f64 * 100.0;
debug!("Term collection: {size_before} -> {size_after} (-{pct_removed}%)");
size_before - size_after
}
fn collect_types() {
let size_before = ty::TERM_TYPES.with(|tys| tys.borrow().len());
ty::TERM_TYPES.with(|tys| tys.borrow_mut().collect());
let size_after = ty::TERM_TYPES.with(|tys| tys.borrow().len());
let pct_removed = (size_before - size_after) as f64 / size_before as f64 * 100.0;
debug!("Type collection: {size_before} -> {size_after} (-{pct_removed}%)");
}
impl Term {
@@ -1673,6 +1690,19 @@ pub fn unmake_array(a: Term) -> Vec<Term> {
.collect()
}
/// Make a sequence of terms from a tuple
///
/// Requires
///
/// * a tuple term
pub fn tuple_terms(a: Term) -> Vec<Term> {
let sort = check(&a);
let size = sort.as_tuple().len();
(0..size)
.map(|idx| term(Op::Field(idx), vec![a.clone()]))
.collect()
}
/// Make a term with no arguments, just an operator.
pub fn leaf_term(op: Op) -> Term {
term(op, Vec::new())

View File

@@ -9,6 +9,8 @@ use crate::ir::term::*;
use log::trace;
use std::cell::RefCell;
/// A "precomputation".
///
/// Expresses a computation to be run in advance by a single party.
@@ -16,7 +18,7 @@ use log::trace;
pub struct PreComp {
#[serde(with = "crate::ir::term::serde_mods::map")]
/// A map from output names to the terms that compute them.
outputs: FxHashMap<String, Term>,
pub outputs: FxHashMap<String, Term>,
sequence: Vec<(String, Sort)>,
inputs: FxHashSet<(String, Sort)>,
}
@@ -133,7 +135,7 @@ impl PreComp {
}
/// Reduce the precomputation to a single, step-less map.
pub fn flatten(self) -> FxHashMap<String, Term> {
pub fn flatten(&mut self) {
let mut out: FxHashMap<String, Term> = Default::default();
let mut cache: TermMap<Term> = Default::default();
for (name, sort) in &self.sequence {
@@ -142,7 +144,124 @@ impl PreComp {
out.insert(name.into(), term.clone());
cache.insert(var_term, term);
}
out
self.outputs = out;
}
/// Compute a topo order
pub fn topo_order(&self) -> FxHashMap<String, usize> {
let mut order: FxHashMap<String, usize> = FxHashMap::default();
let mut stack: Vec<Term> = self
.outputs
.iter()
.map(|(name, t)| leaf_term(Op::Var(name.clone(), check(t))))
.collect();
let mut post_visited: TermSet = Default::default();
let mut pre_visited: TermSet = Default::default();
while let Some(t) = stack.pop() {
if post_visited.contains(&t) {
continue;
}
if pre_visited.insert(t.clone()) {
// children not yet pushed
stack.push(t.clone());
if let Op::Var(name, _) = t.op() {
if let Some(c) = self.outputs.get(name) {
if !post_visited.contains(c) {
assert!(!pre_visited.contains(c), "loop on {} {}", c.id(), c);
stack.push(c.clone());
}
}
} else {
for c in t.cs() {
if !post_visited.contains(c) {
assert!(!pre_visited.contains(c), "loop on {} {}", c.id(), c);
stack.push(c.clone());
}
}
}
} else {
post_visited.insert(t.clone());
if let Op::Var(name, _) = t.op() {
order.insert(name.clone(), order.len());
}
}
}
order
}
/// Put the outputs back into a topo order
pub fn reorder(&mut self) {
let order = self.topo_order();
self.sequence
.sort_by_cached_key(|(name, _sort)| order.get(name).unwrap());
trace!("{}", text::serialize_precompute(self));
#[cfg(debug_assertions)]
self.check_topo_order();
}
#[allow(dead_code)]
/// Check that no variables is used before defintion.
pub fn check_topo_order(&self) {
let defined: TermSet = self
.sequence
.iter()
.map(|(n, s)| leaf_term(Op::Var(n.clone(), s.clone())))
.collect();
let seen = RefCell::new(TermSet::default());
for (name, sort) in &self.inputs {
seen.borrow_mut()
.insert(leaf_term(Op::Var(name.clone(), sort.clone())));
}
for (name, sort) in &self.sequence {
let t = self.outputs.get(name).unwrap();
for desc in extras::PostOrderSkipIter::new(t.clone(), &|n| seen.borrow().contains(n)) {
if desc.is_var() && defined.contains(&desc) {
// we haven't seen this, or it would have been skipped.
panic!("variable {} used before definition", desc);
}
seen.borrow_mut().insert(desc);
}
seen.borrow_mut()
.insert(leaf_term(Op::Var(name.clone(), sort.clone())));
}
}
/// Check that a topo-order exists
#[allow(dead_code)]
pub fn check_topo_orderable(&self) {
let mut stack: Vec<Term> = self
.outputs
.iter()
.map(|(name, t)| leaf_term(Op::Var(name.clone(), check(t))))
.collect();
let mut post_visited: TermSet = Default::default();
let mut pre_visited: TermSet = Default::default();
while let Some(t) = stack.pop() {
if post_visited.contains(&t) {
continue;
}
if pre_visited.insert(t.clone()) {
// children not yet pushed
stack.push(t.clone());
if let Op::Var(name, _) = t.op() {
if let Some(c) = self.outputs.get(name) {
if !post_visited.contains(c) {
assert!(!pre_visited.contains(c), "loop on {} {}", c.id(), c);
stack.push(c.clone());
}
}
} else {
for c in t.cs() {
if !post_visited.contains(c) {
assert!(!pre_visited.contains(c), "loop on {} {}", c.id(), c);
stack.push(c.clone());
}
}
}
} else {
post_visited.insert(t.clone());
}
}
}
}

View File

@@ -41,6 +41,7 @@
//! * `(bv N)`
//! * `(mod I)`
//! * `(tuple S1 ... Sn)`
//! * `(tuple N S)` : N copies of S
//! * `(array Sk Sv N)`
//! * `(map Sk Sv)`
//! * Value `V`:
@@ -365,7 +366,16 @@ impl<'src> IrInterp<'src> {
Sort::Map(Box::new(self.sort(k)), Box::new(self.sort(v)))
}
[Leaf(Ident, b"tuple"), ..] => {
Sort::Tuple(ls[1..].iter().map(|li| self.sort(li)).collect())
if ls.len() > 1 {
if let Some(size) = self.maybe_usize(&ls[1]) {
assert_eq!(ls.len(), 3);
Sort::Tuple(vec![self.sort(&ls[2]); size].into())
} else {
Sort::Tuple(ls[1..].iter().map(|li| self.sort(li)).collect())
}
} else {
Sort::Tuple(ls[1..].iter().map(|li| self.sort(li)).collect())
}
}
_ => panic!("Expected sort, found {}", tt),
}
@@ -401,9 +411,12 @@ impl<'src> IrInterp<'src> {
}
}
fn usize(&self, tt: &TokTree) -> usize {
self.maybe_usize(tt).unwrap()
}
fn maybe_usize(&self, tt: &TokTree) -> Option<usize> {
match tt {
Leaf(Token::Int, s) => usize::from_str(from_utf8(s).unwrap()).unwrap(),
_ => panic!("Expected integer, got {}", tt),
Leaf(Token::Int, s) => usize::from_str(from_utf8(s).ok()?).ok(),
_ => None,
}
}
/// Parse lets, returning bindings, in-order.
@@ -1294,8 +1307,8 @@ mod test {
let t = parse_term(
b"
(declare (
(entries (array (mod 17) (tuple (mod 17) (mod 17)) 5))
(indices (array (mod 17) (mod 17) 3))
(entries (tuple 5 (tuple (mod 17) (mod 17))))
(indices (tuple 3 (mod 17)))
)
(persistent_ram_split entries indices))",
);
@@ -1336,6 +1349,20 @@ mod test {
assert_eq!(t, t2);
}
#[test]
fn tuple_dup_roundtrip() {
let t = parse_term(b"(declare ((a (tuple 4 bool))) a)");
let t2 = parse_term(serialize_term(&t).as_bytes());
assert_eq!(t, t2);
}
#[test]
fn tuple_nodup_roundtrip() {
let t = parse_term(b"(declare ((a (tuple (bv 4) bool))) a)");
let t2 = parse_term(serialize_term(&t).as_bytes());
assert_eq!(t, t2);
}
#[test]
fn pf_fits_in_bits_rountrip() {
let t = parse_term(b"(declare ((a bool)) ((pf_fits_in_bits 4) (ite a #f1m11 #f0m11)))");
@@ -1363,8 +1390,8 @@ mod test {
let t = parse_term(
b"
(declare (
(haystack (array (mod 17) (mod 17) 5))
(needles (array (mod 17) (mod 17) 8))
(haystack (tuple 5 (mod 17)))
(needles (tuple 8 (mod 17)))
)
(haboeck haystack needles))",
);
@@ -1373,4 +1400,19 @@ mod test {
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
#[test]
fn pf_batch_inv_roundtrip() {
let t = parse_term(
b"
(declare (
(values (tuple (mod 17) (mod 17)))
)
(pf_batch_inv values))",
);
let s = serialize_term(&t);
println!("{s}");
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
}

View File

@@ -548,6 +548,8 @@ pub enum TypeErrorReason {
ExpectedMap(Sort, &'static str),
/// A sort should be a tuple
ExpectedTuple(&'static str),
/// A sort should be a (non-empty) homogenous tuple
ExpectedHomogenousTuple(&'static str),
/// An empty n-ary operator.
EmptyNary(String),
/// Expected _ args, but got _
@@ -667,8 +669,32 @@ fn all_eq_or<'a, I: Iterator<Item = &'a Sort>>(
Ok(first)
}
pub(super) fn count_or<'a, const N: usize>(
pub(super) fn count_or_ref<'a, const N: usize>(
sorts: &'a [&'a Sort],
) -> Result<&'a [&'a Sort; N], TypeErrorReason> {
<&'a [&'a Sort; N]>::try_from(sorts).map_err(|_| TypeErrorReason::ExpectedArgs(N, sorts.len()))
}
pub(super) fn count_or<'a, const N: usize>(
sorts: &'a [Sort],
) -> Result<&'a [Sort; N], TypeErrorReason> {
<&'a [Sort; N]>::try_from(sorts).map_err(|_| TypeErrorReason::ExpectedArgs(N, sorts.len()))
}
/// Check if this is a non-empty homogenous tuple.
pub(super) fn homogenous_tuple_or<'a>(
a: &'a Sort,
ctx: &'static str,
) -> Result<(usize, &'a Sort), TypeErrorReason> {
let sorts = tuple_or(a, ctx)?;
if !sorts.is_empty() {
for i in 1..sorts.len() {
if sorts[0] != sorts[i] {
return Err(TypeErrorReason::ExpectedHomogenousTuple(ctx));
}
}
Ok((sorts.len(), &sorts[0]))
} else {
Err(TypeErrorReason::ExpectedHomogenousTuple(ctx))
}
}

View File

@@ -271,6 +271,7 @@ where
fn prove(pk: &Self::ProvingKey, witness: &FxHashMap<String, Value>) -> Self::Proof {
let rng = &mut rand::thread_rng();
#[cfg(debug_assertions)]
pk.0.check_all(witness);
Proof(groth16::create_random_proof(SynthInput(&pk.0, Some(witness)), &pk.1, rng).unwrap())
}

View File

@@ -176,6 +176,9 @@ impl<'a, F: PrimeField + PrimeFieldBits> CcCircuit<F> for SynthInput<'a, F> {
}
}
}
if let Some((_, ref evaluator, _)) = wit_comp.as_mut() {
evaluator.print_times();
}
for (i, (a, b, c)) in self.0.r1cs.constraints.iter().enumerate() {
cs.enforce(

View File

@@ -132,6 +132,8 @@ pub struct R1cs {
/// The contraints themselves
constraints: Vec<(Lc, Lc, Lc)>,
stats: R1csStats,
/// Terms for computing them.
#[serde(with = "crate::ir::term::serde_mods::map")]
terms: HashMap<Var, Term>,
@@ -222,6 +224,7 @@ impl R1cs {
num_final_wits: Default::default(),
challenge_names: Default::default(),
constraints: Vec::new(),
stats: Default::default(),
terms: Default::default(),
precompute,
}
@@ -257,6 +260,7 @@ impl R1cs {
// could check `t` dependents
self.idx_to_sig.insert(var, s);
self.terms.insert(var, t);
self.stats.n_vars += 1;
var
}
@@ -324,6 +328,13 @@ impl R1cs {
assert_eq!(&self.modulus, &a.modulus);
assert_eq!(&self.modulus, &b.modulus);
assert_eq!(&self.modulus, &c.modulus);
self.stats.n_constraints += 1;
let n_a = a.monomials.len() + !a.constant.is_zero() as usize;
let n_b = b.monomials.len() + !b.constant.is_zero() as usize;
let n_c = c.monomials.len() + !c.constant.is_zero() as usize;
self.stats.n_a_entries += n_a as u32;
self.stats.n_b_entries += n_b as u32;
self.stats.n_c_entries += n_c as u32;
debug!(
"Constraint:\n {}\n * {}\n = {}",
self.format_lc(&a),
@@ -419,6 +430,69 @@ impl R1cs {
pub fn constraints(&self) -> &Vec<(Lc, Lc, Lc)> {
&self.constraints
}
/// Statistics for this R1CS instance
pub fn stats(&self) -> &R1csStats {
&self.stats
}
/// Recalculate statistics for this R1CS instance
pub fn update_stats(&mut self) {
self.stats = R1csStats::default();
self.stats.n_vars = self.num_vars() as u32;
let s = &mut self.stats;
s.n_constraints = self.constraints.len() as u32;
for (a, b, c) in &self.constraints {
let n_a = a.monomials.len() + !a.constant.is_zero() as usize;
let n_b = b.monomials.len() + !b.constant.is_zero() as usize;
let n_c = c.monomials.len() + !c.constant.is_zero() as usize;
s.n_a_entries += n_a as u32;
s.n_b_entries += n_b as u32;
s.n_c_entries += n_c as u32;
}
}
}
/// R1CS statistics
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct R1csStats {
/// number of constraints
pub n_constraints: u32,
/// number of variables
pub n_vars: u32,
/// number of non-zero A matrix entries
pub n_a_entries: u32,
/// number of non-zero B matrix entries
pub n_b_entries: u32,
/// number of non-zero C matrix entries
pub n_c_entries: u32,
}
impl R1csStats {
/// number of non-zero A, B, and C entries
pub fn n_entries(&self) -> u64 {
self.n_a_entries as u64 + self.n_b_entries as u64 + self.n_c_entries as u64
}
}
impl std::ops::AddAssign<&R1csStats> for R1csStats {
fn add_assign(&mut self, other: &R1csStats) {
self.n_constraints += other.n_constraints;
self.n_vars += other.n_vars;
self.n_a_entries += other.n_a_entries;
self.n_b_entries += other.n_b_entries;
self.n_c_entries += other.n_c_entries;
}
}
impl std::ops::SubAssign<&R1csStats> for R1csStats {
fn sub_assign(&mut self, other: &R1csStats) {
self.n_constraints -= other.n_constraints;
self.n_vars -= other.n_vars;
self.n_a_entries -= other.n_a_entries;
self.n_b_entries -= other.n_b_entries;
self.n_c_entries -= other.n_c_entries;
}
}
impl R1csFinal {
@@ -535,6 +609,7 @@ impl ProverData {
}
}
}
eval.print_times();
var_values
}
/// Check all assertions. Puts in 1 for challenges.
@@ -922,7 +997,8 @@ impl R1cs {
// we still need to remove the non-r1cs variables
//use crate::ir::proof::PROVER_ID;
//let all_inputs = cs.metadata.get_inputs_for_party(Some(PROVER_ID));
let mut precompute_map = precompute.flatten();
precompute.flatten();
let mut precompute_map = precompute.outputs;
let mut vars: HashMap<String, Sort> = {
PostOrderIter::from_roots_and_skips(
precompute_map.values().cloned(),
@@ -999,7 +1075,8 @@ impl R1cs {
for c in &self.challenge_names {
assert!(!vars.contains_key(c));
}
let mut precompute_map = precompute.flatten();
precompute.flatten();
let mut precompute_map = precompute.outputs;
let terms = self
.insts_iter()
.map(|v| {

View File

@@ -218,7 +218,9 @@ fn constantly_true((a, b, c): &(Lc, Lc, Lc)) -> bool {
/// * `lc_size_thresh`: the maximum size LC (number of non-constant monomials) that will be used
/// for propagation. `None` means no size limit.
pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs {
LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run()
let mut r = LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run();
r.update_stats();
r
}
#[cfg(test)]

View File

@@ -14,6 +14,8 @@ use log::{debug, trace};
use rug::ops::Pow;
use rug::Integer;
use fxhash::FxHashMap;
use std::cell::RefCell;
use std::fmt::Display;
use std::iter::ExactSizeIterator;
@@ -36,19 +38,6 @@ enum EmbeddedTerm {
Tuple(Vec<EmbeddedTerm>),
}
#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct Metric {
n_constraints: u32,
n_vars: u32,
}
impl std::ops::AddAssign<&Metric> for Metric {
fn add_assign(&mut self, rhs: &Metric) {
self.n_vars += rhs.n_vars;
self.n_constraints += rhs.n_constraints;
}
}
struct ToR1cs<'cfg> {
r1cs: R1cs,
cache: TermMap<EmbeddedTerm>,
@@ -59,9 +48,10 @@ struct ToR1cs<'cfg> {
cfg: &'cfg CircCfg,
field: FieldT,
used_vars: HashSet<String>,
profiling_data: TermMap<Metric>,
metric: Metric,
term_in_progress: Option<Term>,
/// Map from (operator, arity) to metric.
profiling_data: FxHashMap<(Op, usize), (usize, R1csStats)>,
old_stats: R1csStats,
op_in_progress: Option<(Op, usize)>,
}
impl<'cfg> ToR1cs<'cfg> {
@@ -82,68 +72,62 @@ impl<'cfg> ToR1cs<'cfg> {
field,
cfg,
profiling_data: Default::default(),
term_in_progress: None,
metric: Default::default(),
op_in_progress: None,
old_stats: Default::default(),
}
}
fn profile_start_term(&mut self, t: Term) {
if self.cfg.r1cs.profile {
assert!(self.term_in_progress.is_none());
self.term_in_progress = Some(t);
self.metric = Default::default();
assert!(self.op_in_progress.is_none());
let key = (t.op().clone(), t.cs().len());
// trace!("profile start {:?}", key);
self.op_in_progress = Some(key);
self.old_stats = self.r1cs.stats().clone();
}
}
fn profile_end_term(&mut self) {
if self.cfg.r1cs.profile {
assert!(self.term_in_progress.is_some());
let t = self.term_in_progress.take().unwrap();
*self.profiling_data.entry(t).or_default() += &self.metric;
self.metric = Default::default();
assert!(self.op_in_progress.is_some());
let t = self.op_in_progress.take().unwrap();
// trace!("profile end {:?}", t);
let mut current_stats = self.r1cs.stats().clone();
current_stats -= &self.old_stats;
let data = self.profiling_data.entry(t).or_default();
data.1 += &current_stats;
data.0 += 1;
self.old_stats = Default::default();
}
}
fn profile_print(&self) {
if self.cfg.r1cs.profile {
let terms: TermSet = self.profiling_data.keys().cloned().collect();
let mut cum_metrics: TermMap<Metric> = Default::default();
for t in PostOrderIter::from_roots_and_skips(terms, Default::default()) {
let mut cum = Metric::default();
if let Some(d) = self.profiling_data.get(&t) {
cum += d;
}
for c in t.cs() {
cum += cum_metrics.get(c).unwrap();
}
cum_metrics.insert(t, cum);
}
let mut data: Vec<(Term, Metric, Metric)> = cum_metrics
.into_iter()
.map(|(term, cum)| {
let indiv = self
.profiling_data
.get(&term)
.cloned()
.unwrap_or(Default::default());
(term, cum, indiv)
})
let mut data: Vec<(&Op, usize, usize, &R1csStats)> = self
.profiling_data
.iter()
.map(|((op, arity), (count, stats))| (op, *arity, *count, stats))
.collect();
data.sort_by(|(t0, c0, i0), (t1, c1, i1)| i0.cmp(i1).then(c0.cmp(c1)).then(t0.cmp(t1)));
data.sort_by_key(|(_, _, count, stats)| {
(stats.n_entries(), stats.n_constraints, stats.n_vars, *count)
});
data.reverse();
for (t, c, i) in data {
if c.n_constraints != 0 {
println!(
"{:>8}: {:>8}cs cum, {:>8}vs cum, {:>8}cs, {:>8}vs; {} {:?}",
format!("{}", t.id()),
c.n_constraints,
c.n_vars,
i.n_constraints,
i.n_vars,
t.op(),
t.cs().iter().map(|c| c.id()).collect::<Vec<_>>(),
)
}
println!(
"r1csstats,entries,constraints,count,vars,a_entries,b_entries,c_entries,op_arity,op"
);
for (op, arity, count, stats) in data {
println!(
"r1csstats,{},{},{},{},{},{},{},{},{}",
stats.n_entries(),
stats.n_constraints,
count,
stats.n_vars,
stats.n_a_entries,
stats.n_b_entries,
stats.n_c_entries,
arity,
op,
)
}
}
}
@@ -173,7 +157,6 @@ impl<'cfg> ToR1cs<'cfg> {
self.next_idx += 1;
debug_assert!(matches!(check(&comp), Sort::Field(_)));
self.r1cs.add_var(n.clone(), comp.clone(), ty);
self.metric.n_vars += 1;
debug!("fresh: {n:?}");
TermLc(comp, self.r1cs.signal_lc(&n))
}
@@ -185,7 +168,6 @@ impl<'cfg> ToR1cs<'cfg> {
/// Create a constraint
fn constraint(&mut self, a: Lc, b: Lc, c: Lc) {
self.metric.n_constraints += 1;
self.r1cs.constraint(a, b, c);
}
@@ -1045,7 +1027,17 @@ impl<'cfg> ToR1cs<'cfg> {
Op::PfNaryOp(o) => {
let args = c.cs().iter().map(|c| self.get_pf(c));
match o {
PfNaryOp::Add => args.fold(self.zero.clone(), std::ops::Add::add),
PfNaryOp::Add => {
let mut lc = args.fold(self.zero.clone(), std::ops::Add::add);
if lc.1.monomials.len() > self.cfg.r1cs.lc_elim_thresh {
let w = self.fresh_wit("add", lc.0.clone());
lc -= &w;
self.constraint(self.zero.1.clone(), self.zero.1.clone(), lc.1);
w
} else {
lc
}
}
PfNaryOp::Mul => {
// Needed to end the above closures borrow of self, before the mul call
#[allow(clippy::needless_collect)]

View File

@@ -1,11 +1,14 @@
//! A multi-stage R1CS witness evaluator.
use crate::cfg::cfg_or_default;
use crate::ir::term::*;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use serde::{Deserialize, Serialize};
use log::trace;
use std::time::Duration;
/// A witness computation that proceeds in stages.
///
/// In each stage:
@@ -101,6 +104,8 @@ pub struct StagedWitCompEvaluator<'a> {
step_values: Vec<Value>,
stages_evaluated: usize,
outputs_evaluted: usize,
op_times: HashMap<(Op, Vec<Sort>), (Duration, usize)>,
time_ops: bool,
}
impl<'a> StagedWitCompEvaluator<'a> {
@@ -112,6 +117,8 @@ impl<'a> StagedWitCompEvaluator<'a> {
step_values: Default::default(),
stages_evaluated: Default::default(),
outputs_evaluted: 0,
op_times: Default::default(),
time_ops: cfg_or_default().ir.time_eval_ops,
}
}
/// Have all stages been evaluated?
@@ -122,12 +129,27 @@ impl<'a> StagedWitCompEvaluator<'a> {
let next_step_idx = self.step_values.len();
assert!(next_step_idx < self.comp.steps.len());
let op = &self.comp.steps[next_step_idx].0;
let step_values = &self.step_values;
let op_times = &mut self.op_times;
let args: Vec<&Value> = self
.comp
.step_args(next_step_idx)
.map(|i| &self.step_values[i])
.map(|i| &step_values[i])
.collect();
let value = eval_op(op, &args, &self.variable_values);
let value = if self.time_ops {
let start = std::time::Instant::now();
let r = eval_op(op, &args, &self.variable_values);
let duration = start.elapsed();
let (ref mut dur, ref mut ct) = op_times
.entry((op.clone(), args.iter().map(|v| v.sort()).collect()))
.or_default();
*dur += duration;
*ct += 1;
r
} else {
eval_op(op, &args, &self.variable_values)
};
trace!(
"Eval step {}: {} on {:?} -> {}",
next_step_idx,
@@ -173,6 +195,30 @@ impl<'a> StagedWitCompEvaluator<'a> {
}
out
}
/// Prints out operator evaluation times (if self.time_ops is set)
pub fn print_times(&self) {
if self.time_ops {
// (operator, nanos total, counts, nanos/count, arg sorts (or *))
let mut rows: Vec<(String, usize, usize, f64, String)> = Default::default();
for ((op, arg_sorts), (time, count)) in &self.op_times {
let nanos = time.as_nanos() as usize;
let per = nanos as f64 / *count as f64;
rows.push((
format!("{}", op),
nanos,
*count,
per,
format!("{:?}", arg_sorts),
));
}
rows.sort_by_key(|t| t.1);
println!("time,op,nanos,counts,nanos_per,arg_sorts");
for (op, nanos, counts, nanos_per, arg_sorts) in &rows {
println!("time,{op},{nanos},{counts},{nanos_per},{arg_sorts}");
}
}
}
}
#[cfg(test)]

View File

@@ -2,7 +2,7 @@
use std::fmt::Display;
use fxhash::FxHashSet;
use fxhash::{FxHashMap, FxHashSet};
/// A namespace. Used to create unique names.
///
@@ -28,6 +28,7 @@ impl Namespace {
/// A tool for ensuring name uniqueness.
pub struct Uniquer {
used: FxHashSet<String>,
counts: FxHashMap<String, usize>,
}
impl Uniquer {
@@ -35,6 +36,7 @@ impl Uniquer {
pub fn new(used: impl IntoIterator<Item = String>) -> Self {
Uniquer {
used: used.into_iter().collect(),
counts: Default::default(),
}
}
/// Make a unique name prefixed by `prefix`, store it, and return it.
@@ -43,10 +45,12 @@ impl Uniquer {
self.used.insert(prefix.into());
return prefix.into();
}
for i in 0.. {
let counts = self.counts.entry(prefix.into()).or_default();
for i in *counts.. {
let name = format!("{prefix}_{i}");
if !self.used.contains(&name) {
self.used.insert(name.clone());
*counts = i + 1;
return name;
}
}