Compare commits

...

9 Commits

Author SHA1 Message Date
jmjac
de9e3f2673 Add __version__ to python bindings (#739) 2024-03-13 14:22:20 +00:00
dante
a1450f8df7 feat: gather_nd/scatter_nd support (#737) 2024-03-11 22:05:40 +00:00
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
dante
a59e3780b2 chore: rm recip_int helper (#733) 2024-03-05 21:51:14 +00:00
dante
345fb5672a chore: cleanup unused args (#732) 2024-03-05 13:43:29 +00:00
dante
70daaff2e4 chore: cleanup calibrate (#731) 2024-03-04 17:52:11 +00:00
dante
a437d8a51f feat: "sub"-dynamic tables (#730) 2024-03-04 10:35:28 +00:00
Ethan Cemer
fe535c1cac feat: wasm felt to little endian string (#729)
---------

Co-authored-by: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com>
2024-03-01 14:06:20 +00:00
33 changed files with 2262 additions and 1161 deletions

View File

@@ -575,7 +575,7 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pytest -vv
@@ -599,7 +599,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
@@ -634,7 +634,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
# - name: authenticate-kaggle-cli
# shell: bash
# env:

View File

@@ -343,7 +343,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" compress_selectors=True,\n",
" )\n",
"\n",
" assert res == True\n",

View File

@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# gather_nd in tf then export to onnx
x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

Binary file not shown.

View File

@@ -12,15 +12,11 @@ pub enum BaseOp {
DotInit,
CumProdInit,
CumProd,
Identity,
Add,
Mult,
Sub,
SumInit,
Sum,
Neg,
Range { tol: i32 },
IsZero,
IsBoolean,
}
@@ -36,12 +32,8 @@ impl BaseOp {
let (a, b) = inputs;
match &self {
BaseOp::Add => a + b,
BaseOp::Identity => b,
BaseOp::Neg => -b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::Range { .. } => b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
@@ -73,19 +65,15 @@ impl BaseOp {
/// display func
pub fn as_str(&self) -> &'static str {
match self {
BaseOp::Identity => "IDENTITY",
BaseOp::Dot => "DOT",
BaseOp::DotInit => "DOTINIT",
BaseOp::CumProdInit => "CUMPRODINIT",
BaseOp::CumProd => "CUMPROD",
BaseOp::Add => "ADD",
BaseOp::Neg => "NEG",
BaseOp::Sub => "SUB",
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::Range { .. } => "RANGE",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -93,8 +81,6 @@ impl BaseOp {
/// Returns the range of the query offset for this operation.
pub fn query_offset_rng(&self) -> (i32, usize) {
match self {
BaseOp::Identity => (0, 1),
BaseOp::Neg => (0, 1),
BaseOp::DotInit => (0, 1),
BaseOp::Dot => (-1, 2),
BaseOp::CumProd => (-1, 2),
@@ -104,8 +90,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::Range { .. } => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -113,8 +97,6 @@ impl BaseOp {
/// Returns the number of inputs for this operation.
pub fn num_inputs(&self) -> usize {
match self {
BaseOp::Identity => 1,
BaseOp::Neg => 1,
BaseOp::DotInit => 2,
BaseOp::Dot => 2,
BaseOp::CumProdInit => 1,
@@ -124,8 +106,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::Range { .. } => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}
@@ -133,19 +113,15 @@ impl BaseOp {
/// Returns the number of outputs for this operation.
pub fn constraint_idx(&self) -> usize {
match self {
BaseOp::Identity => 0,
BaseOp::Neg => 0,
BaseOp::DotInit => 0,
BaseOp::Dot => 1,
BaseOp::Add => 0,
BaseOp::Sub => 0,
BaseOp::Mult => 0,
BaseOp::Range { .. } => 0,
BaseOp::Sum => 1,
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -188,38 +188,158 @@ impl<'source> FromPyObject<'source> for Tolerance {
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub table_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub tables: Vec<VarTensor>,
}
impl DynamicLookups {
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
Self {
lookup_selectors: BTreeMap::new(),
table_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
tables: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
}
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub references: Vec<VarTensor>,
}
impl Shuffles {
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
Self {
input_selectors: BTreeMap::new(),
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
}
}
}
/// A struct representing the selectors for the static lookup tables
#[derive(Clone, Debug, Default)]
pub struct StaticLookups<F: PrimeField + TensorType + PartialOrd> {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub index: VarTensor,
///
pub output: VarTensor,
///
pub input: VarTensor,
}
impl<F: PrimeField + TensorType + PartialOrd> StaticLookups<F> {
/// Returns a new [StaticLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
selectors: BTreeMap::new(),
tables: BTreeMap::new(),
index: dummy_var.clone(),
output: dummy_var.clone(),
input: dummy_var,
}
}
}
/// A struct representing the selectors for custom gates
#[derive(Clone, Debug, Default)]
pub struct CustomGates {
/// the inputs to the accumulated operations.
pub inputs: Vec<VarTensor>,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// selector
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
}
impl CustomGates {
/// Returns a new [CustomGates] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
inputs: vec![dummy_var.clone(), dummy_var.clone()],
output: dummy_var,
selectors: BTreeMap::new(),
}
}
}
/// A struct representing the selectors for the range checks
#[derive(Clone, Debug, Default)]
pub struct RangeChecks<F: PrimeField + TensorType + PartialOrd> {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub ranges: BTreeMap<Range, RangeCheck<F>>,
///
pub index: VarTensor,
///
pub input: VarTensor,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeChecks<F> {
/// Returns a new [RangeChecks] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
selectors: BTreeMap::new(),
ranges: BTreeMap::new(),
index: dummy_var.clone(),
input: dummy_var,
}
}
}
/// Configuration for an accumulated arg.
#[derive(Clone, Debug, Default)]
pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
/// the inputs to the accumulated operations.
pub inputs: Vec<VarTensor>,
/// the VarTensor reserved for lookup operations (could be an element of inputs)
/// Note that you should be careful to ensure that the lookup_input is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_input: VarTensor,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// The VarTensor reserved for dynamic lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub dynamic_lookup_tables: Vec<VarTensor>,
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_output: VarTensor,
///
pub lookup_index: VarTensor,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure [BaseOp].
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub dynamic_lookup_selectors: BTreeMap<(usize, usize), Vec<Selector>>,
///
pub dynamic_table_selectors: Vec<Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Custom gates
pub custom_gates: CustomGates,
/// StaticLookups
pub static_lookups: StaticLookups<F>,
/// [Selector]s for the dynamic lookup tables
pub dynamic_lookups: DynamicLookups,
/// [Selector]s for the range checks
pub range_checks: RangeChecks<F>,
/// [Selector]s for the shuffles
pub shuffles: Shuffles,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -228,22 +348,12 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
inputs: vec![dummy_var.clone(), dummy_var.clone()],
lookup_input: dummy_var.clone(),
output: dummy_var.clone(),
lookup_output: dummy_var.clone(),
dynamic_lookup_tables: vec![VarTensor::dummy(col_size, 2), dummy_var.clone()],
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
dynamic_lookup_selectors: BTreeMap::new(),
dynamic_table_selectors: vec![],
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
_marker: PhantomData,
}
@@ -276,10 +386,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
for j in 0..output.num_inner_cols() {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Neg, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Identity, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -324,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
@@ -383,19 +484,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
.collect();
Self {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
dynamic_lookup_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
dynamic_table_selectors: vec![],
dynamic_lookup_tables: vec![],
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
custom_gates: CustomGates {
inputs: inputs.to_vec(),
output: output.clone(),
selectors,
},
static_lookups: StaticLookups::default(),
dynamic_lookups: DynamicLookups::default(),
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
check_mode,
_marker: PhantomData,
}
@@ -428,9 +525,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
// we borrow mutably twice so we need to do this dance
let table = if !self.tables.contains_key(nl) {
let table = if !self.static_lookups.tables.contains_key(nl) {
// as all tables have the same input we see if there's another table who's input we can reuse
let table = if let Some(table) = self.tables.values().next() {
let table = if let Some(table) = self.static_lookups.tables.values().next() {
Table::<F>::configure(
cs,
lookup_range,
@@ -441,7 +538,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
} else {
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
};
self.tables.insert(nl.clone(), table.clone());
self.static_lookups.tables.insert(nl.clone(), table.clone());
table
} else {
return Ok(());
@@ -525,22 +622,23 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
self.lookup_selectors
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
}
}
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
if let VarTensor::Empty = self.static_lookups.input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
self.static_lookups.input = input.clone();
}
if let VarTensor::Empty = self.lookup_output {
if let VarTensor::Empty = self.static_lookups.output {
debug!("assigning lookup output");
self.lookup_output = output.clone();
self.static_lookups.output = output.clone();
}
if let VarTensor::Empty = self.lookup_index {
if let VarTensor::Empty = self.static_lookups.index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
self.static_lookups.index = index.clone();
}
Ok(())
}
@@ -550,8 +648,8 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
pub fn configure_dynamic_lookup(
&mut self,
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 2],
tables: &[VarTensor; 2],
lookups: &[VarTensor; 3],
tables: &[VarTensor; 3],
) -> Result<(), Box<dyn Error>>
where
F: Field,
@@ -607,18 +705,105 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
expression
});
self.dynamic_lookup_selectors
self.dynamic_lookups
.lookup_selectors
.entry((x, y))
.or_default()
.push(s_lookup);
.or_insert(s_lookup);
}
}
self.dynamic_table_selectors.push(s_ltable);
self.dynamic_lookups.table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookup_tables.is_empty() {
if self.dynamic_lookups.tables.is_empty() {
debug!("assigning dynamic lookup table");
self.dynamic_lookup_tables = tables.to_vec();
self.dynamic_lookups.tables = tables.to_vec();
}
if self.dynamic_lookups.inputs.is_empty() {
debug!("assigning dynamic lookup input");
self.dynamic_lookups.inputs = lookups.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in inputs.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_reference = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.shuffles
.input_selectors
.entry((x, y))
.or_insert(s_input);
}
}
self.shuffles.reference_selectors.push(s_reference);
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
self.shuffles.inputs = inputs.to_vec();
}
Ok(())
@@ -643,15 +828,16 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
// we borrow mutably twice so we need to do this dance
let range_check =
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
self.range_checks.ranges.entry(range)
{
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
@@ -708,19 +894,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
self.range_check_selectors
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);
}
}
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
if let VarTensor::Empty = self.range_checks.input {
debug!("assigning range check input");
self.range_checks.input = input.clone();
}
if let VarTensor::Empty = self.lookup_index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
if let VarTensor::Empty = self.range_checks.index {
debug!("assigning range check index");
self.range_checks.index = index.clone();
}
Ok(())
@@ -728,7 +915,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
for (i, table) in self.tables.values_mut().enumerate() {
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
if !table.is_assigned {
debug!(
"laying out table for {}",
@@ -749,7 +936,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
for range_check in self.range_checks.values_mut() {
for range_check in self.range_checks.ranges.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
range_check.layout(layouter)?;

File diff suppressed because it is too large Load Diff

View File

@@ -14,10 +14,17 @@ pub enum PolyOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherND {
batch_dims: usize,
indices: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterND {
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -60,8 +67,6 @@ pub enum PolyOp {
len_prod: usize,
},
Pow(u32),
Pack(u32, u32),
GlobalSumPool,
Concat {
axis: usize,
},
@@ -91,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -110,8 +117,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Sum { .. } => "SUM".into(),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Pack(_, _) => "PACK".into(),
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -181,13 +186,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pack(base, scale) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pack inputs".to_string()));
}
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
}
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
@@ -206,7 +204,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
@@ -225,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
@@ -241,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
}?;
Ok(ForwardResult { output: res })
@@ -288,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::GatherND {
batch_dims,
indices,
} => {
if let Some(idx) = indices {
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
} else {
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
@@ -304,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterND { constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter_nd(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
)?
.into()
} else {
layouts::scatter_nd(config, region, values[..].try_into()?)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -334,10 +380,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
PolyOp::Pack(base, scale) => {
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
PolyOp::Slice { axis, start, end } => {
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
@@ -405,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
} else if matches!(self, PolyOp::ScatterElements { .. })
| matches!(self, PolyOp::ScatterND { .. })
{
vec![0, 2]
} else {
vec![]

View File

@@ -23,28 +23,61 @@ use super::lookup::LookupOp;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
lookup_index: usize,
index: usize,
col_coord: usize,
}
impl DynamicLookupIndex {
/// Create a new dynamic lookup index
pub fn new(lookup_index: usize, col_coord: usize) -> DynamicLookupIndex {
DynamicLookupIndex {
lookup_index,
col_coord,
}
pub fn new(index: usize, col_coord: usize) -> DynamicLookupIndex {
DynamicLookupIndex { index, col_coord }
}
/// Get the lookup index
pub fn lookup_index(&self) -> usize {
self.lookup_index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another dynamic lookup index
pub fn update(&mut self, other: &DynamicLookupIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
}
}
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct ShuffleIndex {
index: usize,
col_coord: usize,
}
impl ShuffleIndex {
/// Create a new dynamic lookup index
pub fn new(index: usize, col_coord: usize) -> ShuffleIndex {
ShuffleIndex { index, col_coord }
}
/// Get the lookup index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another shuffle index
pub fn update(&mut self, other: &ShuffleIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
}
}
/// Region error
@@ -94,6 +127,7 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
@@ -110,7 +144,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
///
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
self.dynamic_lookup_index.lookup_index += n;
self.dynamic_lookup_index.index += n;
}
///
@@ -118,6 +152,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.dynamic_lookup_index.col_coord += n;
}
///
pub fn increment_shuffle_index(&mut self, n: usize) {
self.shuffle_index.index += n;
}
///
pub fn increment_shuffle_col_coord(&mut self, n: usize) {
self.shuffle_index.col_coord += n;
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
@@ -135,6 +179,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
@@ -149,6 +194,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
@@ -158,6 +204,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
@@ -183,6 +230,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
@@ -198,9 +246,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
@@ -210,9 +255,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants,
dynamic_lookup_index,
used_lookups,
used_range_checks,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
@@ -272,6 +318,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
*output = output
.par_enum_map(|idx, _| {
@@ -287,9 +334,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
starting_linear_coord,
starting_constants,
self.num_inner_cols,
DynamicLookupIndex::default(),
HashSet::new(),
HashSet::new(),
self.throw_range_check_error,
);
let res = inner_loop_function(idx, &mut local_reg);
@@ -309,18 +353,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.lookup_index += local_reg.dynamic_lookup_index.lookup_index;
dynamic_lookup_index.col_coord += local_reg.dynamic_lookup_index.col_coord;
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
res
})
.map_err(|e| {
log::error!("dummy_loop: {:?}", e);
Error::Synthesis
})?;
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
@@ -357,6 +402,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
e
))
})?;
self.shuffle_index = Arc::try_unwrap(shuffle_index)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
Ok(())
}
@@ -429,7 +482,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Get the dynamic lookup index
pub fn dynamic_lookup_index(&self) -> usize {
self.dynamic_lookup_index.lookup_index
self.dynamic_lookup_index.index
}
/// Get the dynamic lookup column coordinate
@@ -437,6 +490,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.dynamic_lookup_index.col_coord
}
/// Get the shuffle index
pub fn shuffle_index(&self) -> usize {
self.shuffle_index.index
}
/// Get the shuffle column coordinate
pub fn shuffle_col_coord(&self) -> usize {
self.shuffle_index.col_coord
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
@@ -486,6 +549,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
}
///
pub fn combined_dynamic_shuffle_coord(&self) -> usize {
self.dynamic_lookup_col_coord() + self.shuffle_col_coord()
}
/// Assign a valtensor to a vartensor
pub fn assign_dynamic_lookup(
&mut self,
@@ -496,7 +564,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.dynamic_lookup_col_coord(),
self.combined_dynamic_shuffle_coord(),
values,
)
} else {
@@ -504,6 +572,15 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
}
/// Assign a valtensor to a vartensor
pub fn assign_shuffle(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,

View File

@@ -357,8 +357,6 @@ mod matmul_col_ultra_overflow_double_col {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -476,8 +474,6 @@ mod matmul_col_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1280,8 +1276,6 @@ mod conv_col_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1435,8 +1429,6 @@ mod conv_relu_col_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1605,15 +1597,19 @@ mod dynamic_lookup {
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
for _ in 0..NUM_LOOP {
config
.configure_dynamic_lookup(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
}
config
.configure_dynamic_lookup(
cs,
&[a.clone(), b.clone(), c.clone()],
&[d.clone(), e.clone(), f.clone()],
)
.unwrap();
config
}
@@ -1690,11 +1686,18 @@ mod dynamic_lookup {
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
let prev_idx = if loop_idx == 0 {
NUM_LOOP - 1
} else {
loop_idx - 1
};
[
ValTensor::from(Tensor::from(
(0..2).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
(0..3).map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((prev_idx * i * i) as u64 + 1))),
)),
ValTensor::from(Tensor::from((0..2).map(|_| Value::known(F::from(10000))))),
]
})
.collect::<Vec<_>>();
@@ -1710,6 +1713,133 @@ mod dynamic_lookup {
}
}
#[cfg(test)]
mod shuffle {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn shufflecircuit() {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
references: references.clone().try_into().unwrap(),
inputs: inputs.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
let prev_idx = if loop_idx == 0 {
NUM_LOOP - 1
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
references: references.try_into().unwrap(),
inputs: inputs.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
#[cfg(test)]
mod add_with_overflow {
use super::*;
@@ -2113,75 +2243,6 @@ mod pow {
}
}
#[cfg(test)]
mod pack {
use super::*;
const K: usize = 8;
const LEN: usize = 4;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 1],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&self.inputs.clone(),
Box::new(PolyOp::Pack(2, 1)),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
fn packcircuit() {
// parameters
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = MyCircuit::<F> {
inputs: [ValTensor::from(a)],
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
#[cfg(test)]
mod matmul_relu {
use super::*;
@@ -2469,7 +2530,5 @@ mod lookup_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}

View File

@@ -77,7 +77,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
/// Default lookup safety margin
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
/// Default render vk seperately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
@@ -402,9 +402,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
@@ -450,8 +447,8 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
/// Aggregates proofs :)
Aggregate {
@@ -515,8 +512,8 @@ pub enum Commands {
#[arg(short = 'W', long)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -540,8 +537,8 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
num_runs: usize,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information

View File

@@ -23,6 +23,7 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::RunArgs;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::poly::commitment::Params;
@@ -63,7 +64,11 @@ use std::process::Command;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
#[cfg(not(target_arch = "wasm32"))]
use std::sync::OnceLock;
#[cfg(not(target_arch = "wasm32"))]
use crate::EZKL_BUF_CAPACITY;
#[cfg(not(target_arch = "wasm32"))]
use std::io::BufWriter;
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;
@@ -140,13 +145,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
compiled_circuit,
transcript,
num_runs,
compress_selectors,
disable_selector_compression,
} => fuzz(
compiled_circuit,
witness,
transcript,
num_runs,
compress_selectors,
disable_selector_compression,
),
Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32),
#[cfg(not(target_arch = "wasm32"))]
@@ -154,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
settings_path,
logrows,
check,
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
} => get_srs_cmd(srs_path, settings_path, logrows).await,
Commands::Table { model, args } => table(model, args),
#[cfg(feature = "render")]
Commands::RenderCircuit {
@@ -260,14 +264,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
pk_path,
witness,
compress_selectors,
disable_selector_compression,
} => setup(
compiled_circuit,
srs_path,
vk_path,
pk_path,
witness,
compress_selectors,
disable_selector_compression,
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEvmData {
@@ -331,7 +335,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
} => setup_aggregate(
sample_snarks,
vk_path,
@@ -339,7 +343,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
),
Commands::Aggregate {
proof_path,
@@ -487,16 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let file = std::fs::File::open(path.clone())?;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
let mut buffer = vec![];
let bytes_read = std::io::BufReader::new(file).read_to_end(&mut buffer)?;
debug!("read {} bytes from SRS file", bytes_read);
let bytes_read = reader.read_to_end(&mut buffer)?;
info!(
"read {} bytes from file (vector of len = {})",
bytes_read,
buffer.len()
);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
info!("file hash: {}", hash);
Ok(hash)
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
Some(h) => h,
@@ -520,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
srs_path: Option<PathBuf>,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<String, Box<dyn Error>> {
// logrows overrides settings
@@ -548,18 +563,21 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
if matches!(check_mode, CheckMode::SAFE) {
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated");
}
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
file.write_all(reader.get_ref())?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
info!("SRS already exists at that path");
@@ -806,7 +824,6 @@ pub(crate) fn calibrate(
let settings = GraphSettings::load(&settings_path)?;
// now retrieve the run args
// we load the model to get the input and output shapes
// check if gag already exists
let model = Model::from_run_args(&settings.run_args, &model_path)?;
@@ -887,7 +904,12 @@ pub(crate) fn calibrate(
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
let key = (input_scale, param_scale, scale_rebase_multiplier);
let key = (
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
);
forward_pass_res.insert(key, vec![]);
let local_run_args = RunArgs {
@@ -898,6 +920,18 @@ pub(crate) fn calibrate(
..settings.run_args.clone()
};
// if unix get a gag
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(unix)]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
@@ -938,29 +972,27 @@ pub(crate) fn calibrate(
}
}
let min_lookup_range = forward_pass_res
.get(&key)
.unwrap()
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
drop(_g);
let result = forward_pass_res.get(&key).ok_or("key not found")?;
let min_lookup_range = result
.iter()
.map(|x| x.min_lookup_inputs)
.min()
.unwrap_or(0);
let max_lookup_range = forward_pass_res
.get(&key)
.unwrap()
let max_lookup_range = result
.iter()
.map(|x| x.max_lookup_inputs)
.max()
.unwrap_or(0);
let max_range_size = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_size)
.max()
.unwrap_or(0);
let max_range_size = result.iter().map(|x| x.max_range_size).max().unwrap_or(0);
let res = circuit.calc_min_logrows(
(min_lookup_range, max_lookup_range),
@@ -1081,6 +1113,7 @@ pub(crate) fn calibrate(
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
best_params.run_args.div_rebasing,
))
.ok_or("no params found")?
.iter()
@@ -1499,7 +1532,7 @@ pub(crate) fn setup(
vk_path: PathBuf,
pk_path: PathBuf,
witness: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit)?;
@@ -1513,7 +1546,7 @@ pub(crate) fn setup(
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
disable_selector_compression,
)
.map_err(Box::<dyn Error>::from)?;
@@ -1654,7 +1687,7 @@ pub(crate) fn fuzz(
data_path: PathBuf,
transcript: TranscriptType,
num_runs: usize,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let passed = AtomicBool::new(true);
@@ -1664,7 +1697,7 @@ pub(crate) fn fuzz(
let logrows = circuit.settings().run_args.logrows;
info!("setting up tests");
#[cfg(unix)]
let _r = Gag::stdout()?;
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -1673,7 +1706,7 @@ pub(crate) fn fuzz(
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
disable_selector_compression,
)
.map_err(Box::<dyn Error>::from)?;
@@ -1682,6 +1715,7 @@ pub(crate) fn fuzz(
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);
#[cfg(unix)]
std::mem::drop(_r);
info!("starting fuzzing");
@@ -1694,7 +1728,7 @@ pub(crate) fn fuzz(
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
disable_selector_compression,
)
.map_err(|_| ())?;
@@ -1772,7 +1806,7 @@ pub(crate) fn fuzz(
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
disable_selector_compression,
)
.map_err(|_| ())?;
@@ -1872,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
passed: &AtomicBool,
) {
let num_failures = AtomicI64::new(0);
#[cfg(unix)]
let _r = Gag::stdout().unwrap();
let pb = init_bar(num_runs as u64);
@@ -1885,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
pb.inc(1);
});
pb.finish_with_message("Done.");
#[cfg(unix)]
std::mem::drop(_r);
info!(
"num failures: {} out of {}",
@@ -1951,7 +1987,7 @@ pub(crate) fn setup_aggregate(
srs_path: Option<PathBuf>,
logrows: u32,
split_proofs: bool,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
// the K used for the aggregation circuit
let params = load_params_cmd(srs_path, logrows)?;
@@ -1965,7 +2001,7 @@ pub(crate) fn setup_aggregate(
let agg_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(
&agg_circuit,
&params,
compress_selectors,
disable_selector_compression,
)?;
let agg_vk = agg_pk.get_vk();
@@ -2042,7 +2078,8 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = if reduced_srs {
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
// only need G_0 for the verification with shplonk
load_params_cmd(srs_path, 1)?
} else {
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
};

View File

@@ -12,6 +12,8 @@ pub mod utilities;
pub mod vars;
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
pub use input::DataSource;
@@ -453,6 +455,10 @@ pub struct GraphSettings {
pub total_dynamic_col_size: usize,
/// number of dynamic lookups
pub num_dynamic_lookups: usize,
/// number of shuffles
pub num_shuffles: usize,
/// total shuffle column size
pub total_shuffle_col_size: usize,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -482,8 +488,14 @@ impl GraphSettings {
.ceil() as u32
}
fn dynamic_lookup_logrows(&self) -> u32 {
(self.total_dynamic_col_size as f64).log2().ceil() as u32
fn dynamic_lookup_and_shuffle_logrows(&self) -> u32 {
(self.total_dynamic_col_size as f64 + self.total_shuffle_col_size as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
self.total_dynamic_col_size + self.total_shuffle_col_size
}
fn module_constraint_logrows(&self) -> u32 {
@@ -583,6 +595,11 @@ impl GraphSettings {
self.num_dynamic_lookups > 0
}
/// requires dynamic shuffle
pub fn requires_shuffle(&self) -> bool {
self.num_shuffles > 0
}
/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
self.run_args.input_visibility.is_kzgcommit()
@@ -1097,7 +1114,7 @@ impl GraphCircuit {
// These are hard lower limits, we can't overflow instances or modules constraints
let instance_logrows = self.settings().log2_total_instances();
let module_constraint_logrows = self.settings().module_constraint_logrows();
let dynamic_lookup_logrows = self.settings().dynamic_lookup_logrows();
let dynamic_lookup_logrows = self.settings().dynamic_lookup_and_shuffle_logrows();
min_logrows = std::cmp::max(
min_logrows,
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
@@ -1181,9 +1198,9 @@ impl GraphCircuit {
max_range_size: i128,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS {
return false;
} else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS {
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|| Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS
{
return false;
}
@@ -1192,7 +1209,26 @@ impl GraphCircuit {
settings.run_args.logrows = k;
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(unix)]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
Self::configure_with_params(&mut cs, settings);
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
drop(_g);
#[cfg(feature = "mv-lookup")]
let cs = cs.chunk_lookups();
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial

View File

@@ -103,6 +103,10 @@ pub struct DummyPassRes {
pub num_dynamic_lookups: usize,
/// dynamic lookup col size
pub dynamic_lookup_col_coord: usize,
/// num shuffles
pub num_shuffles: usize,
/// shuffle
pub shuffle_col_coord: usize,
/// linear coordinate
pub linear_coord: usize,
/// total const size
@@ -546,6 +550,8 @@ impl Model {
model_input_scales: self.graph.get_input_scales(),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
total_shuffle_col_size: res.shuffle_col_coord,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -1019,7 +1025,6 @@ impl Model {
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let num_dynamic_lookups = settings.num_dynamic_lookups;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
@@ -1041,8 +1046,16 @@ impl Model {
base_gate.configure_range_check(meta, input, index, range, logrows)?;
}
for _ in 0..num_dynamic_lookups {
if settings.requires_dynamic_lookup() {
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..3].try_into()?,
vars.advices[3..6].try_into()?,
)?;
}
if settings.requires_shuffle() {
base_gate.configure_shuffles(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
@@ -1456,6 +1469,8 @@ impl Model {
max_range_size: region.max_range_size(),
num_dynamic_lookups: region.dynamic_lookup_index(),
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
num_shuffles: region.shuffle_index(),
shuffle_col_coord: region.shuffle_col_coord(),
outputs,
};

View File

@@ -23,7 +23,10 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
},
change_axes::AxisOp,
cnn::{Conv, Deconv},
einsum::EinSum,
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
// Extract the max value
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(

View File

@@ -429,17 +429,20 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
let num_constants = params.total_const_size;
let module_requires_fixed = params.module_requires_fixed();
let requires_dynamic_lookup = params.requires_dynamic_lookup();
let dynamic_lookup_size = params.total_dynamic_col_size;
let requires_shuffle = params.requires_shuffle();
let dynamic_lookup_and_shuffle_size = params.dynamic_lookup_and_shuffle_col_size();
let mut advices = (0..3)
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
.collect_vec();
if requires_dynamic_lookup {
for _ in 0..2 {
let dynamic_lookup = VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_size);
if requires_dynamic_lookup || requires_shuffle {
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
for _ in 0..num_cols {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
if dynamic_lookup.num_blocks() > 1 {
panic!("dynamic lookup should only have one block");
panic!("dynamic lookup or shuffle should only have one block");
};
advices.push(dynamic_lookup);
}

View File

@@ -484,7 +484,7 @@ where
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
@@ -496,7 +496,7 @@ where
// Initialize verifying key
let now = Instant::now();
trace!("preparing VK");
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
let vk = keygen_vk_custom(params, &empty_circuit, !disable_selector_compression)?;
let elapsed = now.elapsed();
info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis());

View File

@@ -484,7 +484,6 @@ fn get_srs(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);
@@ -622,7 +621,7 @@ fn mock_aggregate(
pk_path=PathBuf::from(DEFAULT_PK),
srs_path=None,
witness_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
fn setup(
model: PathBuf,
@@ -630,7 +629,7 @@ fn setup(
pk_path: PathBuf,
srs_path: Option<PathBuf>,
witness_path: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<bool, PyErr> {
crate::execute::setup(
model,
@@ -638,7 +637,7 @@ fn setup(
vk_path,
pk_path,
witness_path,
compress_selectors,
disable_selector_compression,
)
.map_err(|e| {
let err_str = format!("Failed to run setup: {}", e);
@@ -719,7 +718,7 @@ fn verify(
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
srs_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
@@ -728,7 +727,7 @@ fn setup_aggregate(
logrows: u32,
split_proofs: bool,
srs_path: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<bool, PyErr> {
crate::execute::setup_aggregate(
sample_snarks,
@@ -737,7 +736,7 @@ fn setup_aggregate(
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
)
.map_err(|e| {
let err_str = format!("Failed to setup aggregate: {}", e);
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;

View File

@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
&mut self,
indices: &[Range<usize>],
value: &Tensor<T>,
) -> Result<(), TensorError>
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
let omitted_dims = (indices.len()..self.dims.len())
.map(|i| self.dims[i])
.collect::<Vec<_>>();
for dim in &omitted_dims {
full_indices.push(0..*dim);
}
let full_dims = full_indices
.iter()
.map(|x| x.end - x.start)
.collect::<Vec<_>>();
// now broadcast the value to the full dims
let value = value.expand(&full_dims)?;
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let _ = cartesian_coord
.iter()
.enumerate()
.map(|(i, e)| {
self.set(e, value[i].clone());
})
.collect::<Vec<_>>();
Ok(())
}
/// Get the array index from rows / columns indices.
///
/// ```
@@ -1526,18 +1588,20 @@ pub fn get_broadcasted_shape(
let num_dims_a = shape_a.len();
let num_dims_b = shape_b.len();
// reewrite the below using match
if num_dims_a == num_dims_b {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
match (num_dims_a, num_dims_b) {
(a, b) if a == b => {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
Ok(broadcasted_shape)
}
Ok(broadcasted_shape)
} else if num_dims_a < num_dims_b {
Ok(shape_b.to_vec())
} else {
Ok(shape_a.to_vec())
(a, b) if a < b => Ok(shape_b.to_vec()),
(a, b) if a > b => Ok(shape_a.to_vec()),
_ => Err(Box::new(TensorError::DimError(
"Unknown condition for broadcasting".to_string(),
))),
}
}
////////////////////////

View File

@@ -2,7 +2,7 @@ use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
use maybe_rayon::{
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
prelude::IntoParallelRefIterator,
};
use std::collections::{HashMap, HashSet};
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
Ok(output)
}
/// Gather ND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to gather
/// * `batch_dims` - Number of batch dimensions
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::gather_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3]),
/// &[2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 0, 1, 1]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
/// &[2, 2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 1, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
/// &[2, 2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
/// &[2, 2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if last_value > &(input_dims.len() - batch_dims) {
return Err(TensorError::DimMismatch("gather_nd".to_string()));
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
// cartesian coord over batch dims
let mut batch_cartesian_coord = input_dims[0..batch_dims]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let outputs = batch_cartesian_coord
.par_iter()
.map(|batch_coord| {
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&batch_slice)?;
index_slice.reshape(&index.dims()[batch_dims..])?;
let mut input_slice = input.get_slice(&batch_slice)?;
input_slice.reshape(&input.dims()[batch_dims..])?;
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let output = inner_cartesian_coord
.iter()
.map(|coord| {
let slice = coord
.iter()
.map(|x| *x..*x + 1)
.chain(batch_coord.iter().map(|x| *x..*x + 1))
.collect::<Vec<_>>();
let index_slice = index_slice
.get_slice(&slice)
.unwrap()
.iter()
.map(|x| *x..*x + 1)
.collect::<Vec<_>>();
input_slice.get_slice(&index_slice).unwrap()
})
.collect::<Tensor<_>>();
output.combine()
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
outputs.reshape(&output_size)?;
Ok(outputs)
}
/// Scatter ND.
/// This operator is the inverse of GatherND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to scatter
/// * `src` - Tensor of src
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scatter_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[8],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[4, 3, 1, 7]),
/// &[4, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10, 11, 12]),
/// &[4],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 2]),
/// &[2, 1],
/// ).unwrap();
///
/// let src = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// ]),
/// &[2, 4, 4],
/// ).unwrap();
///
/// let result = scatter_nd(&x, &index, &src).unwrap();
///
/// let expected = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[2, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10]),
/// &[2],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[1, 1, 2],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9]),
/// &[1, 1],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
if last_value > &input_dims.len() {
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
}
let mut output = input.clone();
let cartesian_coord = index_dims[0..index_dims.len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
cartesian_coord
.iter()
.map(|coord| {
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(output)
}
fn axes_op<T: TensorType + Send + Sync>(
a: &Tensor<T>,
axes: &[usize],

View File

@@ -4,6 +4,37 @@ use super::{
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
val: F,
len: usize,
) -> ValTensor<F> {
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
constant.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(constant)
}
pub(crate) fn create_unit_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
unit.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(unit)
}
pub(crate) fn create_zero_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
zero.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(zero)
}
#[derive(Debug, Clone)]
/// A [ValType] is a wrapper around Halo2 value(s).
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
@@ -318,6 +349,19 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
matches!(self, ValTensor::Instance { .. })
}
/// reverse order of elements whilst preserving the shape
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value { inner: v, .. } => {
v.reverse();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
}
};
Ok(())
}
///
pub fn set_initial_instance_offset(&mut self, offset: usize) {
if let ValTensor::Instance { initial_offset, .. } = self {
@@ -450,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.

View File

@@ -78,6 +78,17 @@ pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String,
Ok(format!("{:?}", felt))
}
/// Converts a felt to a little endian string
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let repr = serde_json::to_string(&felt).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
Ok(b)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]

View File

@@ -193,7 +193,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 77] = [
const TESTS: [&str; 79] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -275,6 +275,8 @@ mod native_tests {
"ltsf",
"remainder", //75
"bitshift",
"gather_nd",
"scatter_nd",
];
const WASM_TESTS: [&str; 46] = [
@@ -502,7 +504,7 @@ mod native_tests {
}
});
seq!(N in 0..=76 {
seq!(N in 0..=78 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -589,13 +591,16 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
}
#(#[test_case(TESTS[N])])*
@@ -1849,6 +1854,7 @@ mod native_tests {
&format!("{}/{}/key.pk", test_dir, example_name),
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
"--disable-selector-compression",
])
.status()
.expect("failed to execute process");

View File

@@ -9,8 +9,8 @@ mod wasm32 {
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, genPk, genVk, genWitness, inputValidation, pkValidation,
poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
@@ -89,9 +89,16 @@ mod wasm32 {
.unwrap();
assert_eq!(integer, i as i128);
let hex_string = format!("{:?}", field_element);
let returned_string: String = feltToBigEndian(clamped).map_err(|_| "failed").unwrap();
let hex_string = format!("{:?}", field_element.clone());
let returned_string: String = feltToBigEndian(clamped.clone())
.map_err(|_| "failed")
.unwrap();
assert_eq!(hex_string, returned_string);
let repr = serde_json::to_string(&field_element).unwrap();
let little_endian_string: String = serde_json::from_str(&repr).unwrap();
let returned_string: String =
feltToLittleEndian(clamped).map_err(|_| "failed").unwrap();
assert_eq!(little_endian_string, returned_string);
}
}

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -29,6 +29,8 @@
"num_rows": 16,
"total_dynamic_col_size": 0,
"num_dynamic_lookups": 0,
"num_shuffles": 0,
"total_shuffle_col_size": 0,
"total_assignments": 32,
"total_const_size": 8,
"model_instance_shapes": [

Binary file not shown.