mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c8daf773c | ||
|
|
80041ac523 | ||
|
|
2a1ee1102c | ||
|
|
95d4fd4a70 | ||
|
|
e0d3f4f145 | ||
|
|
bceac2fab5 | ||
|
|
04d7b5feaa | ||
|
|
45fd12a04f |
23
.github/workflows/rust.yml
vendored
23
.github/workflows/rust.yml
vendored
@@ -198,8 +198,6 @@ jobs:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Install wasm runner
|
||||
run: cargo install wasm-server-runner
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
@@ -352,9 +350,6 @@ jobs:
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
@@ -427,21 +422,21 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
fuzz-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
@@ -603,6 +598,8 @@ jobs:
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
|
||||
3
.github/workflows/wasm.yml
vendored
3
.github/workflows/wasm.yml
vendored
@@ -29,9 +29,6 @@ jobs:
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
@@ -16,7 +17,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: (i128, i128) = (-32768, 32768);
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
@@ -16,7 +17,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: (i128, i128) = (-8180, 8180);
|
||||
const BITS: Range = (-8180, 8180);
|
||||
static mut LEN: usize = 4;
|
||||
static mut K: usize = 16;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
@@ -14,7 +15,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::Rng;
|
||||
|
||||
const BITS: (i128, i128) = (-32768, 32768);
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
|
||||
39
examples/onnx/1l_tiny_div/gen.py
Normal file
39
examples/onnx/1l_tiny_div/gen.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
|
||||
class Circuit(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Circuit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x/ 10000
|
||||
|
||||
|
||||
circuit = Circuit()
|
||||
|
||||
|
||||
x = torch.empty(1, 8).random_(0, 2)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/1l_tiny_div/input.json
Normal file
1
examples/onnx/1l_tiny_div/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]]}
|
||||
BIN
examples/onnx/1l_tiny_div/network.onnx
Normal file
BIN
examples/onnx/1l_tiny_div/network.onnx
Normal file
Binary file not shown.
@@ -125,8 +125,8 @@ impl BaseOp {
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::Range { .. } => 1,
|
||||
BaseOp::IsZero => 1,
|
||||
BaseOp::IsBoolean => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
circuit::ops::base::BaseOp,
|
||||
circuit::{table::Table, utils},
|
||||
circuit::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
|
||||
@@ -176,6 +179,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
|
||||
///
|
||||
pub tables: BTreeMap<LookupOp, Table<F>>,
|
||||
///
|
||||
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
|
||||
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -194,7 +201,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
lookup_index: dummy_var,
|
||||
selectors: BTreeMap::new(),
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
check_mode: CheckMode::SAFE,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -267,9 +276,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))]
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
}
|
||||
BaseOp::IsZero => vec![qis[1].clone()],
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
@@ -325,11 +345,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Self {
|
||||
selectors,
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
inputs: inputs.to_vec(),
|
||||
lookup_input: VarTensor::Empty,
|
||||
lookup_output: VarTensor::Empty,
|
||||
lookup_index: VarTensor::Empty,
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
output: output.clone(),
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
@@ -344,7 +366,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
input: &VarTensor,
|
||||
output: &VarTensor,
|
||||
index: &VarTensor,
|
||||
lookup_range: (i128, i128),
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
nl: &LookupOp,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
@@ -482,6 +504,74 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_range_check(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
input: &VarTensor,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for x in 0..input.num_blocks() {
|
||||
for y in 0..input.num_inner_cols() {
|
||||
let single_col_sel = cs.complex_selector();
|
||||
|
||||
cs.lookup("", |cs| {
|
||||
let mut res = vec![];
|
||||
let sel = cs.query_selector(single_col_sel);
|
||||
|
||||
let input_query = match &input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let default_x = range_check.get_first_element();
|
||||
|
||||
let not_sel = Expression::Constant(F::ONE) - sel.clone();
|
||||
|
||||
res.extend([(
|
||||
sel.clone() * input_query.clone()
|
||||
+ not_sel.clone() * Expression::Constant(default_x),
|
||||
range_check.input,
|
||||
)]);
|
||||
|
||||
res
|
||||
});
|
||||
selectors.insert((range, x, y), single_col_sel);
|
||||
}
|
||||
}
|
||||
self.range_check_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
debug!("assigning lookup input");
|
||||
self.lookup_input = input.clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// layout_tables must be called before layout.
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
for (i, table) in self.tables.values_mut().enumerate() {
|
||||
@@ -500,6 +590,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// layout_range_checks must be called before layout.
|
||||
pub fn layout_range_checks(
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
for range_check in self.range_checks.values_mut() {
|
||||
if !range_check.is_assigned {
|
||||
debug!("laying out range check for {:?}", range_check.range);
|
||||
range_check.layout(layouter)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Assigns variables to the regions created when calling `configure`.
|
||||
/// # Arguments
|
||||
/// * `values` - The explicit values to the operations.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{self, layouts, utils, Tolerance},
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -13,6 +14,15 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
ReduceMax {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
@@ -75,6 +85,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
match self {
|
||||
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
|
||||
HybridOp::ScatterElements { .. } => vec![0, 2],
|
||||
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
@@ -113,25 +124,53 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
|
||||
// if denom is a round number and use_range_check_for_int is true, use range check check
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
let divisor = Tensor::from(vec![denom.0 as i128 / 2].into_iter());
|
||||
(res, vec![-divisor.clone(), divisor])
|
||||
} else {
|
||||
(res, vec![x])
|
||||
}
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
let res = crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
);
|
||||
// if scale is a round number and use_range_check_for_int is true, use range check check
|
||||
if input_scale.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
let err_tol = Tensor::from(
|
||||
vec![(output_scale.0 * input_scale.0) as i128 / 2].into_iter(),
|
||||
);
|
||||
(res, vec![-err_tol.clone(), err_tol])
|
||||
} else {
|
||||
(res, vec![x])
|
||||
}
|
||||
}
|
||||
HybridOp::ReduceArgMax { dim } => {
|
||||
let res = tensor::ops::argmax_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
(res, inter)
|
||||
}
|
||||
HybridOp::ReduceArgMin { dim } => {
|
||||
let res = tensor::ops::argmin_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
(res, inter)
|
||||
}
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -140,18 +179,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
(res.clone(), vec![])
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => (
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone(),
|
||||
vec![],
|
||||
),
|
||||
HybridOp::TopK { dim, k, largest } => {
|
||||
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
|
||||
|
||||
@@ -183,10 +218,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
(res.clone(), vec![])
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
@@ -198,10 +231,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
} else {
|
||||
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
|
||||
let src = inputs[2].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
(res.clone(), vec![])
|
||||
}
|
||||
}
|
||||
HybridOp::MaxPool2d {
|
||||
@@ -272,6 +303,21 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
|
||||
input_scale, output_scale, use_range_check_for_int
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"DIV (denom={}, use_range_check_for_int={})",
|
||||
denom, use_range_check_for_int
|
||||
),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -335,6 +381,57 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
*kernel_shape,
|
||||
*normalized,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
if input_scale.0.fract() == 0.0
|
||||
&& output_scale.0.fract() == 0.0
|
||||
&& *use_range_check_for_int
|
||||
{
|
||||
layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i128_to_felt(input_scale.0 as i128),
|
||||
i128_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Recip {
|
||||
input_scale: *input_scale,
|
||||
output_scale: *output_scale,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i128_to_felt(denom.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Div {
|
||||
denom: *denom,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
@@ -422,86 +519,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { .. } => 2 * in_scales[0],
|
||||
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
match self {
|
||||
HybridOp::ReduceMax { .. }
|
||||
| HybridOp::ReduceMin { .. }
|
||||
| HybridOp::MaxPool2d { .. } => Op::<F>::required_lookups(&LookupOp::ReLU),
|
||||
HybridOp::Softmax { scale, .. } => {
|
||||
vec![
|
||||
LookupOp::Exp { scale: *scale },
|
||||
LookupOp::Recip {
|
||||
scale: scale.0.powf(2.0).into(),
|
||||
},
|
||||
]
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let mut lookups = vec![];
|
||||
if tol.val > 0.0 {
|
||||
let scale_squared = tol.scale.0.powf(2.0);
|
||||
lookups.extend([
|
||||
LookupOp::Recip {
|
||||
scale: scale_squared.into(),
|
||||
},
|
||||
LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32((tol.val * scale_squared) / 100.0),
|
||||
},
|
||||
]);
|
||||
}
|
||||
lookups
|
||||
}
|
||||
HybridOp::Greater { .. } | HybridOp::Less { .. } => {
|
||||
vec![LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32(0.),
|
||||
}]
|
||||
}
|
||||
HybridOp::GreaterEqual { .. } | HybridOp::LessEqual { .. } => {
|
||||
vec![LookupOp::GreaterThanEqual {
|
||||
a: circuit::utils::F32(0.),
|
||||
}]
|
||||
}
|
||||
HybridOp::TopK { .. } => {
|
||||
vec![
|
||||
LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32(0.),
|
||||
},
|
||||
LookupOp::KroneckerDelta,
|
||||
]
|
||||
}
|
||||
HybridOp::Gather {
|
||||
constant_idx: None, ..
|
||||
}
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::GatherElements {
|
||||
constant_idx: None, ..
|
||||
}
|
||||
| HybridOp::ScatterElements {
|
||||
constant_idx: None, ..
|
||||
}
|
||||
| HybridOp::Equals { .. } => {
|
||||
vec![LookupOp::KroneckerDelta]
|
||||
}
|
||||
HybridOp::ReduceArgMax { .. } | HybridOp::ReduceArgMin { .. } => {
|
||||
vec![LookupOp::ReLU, LookupOp::KroneckerDelta]
|
||||
}
|
||||
HybridOp::SumPool {
|
||||
kernel_shape,
|
||||
normalized: true,
|
||||
..
|
||||
} => {
|
||||
vec![LookupOp::Div {
|
||||
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
|
||||
}]
|
||||
}
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
|
||||
@@ -18,8 +18,11 @@ use super::{
|
||||
region::RegionCtx,
|
||||
};
|
||||
use crate::{
|
||||
circuit::{ops::base::BaseOp, utils},
|
||||
fieldutils::i128_to_felt,
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
utils::{self},
|
||||
},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
@@ -51,6 +54,144 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
|
||||
total_len
|
||||
}
|
||||
|
||||
/// Div accumulated layout
|
||||
pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
div: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_i128(div) / 2;
|
||||
|
||||
let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
|
||||
divisor.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let divisor = region.assign(&config.inputs[1], &divisor.into())?;
|
||||
region.increment(divisor.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), divisor.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), input.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// recip accumulated layout
|
||||
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: F,
|
||||
output_scale: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
|
||||
|
||||
let mut scaled_unit =
|
||||
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
|
||||
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
|
||||
region.increment(scaled_unit.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
tensor::ops::nonlinearities::recip(
|
||||
&input_evals,
|
||||
felt_to_i128(input_scale) as f64,
|
||||
felt_to_i128(output_scale) as f64,
|
||||
)
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
|
||||
// this is now of scale 2 * scale
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), input.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
// this is now of scale 2 * scale hence why we rescaled the unit scale
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), scaled_unit.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
|
||||
|
||||
// debug print the diff
|
||||
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
|
||||
|
||||
log::debug!("range_check_bracket: {:?}", range_check_bracket);
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Dot product accumulated layout
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -648,14 +789,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_input = region.assign(&config.inputs[0], &input)?;
|
||||
|
||||
// now assert all elems are 0 or 1
|
||||
let assigned_output = region.assign(&config.inputs[1], &output)?;
|
||||
if !region.is_dummy() {
|
||||
for i in 0..assigned_output.len() {
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
let assigned_output = boolean_identity(config, region, &[output.clone()], true)?;
|
||||
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));
|
||||
|
||||
let sum = sum(config, region, &[assigned_output.clone()])?;
|
||||
@@ -1560,10 +1694,28 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let diff_inverse = diff.inverse()?;
|
||||
let product_diff_and_invert =
|
||||
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
|
||||
|
||||
let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?;
|
||||
// constant of 1
|
||||
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
|
||||
ones.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
|
||||
Ok(res)
|
||||
// subtract
|
||||
let output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[ones.into(), product_diff_and_invert],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// take the product of diff and output
|
||||
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
|
||||
|
||||
is_zero_identity(config, region, &[prod_check], false)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Xor boolean operation
|
||||
@@ -1627,21 +1779,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
||||
.into();
|
||||
|
||||
// make sure mask is boolean
|
||||
let assigned_mask = region.assign(&config.inputs[1], mask)?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_mask.len())
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_mask.len());
|
||||
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
|
||||
|
||||
let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?;
|
||||
|
||||
@@ -1739,13 +1877,11 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
last_elem = nonlinearity(
|
||||
last_elem = div(
|
||||
config,
|
||||
region,
|
||||
&[last_elem],
|
||||
&LookupOp::Div {
|
||||
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
|
||||
},
|
||||
F::from((kernel_shape.0 * kernel_shape.1) as u64),
|
||||
)?;
|
||||
}
|
||||
Ok(last_elem)
|
||||
@@ -2242,18 +2378,60 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// is zero identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = region.assign(&config.inputs[1], &values[0])?;
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
// get zero constants indices
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j);
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
@@ -2261,7 +2439,6 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
region.increment(output.len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -2304,6 +2481,52 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// layout for range check.
|
||||
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
range: &crate::circuit::table::Range,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_range_check(*range);
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
|
||||
let x = values[0].clone();
|
||||
|
||||
let w = region.assign(&config.lookup_input, &x)?;
|
||||
|
||||
let assigned_len = x.len();
|
||||
|
||||
let is_dummy = region.is_dummy();
|
||||
|
||||
if !is_dummy {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config
|
||||
.lookup_input
|
||||
.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.range_check_selectors.get(&(*range, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
let elapsed = timer.elapsed();
|
||||
trace!(
|
||||
"range check {:?} layout took {:?}, row: {:?}",
|
||||
range,
|
||||
elapsed,
|
||||
region.row()
|
||||
);
|
||||
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
/// layout for nonlinearity check.
|
||||
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2311,6 +2534,8 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
nl: &LookupOp,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_lookup(nl.clone());
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
|
||||
@@ -2392,22 +2617,6 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// mean function layout
|
||||
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let x = &values[0];
|
||||
|
||||
let sum_x = sum(config, region, &[x.clone()])?;
|
||||
let nl = LookupOp::Div {
|
||||
denom: utils::F32((scale * x.len()) as f32),
|
||||
};
|
||||
nonlinearity(config, region, &[sum_x], &nl)
|
||||
}
|
||||
|
||||
/// Argmax
|
||||
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2520,24 +2729,8 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
)?;
|
||||
// relu(x - max(x - 1))
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
// constraining relu(x - max(x - 1)) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
// sum(relu(x - max(x - 1)))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2548,13 +2741,7 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
// constraining 1 - sum(relu(x - max(x - 1))) = 0
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
|
||||
Ok(assigned_max_val)
|
||||
}
|
||||
@@ -2599,23 +2786,8 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
// relu(min(x + 1) - x)
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
// constraining relu(min(x + 1) - x) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
// sum(relu(min(x + 1) - x))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2626,14 +2798,8 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
let relu_one_minus_sum_relu =
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
// constraining product to 0
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
|
||||
Ok(assigned_min_val)
|
||||
}
|
||||
@@ -2780,7 +2946,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
&[denom],
|
||||
// we set to input scale + output_scale so the output scale is output)scale
|
||||
&LookupOp::Recip {
|
||||
scale: scale.0.powf(2.0).into(),
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)?;
|
||||
|
||||
@@ -2808,19 +2975,22 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
|
||||
let scale_squared = scale.0.powf(2.0);
|
||||
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
|
||||
let recip = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone()],
|
||||
&LookupOp::Recip {
|
||||
scale: scale_squared.into(),
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)?;
|
||||
|
||||
// Multiply the difference by the recip
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
|
||||
let scale_squared = scale.0 * scale.0;
|
||||
|
||||
// Use the greater than look up table to check if the percent error is within the tolerance for upper bound
|
||||
let tol = tol / 100.0;
|
||||
let upper_bound = nonlinearity(
|
||||
@@ -2848,15 +3018,8 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
// Add the lower_bound and upper_bound
|
||||
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
|
||||
|
||||
// Assign the sum tensor to the inputs
|
||||
region.assign(&config.inputs[1], &sum)?;
|
||||
|
||||
// Constrain the sum to be all zeros
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(sum.len());
|
||||
is_zero_identity(config, region, &[sum.clone()], false)?;
|
||||
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::{multiplier_to_scale, scale_to_multiplier},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -17,47 +17,117 @@ use halo2curves::ff::PrimeField;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Abs,
|
||||
Div { denom: utils::F32 },
|
||||
Cast { scale: utils::F32 },
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ReLU,
|
||||
Max { scale: utils::F32, a: utils::F32 },
|
||||
Min { scale: utils::F32, a: utils::F32 },
|
||||
Ceil { scale: utils::F32 },
|
||||
Floor { scale: utils::F32 },
|
||||
Round { scale: utils::F32 },
|
||||
RoundHalfToEven { scale: utils::F32 },
|
||||
Sqrt { scale: utils::F32 },
|
||||
Rsqrt { scale: utils::F32 },
|
||||
Recip { scale: utils::F32 },
|
||||
LeakyReLU { slope: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
Cosh { scale: utils::F32 },
|
||||
ACosh { scale: utils::F32 },
|
||||
Sin { scale: utils::F32 },
|
||||
ASin { scale: utils::F32 },
|
||||
Sinh { scale: utils::F32 },
|
||||
ASinh { scale: utils::F32 },
|
||||
Tan { scale: utils::F32 },
|
||||
ATan { scale: utils::F32 },
|
||||
Tanh { scale: utils::F32 },
|
||||
ATanh { scale: utils::F32 },
|
||||
Erf { scale: utils::F32 },
|
||||
GreaterThan { a: utils::F32 },
|
||||
LessThan { a: utils::F32 },
|
||||
GreaterThanEqual { a: utils::F32 },
|
||||
LessThanEqual { a: utils::F32 },
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
},
|
||||
Sigmoid {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
GreaterThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
GreaterThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
Sign,
|
||||
KroneckerDelta,
|
||||
Pow { scale: utils::F32, a: utils::F32 },
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> (i128, i128) {
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as i128;
|
||||
(-range, range)
|
||||
@@ -120,7 +190,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
&x,
|
||||
f32::from(*scale).into(),
|
||||
)),
|
||||
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, scale.into())),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
|
||||
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
@@ -173,7 +250,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
|
||||
LookupOp::LessThan { .. } => "LESS_THAN".into(),
|
||||
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
|
||||
LookupOp::Recip { scale, .. } => format!("RECIP(scale={})", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RECIP(input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
@@ -220,12 +303,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { scale } => {
|
||||
let mut out_scale = inputs_scale[0];
|
||||
out_scale +=
|
||||
multiplier_to_scale(scale.0 as f64 / scale_to_multiplier(out_scale).powf(2.0));
|
||||
out_scale
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::Sign
|
||||
| LookupOp::GreaterThan { .. }
|
||||
| LookupOp::LessThan { .. }
|
||||
@@ -237,10 +315,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
vec![self.clone()]
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
|
||||
@@ -55,11 +55,6 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns the lookups required by the operation.
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns true if the operation is an input.
|
||||
fn is_input(&self) -> bool {
|
||||
false
|
||||
@@ -206,6 +201,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
true,
|
||||
)?))
|
||||
}
|
||||
_ => Ok(Some(super::layouts::identity(
|
||||
|
||||
@@ -33,7 +33,9 @@ pub enum PolyOp {
|
||||
Sub,
|
||||
Neg,
|
||||
Mult,
|
||||
Identity,
|
||||
Identity {
|
||||
out_scale: Option<crate::Scale>,
|
||||
},
|
||||
Reshape(Vec<usize>),
|
||||
MoveAxis {
|
||||
source: usize,
|
||||
@@ -85,7 +87,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Resize { .. } => "RESIZE".into(),
|
||||
PolyOp::Iff => "IFF".into(),
|
||||
PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation),
|
||||
PolyOp::Identity => "IDENTITY".into(),
|
||||
PolyOp::Identity { out_scale } => {
|
||||
format!("IDENTITY (out_scale={:?})", out_scale)
|
||||
}
|
||||
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
|
||||
PolyOp::Flatten(_) => "FLATTEN".into(),
|
||||
PolyOp::Pad(_) => "PAD".into(),
|
||||
@@ -135,7 +139,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
|
||||
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
|
||||
PolyOp::Identity => Ok(inputs[0].clone()),
|
||||
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
|
||||
PolyOp::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
@@ -264,7 +268,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Mult => {
|
||||
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
|
||||
}
|
||||
PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
@@ -290,12 +294,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MultiBroadcastTo { .. } => in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Neg => in_scales[0],
|
||||
PolyOp::MoveAxis { .. } => in_scales[0],
|
||||
PolyOp::Downsample { .. } => in_scales[0],
|
||||
PolyOp::Resize { .. } => in_scales[0],
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
let mut scale = in_scales[0];
|
||||
@@ -327,9 +326,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
output_scale
|
||||
}
|
||||
PolyOp::Add => {
|
||||
let mut scale_a = 0;
|
||||
let scale_b = in_scales[0];
|
||||
scale_a += in_scales[1];
|
||||
let scale_a = in_scales[0];
|
||||
let scale_b = in_scales[1];
|
||||
assert_eq!(scale_a, scale_b);
|
||||
scale_a
|
||||
}
|
||||
@@ -339,26 +337,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
scale += in_scales[1];
|
||||
scale
|
||||
}
|
||||
PolyOp::Identity => in_scales[0],
|
||||
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
|
||||
PolyOp::Pad(_) => in_scales[0],
|
||||
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
|
||||
PolyOp::Pack(_, _) => in_scales[0],
|
||||
PolyOp::GlobalSumPool => in_scales[0],
|
||||
PolyOp::Concat { axis: _ } => in_scales[0],
|
||||
PolyOp::Slice { .. } => in_scales[0],
|
||||
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
if matches!(
|
||||
self,
|
||||
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
|
||||
) {
|
||||
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
|
||||
vec![0, 1]
|
||||
} else if matches!(self, PolyOp::Iff) {
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
plonk::{Error, Selector},
|
||||
@@ -7,9 +10,14 @@ use halo2curves::ff::PrimeField;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashSet,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Region error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegionError {
|
||||
@@ -56,6 +64,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
total_constants: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
@@ -75,6 +85,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
row,
|
||||
linear_coord,
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
}
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
@@ -90,6 +102,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +118,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,8 +127,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
pub fn new_dummy_with_constants(
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
constants: usize,
|
||||
total_constants: usize,
|
||||
num_inner_cols: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -120,7 +138,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: constants,
|
||||
total_constants,
|
||||
used_lookups,
|
||||
used_range_checks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,6 +190,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
@@ -177,12 +199,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
let starting_constants = constants.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_constants(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
starting_constants,
|
||||
self.num_inner_cols,
|
||||
HashSet::new(),
|
||||
HashSet::new(),
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -195,6 +221,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
local_reg.total_constants() - starting_constants,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
res
|
||||
})
|
||||
.map_err(|e| {
|
||||
@@ -204,6 +235,21 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
|
||||
})?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -212,15 +258,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.region.is_none()
|
||||
}
|
||||
|
||||
/// duplicate_dummy
|
||||
pub fn duplicate_dummy(&self) -> Self {
|
||||
Self {
|
||||
region: None,
|
||||
linear_coord: self.linear_coord,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
row: self.row,
|
||||
total_constants: self.total_constants,
|
||||
}
|
||||
/// add used lookup
|
||||
pub fn add_used_lookup(&mut self, lookup: LookupOp) {
|
||||
self.used_lookups.insert(lookup);
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) {
|
||||
self.used_range_checks.insert(range);
|
||||
}
|
||||
|
||||
/// Get the offset
|
||||
@@ -238,6 +283,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.used_lookups.clone()
|
||||
}
|
||||
|
||||
/// get used range checks
|
||||
pub fn used_range_checks(&self) -> HashSet<Range> {
|
||||
self.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
|
||||
self.total_constants += 1;
|
||||
|
||||
@@ -19,6 +19,9 @@ use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
@@ -91,7 +94,7 @@ pub struct Table<F: PrimeField> {
|
||||
/// Flags if table has been previously assigned to.
|
||||
pub is_assigned: bool,
|
||||
/// Number of bits used in lookup table.
|
||||
pub range: (i128, i128),
|
||||
pub range: Range,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
@@ -129,7 +132,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range: (i128, i128), col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
|
||||
// double it to be safe
|
||||
let range_len = range.1 - range.0;
|
||||
// number of cols needed to store the range
|
||||
@@ -141,7 +144,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
range: (i128, i128),
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
@@ -257,3 +260,86 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Halo2 range check column
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RangeCheck<F: PrimeField> {
|
||||
/// Input to table.
|
||||
pub input: TableColumn,
|
||||
/// selector cn
|
||||
pub selector_constructor: SelectorConstructor<F>,
|
||||
/// Flags if table has been previously assigned to.
|
||||
pub is_assigned: bool,
|
||||
/// Number of bits used in lookup table.
|
||||
pub range: Range,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self) -> F {
|
||||
i128_to_felt(self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
|
||||
let inputs = cs.lookup_table_column();
|
||||
|
||||
RangeCheck {
|
||||
input: inputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(2),
|
||||
range,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values to the constraints generated when calling `configure`.
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
layouter.assign_table(
|
||||
|| "range check table",
|
||||
|mut table| {
|
||||
let _ = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(row_offset, input)| {
|
||||
table.assign_cell(
|
||||
|| format!("rc_i_col row {}", row_offset),
|
||||
self.input,
|
||||
row_offset,
|
||||
|| Value::known(*input),
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2154,7 +2154,7 @@ mod rangecheckpercent {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let scale = utils::F32(SCALE.pow(2) as f32);
|
||||
let scale = utils::F32(SCALE as f32);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
@@ -2162,11 +2162,12 @@ mod rangecheckpercent {
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
// set up a new GreaterThan and Recip tables
|
||||
let nl = &LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32((RANGE * scale.0) / 100.0),
|
||||
a: circuit::utils::F32((RANGE * SCALE.pow(2) as f32) / 100.0),
|
||||
};
|
||||
config
|
||||
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
@@ -2175,7 +2176,10 @@ mod rangecheckpercent {
|
||||
&a,
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Recip { scale },
|
||||
&LookupOp::Recip {
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
@@ -2511,7 +2515,8 @@ mod softmax {
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Recip {
|
||||
scale: SCALE.powf(2.0).into(),
|
||||
input_scale: SCALE.into(),
|
||||
output_scale: SCALE.into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -83,6 +83,8 @@ pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -323,9 +325,20 @@ pub enum Commands {
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
|
||||
#[arg(
|
||||
long,
|
||||
value_delimiter = ',',
|
||||
allow_hyphen_values = true,
|
||||
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
|
||||
)]
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
|
||||
#[arg(long)]
|
||||
div_rebasing: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
|
||||
@@ -176,7 +176,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
div_rebasing,
|
||||
} => calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -184,6 +186,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
max_logrows,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -626,6 +630,10 @@ pub(crate) async fn gen_witness(
|
||||
if let Some(output_path) = output {
|
||||
serde_json::to_writer(&File::create(output_path)?, &witness)?;
|
||||
}
|
||||
|
||||
// print the witness in debug
|
||||
debug!("witness: \n {}", witness.as_json()?.to_colored_json_auto()?);
|
||||
|
||||
Ok(witness)
|
||||
}
|
||||
|
||||
@@ -715,15 +723,14 @@ impl AccuracyResults {
|
||||
let error = (original.clone() - calibrated.clone())?;
|
||||
let abs_error = error.map(|x| x.abs());
|
||||
let squared_error = error.map(|x| x.powi(2));
|
||||
let percentage_error =
|
||||
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let abs_percentage_error = percentage_error.map(|x| x.abs());
|
||||
|
||||
errors.extend(error.into_iter());
|
||||
abs_errors.extend(abs_error.into_iter());
|
||||
squared_errors.extend(squared_error.into_iter());
|
||||
percentage_errors.extend(percentage_error.into_iter());
|
||||
abs_percentage_errors.extend(abs_percentage_error.into_iter());
|
||||
errors.extend(error);
|
||||
abs_errors.extend(abs_error);
|
||||
squared_errors.extend(squared_error);
|
||||
percentage_errors.extend(percentage_error);
|
||||
abs_percentage_errors.extend(abs_percentage_error);
|
||||
}
|
||||
|
||||
let mean_percent_error =
|
||||
@@ -734,22 +741,22 @@ impl AccuracyResults {
|
||||
let median_error = errors[errors.len() / 2];
|
||||
let max_error = *errors
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
let min_error = *errors
|
||||
.iter()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
|
||||
let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
|
||||
let median_abs_error = abs_errors[abs_errors.len() / 2];
|
||||
let max_abs_error = *abs_errors
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
let min_abs_error = *abs_errors
|
||||
.iter()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
|
||||
let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
|
||||
@@ -780,6 +787,8 @@ pub(crate) fn calibrate(
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
div_rebasing: Option<bool>,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
use std::collections::HashMap;
|
||||
@@ -823,9 +832,13 @@ pub(crate) fn calibrate(
|
||||
}
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
|
||||
vec![div_rebasing]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let scale_rebase_multiplier = [1, 2, 10];
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
// 2 x 2 grid
|
||||
let range_grid = range
|
||||
@@ -862,15 +875,21 @@ pub(crate) fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
pb.set_message("calibrating...");
|
||||
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
|
||||
#[cfg(unix)]
|
||||
@@ -890,6 +909,7 @@ pub(crate) fn calibrate(
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
@@ -964,6 +984,7 @@ pub(crate) fn calibrate(
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
@@ -973,6 +994,7 @@ pub(crate) fn calibrate(
|
||||
let found_settings = GraphSettings {
|
||||
run_args: found_run_args,
|
||||
required_lookups: new_settings.required_lookups,
|
||||
required_range_checks: new_settings.required_range_checks,
|
||||
model_output_scales: new_settings.model_output_scales,
|
||||
model_input_scales: new_settings.model_input_scales,
|
||||
num_rows: new_settings.num_rows,
|
||||
|
||||
@@ -23,7 +23,7 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::table::{Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::pfsys::PrettyElements;
|
||||
@@ -431,6 +431,8 @@ pub struct GraphSettings {
|
||||
pub module_sizes: ModuleSizes,
|
||||
/// required_lookups
|
||||
pub required_lookups: Vec<LookupOp>,
|
||||
/// required range_checks
|
||||
pub required_range_checks: Vec<Range>,
|
||||
/// check mode
|
||||
pub check_mode: CheckMode,
|
||||
/// ezkl version used
|
||||
@@ -639,7 +641,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
// dummy module settings, must load from GraphData after
|
||||
let mut settings = model.gen_params(run_args, CheckMode::UNSAFE)?;
|
||||
let mut settings = model.gen_params(run_args, run_args.check_mode)?;
|
||||
|
||||
let mut num_params = 0;
|
||||
if !model.const_shapes().is_empty() {
|
||||
@@ -960,19 +962,20 @@ impl GraphCircuit {
|
||||
min_lookup_inputs: i128,
|
||||
max_lookup_inputs: i128,
|
||||
lookup_safety_margin: i128,
|
||||
) -> (i128, i128) {
|
||||
) -> Range {
|
||||
let mut margin = (
|
||||
lookup_safety_margin * min_lookup_inputs,
|
||||
lookup_safety_margin * max_lookup_inputs,
|
||||
);
|
||||
if lookup_safety_margin == 1 {
|
||||
margin.0 -= 1;
|
||||
margin.1 += 1;
|
||||
margin.0 += 4;
|
||||
margin.1 += 4;
|
||||
}
|
||||
|
||||
margin
|
||||
}
|
||||
|
||||
fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(
|
||||
max_logrows as usize,
|
||||
Self::reserved_blinding_rows() as usize,
|
||||
@@ -1456,6 +1459,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
params.run_args.lookup_range,
|
||||
params.run_args.logrows as usize,
|
||||
params.required_lookups,
|
||||
params.required_range_checks,
|
||||
params.check_mode,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -6,6 +6,7 @@ use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
@@ -79,6 +80,21 @@ pub struct ModelConfig {
|
||||
/// Representation of execution graph
|
||||
pub type NodeGraph = BTreeMap<usize, NodeType>;
|
||||
|
||||
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct DummyPassRes {
|
||||
/// number of rows use
|
||||
pub num_rows: usize,
|
||||
/// linear coordinate
|
||||
pub linear_coord: usize,
|
||||
/// total const size
|
||||
pub total_const_size: usize,
|
||||
/// lookup ops
|
||||
pub lookup_ops: HashSet<LookupOp>,
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
}
|
||||
|
||||
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Model {
|
||||
@@ -233,13 +249,7 @@ impl NodeType {
|
||||
NodeType::SubGraph { out_dims, .. } => out_dims.clone(),
|
||||
}
|
||||
}
|
||||
/// Returns the lookups required by a graph
|
||||
pub fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
match self {
|
||||
NodeType::Node(n) => n.opkind.required_lookups(),
|
||||
NodeType::SubGraph { model, .. } => model.required_lookups(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the scales of the node's output.
|
||||
pub fn out_scales(&self) -> Vec<crate::Scale> {
|
||||
match self {
|
||||
@@ -424,14 +434,6 @@ impl ParsedNodes {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
self.graph
|
||||
.nodes
|
||||
.values()
|
||||
.flat_map(|n| n.required_lookups())
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Creates a `Model` from a specified path to an Onnx file.
|
||||
/// # Arguments
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
@@ -484,36 +486,21 @@ impl Model {
|
||||
);
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let (num_rows, linear_coord, total_const_size) =
|
||||
self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
|
||||
|
||||
// extract the requisite lookup ops from the model
|
||||
let mut lookup_ops: Vec<LookupOp> = self.required_lookups();
|
||||
let res = self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
if run_args.tolerance.val > 0.0 {
|
||||
for scale in self.graph.get_output_scales()? {
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(scale).into();
|
||||
let opkind: Box<dyn Op<Fp>> = Box::new(HybridOp::RangeCheck(tolerance));
|
||||
lookup_ops.extend(opkind.required_lookups());
|
||||
}
|
||||
}
|
||||
|
||||
let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
|
||||
lookup_ops.extend(set.into_iter().sorted());
|
||||
|
||||
Ok(GraphSettings {
|
||||
run_args: run_args.clone(),
|
||||
model_instance_shapes: instance_shapes,
|
||||
module_sizes: crate::graph::modules::ModuleSizes::default(),
|
||||
num_rows,
|
||||
total_assignments: linear_coord,
|
||||
required_lookups: lookup_ops,
|
||||
num_rows: res.num_rows,
|
||||
total_assignments: res.linear_coord,
|
||||
required_lookups: res.lookup_ops.into_iter().collect(),
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
total_const_size,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
num_blinding_factors: None,
|
||||
@@ -568,6 +555,8 @@ impl Model {
|
||||
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
debug!("input nodes: {:?}", n.inputs());
|
||||
|
||||
if n.is_lookup() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &inputs {
|
||||
@@ -611,7 +600,7 @@ impl Model {
|
||||
debug!("intermediate min lookup inputs: {}", min);
|
||||
}
|
||||
debug!(
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} ------------ scale: {}",
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
|
||||
idx,
|
||||
res.output.map(crate::fieldutils::felt_to_i32).show(),
|
||||
res.output
|
||||
@@ -1042,6 +1031,8 @@ impl Model {
|
||||
&run_args.param_visibility,
|
||||
i,
|
||||
symbol_values,
|
||||
run_args.div_rebasing,
|
||||
run_args.rebase_frac_zero_constants,
|
||||
)?;
|
||||
if let Some(ref scales) = override_input_scales {
|
||||
if let Some(inp) = n.opkind.get_input() {
|
||||
@@ -1058,9 +1049,20 @@ impl Model {
|
||||
if scales.contains_key(&i) {
|
||||
let scale_diff = n.out_scale - scales[&i];
|
||||
n.opkind = if scale_diff > 0 {
|
||||
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
|
||||
RebaseScale::rebase(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
1,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
} else {
|
||||
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
|
||||
RebaseScale::rebase_up(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
};
|
||||
n.out_scale = scales[&i];
|
||||
}
|
||||
@@ -1155,9 +1157,10 @@ impl Model {
|
||||
pub fn configure(
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
lookup_range: (i128, i128),
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
required_lookups: Vec<LookupOp>,
|
||||
required_range_checks: Vec<Range>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
@@ -1170,12 +1173,16 @@ impl Model {
|
||||
);
|
||||
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
|
||||
let input = &vars.advices[0];
|
||||
let output = &vars.advices[1];
|
||||
let index = &vars.advices[2];
|
||||
let output = &vars.advices[2];
|
||||
let index = &vars.advices[1];
|
||||
for op in required_lookups {
|
||||
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
|
||||
}
|
||||
|
||||
for range in required_range_checks {
|
||||
base_gate.configure_range_check(meta, input, range)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1216,6 +1223,7 @@ impl Model {
|
||||
let instance_idx = vars.get_instance_idx();
|
||||
|
||||
config.base.layout_tables(layouter)?;
|
||||
config.base.layout_range_checks(layouter)?;
|
||||
|
||||
let mut num_rows = 0;
|
||||
let mut linear_coord = 0;
|
||||
@@ -1482,7 +1490,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
input_shapes: &[Vec<usize>],
|
||||
) -> Result<(usize, usize, usize), Box<dyn Error>> {
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
info!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -1567,11 +1575,15 @@ impl Model {
|
||||
region.total_constants().to_string().red()
|
||||
);
|
||||
|
||||
Ok((
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants(),
|
||||
))
|
||||
let res = DummyPassRes {
|
||||
num_rows: region.row(),
|
||||
linear_coord: region.linear_coord(),
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Retrieves all constants from the model.
|
||||
|
||||
@@ -12,16 +12,12 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use itertools::Itertools;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::trace;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -94,10 +90,6 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
self.inner.required_lookups()
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -126,12 +118,14 @@ impl Op<Fp> for Rescaled {
|
||||
pub struct RebaseScale {
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// the multiplier applied to the node output
|
||||
pub multiplier: f64,
|
||||
/// rebase op
|
||||
pub rebase_op: HybridOp,
|
||||
/// scale being rebased to
|
||||
pub target_scale: i32,
|
||||
/// The original scale of the operation's inputs.
|
||||
pub original_scale: i32,
|
||||
/// multiplier
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
@@ -141,6 +135,7 @@ impl RebaseScale {
|
||||
global_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
scale_rebase_multiplier: u32,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
@@ -149,10 +144,15 @@ impl RebaseScale {
|
||||
let multiplier =
|
||||
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
|
||||
if let Some(op) = inner.get_rebased() {
|
||||
let multiplier = op.multiplier * multiplier;
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
inner: op.inner.clone(),
|
||||
target_scale: op.target_scale,
|
||||
multiplier: op.multiplier * multiplier,
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op.original_scale,
|
||||
})
|
||||
} else {
|
||||
@@ -160,6 +160,10 @@ impl RebaseScale {
|
||||
inner: Box::new(inner),
|
||||
target_scale: global_scale * scale_rebase_multiplier as i32,
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op_out_scale,
|
||||
})
|
||||
}
|
||||
@@ -173,15 +177,21 @@ impl RebaseScale {
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
|
||||
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
|
||||
if let Some(op) = inner.get_rebased() {
|
||||
let multiplier = op.multiplier * multiplier;
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
inner: op.inner.clone(),
|
||||
target_scale: op.target_scale,
|
||||
multiplier: op.multiplier * multiplier,
|
||||
multiplier,
|
||||
original_scale: op.original_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
@@ -189,6 +199,10 @@ impl RebaseScale {
|
||||
target_scale,
|
||||
multiplier,
|
||||
original_scale: op_out_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
@@ -203,19 +217,19 @@ impl Op<Fp> for RebaseScale {
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let ri = res.output.map(felt_to_i128);
|
||||
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
|
||||
res.output = rescaled.map(i128_to_felt);
|
||||
|
||||
res.intermediate_lookups.push(ri);
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
res.intermediate_lookups
|
||||
.extend(rebase_res.intermediate_lookups);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}) ({})",
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
self.multiplier,
|
||||
<HybridOp as Op<Fp>>::as_string(&self.rebase_op),
|
||||
self.inner.as_string()
|
||||
)
|
||||
}
|
||||
@@ -224,14 +238,6 @@ impl Op<Fp> for RebaseScale {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
let mut lookups = self.inner.required_lookups();
|
||||
lookups.push(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
});
|
||||
lookups
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -241,16 +247,8 @@ impl Op<Fp> for RebaseScale {
|
||||
let original_res = self
|
||||
.inner
|
||||
.layout(config, region, values)?
|
||||
.ok_or("no layout")?;
|
||||
|
||||
Ok(Some(crate::circuit::layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[original_res],
|
||||
&LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
},
|
||||
)?))
|
||||
.ok_or("no inner layout")?;
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
@@ -433,10 +431,6 @@ impl Op<Fp> for SupportedOp {
|
||||
self
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
self.as_op().required_lookups()
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
@@ -470,14 +464,7 @@ impl Tabled for Node {
|
||||
|
||||
fn headers() -> Vec<std::borrow::Cow<'static, str>> {
|
||||
let mut headers = Vec::with_capacity(Self::LENGTH);
|
||||
for i in [
|
||||
"idx",
|
||||
"opkind",
|
||||
"out_scale",
|
||||
"inputs",
|
||||
"out_dims",
|
||||
"required_lookups",
|
||||
] {
|
||||
for i in ["idx", "opkind", "out_scale", "inputs", "out_dims"] {
|
||||
headers.push(std::borrow::Cow::Borrowed(i));
|
||||
}
|
||||
headers
|
||||
@@ -490,14 +477,6 @@ impl Tabled for Node {
|
||||
fields.push(std::borrow::Cow::Owned(self.out_scale.to_string()));
|
||||
fields.push(std::borrow::Cow::Owned(display_vector(&self.inputs)));
|
||||
fields.push(std::borrow::Cow::Owned(display_vector(&self.out_dims)));
|
||||
fields.push(std::borrow::Cow::Owned(format!(
|
||||
"{:?}",
|
||||
self.opkind
|
||||
.required_lookups()
|
||||
.iter()
|
||||
.map(<LookupOp as Op<Fp>>::as_string)
|
||||
.collect_vec()
|
||||
)));
|
||||
fields
|
||||
}
|
||||
}
|
||||
@@ -527,9 +506,9 @@ impl Node {
|
||||
param_visibility: &Visibility,
|
||||
idx: usize,
|
||||
symbol_values: &SymbolValues,
|
||||
div_rebasing: bool,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
use log::warn;
|
||||
|
||||
trace!("Create {:?}", node);
|
||||
trace!("Create op {:?}", node.op);
|
||||
|
||||
@@ -567,6 +546,7 @@ impl Node {
|
||||
node.clone(),
|
||||
&mut inputs,
|
||||
symbol_values,
|
||||
rebase_frac_zero_constants,
|
||||
)?; // parses the op name
|
||||
|
||||
// we can only take the inputs as mutable once -- so we need to collect them first
|
||||
@@ -622,8 +602,6 @@ impl Node {
|
||||
input_node.bump_scale(out_scale);
|
||||
in_scales[input] = out_scale;
|
||||
}
|
||||
} else {
|
||||
warn!("input {} not found for rescaling, skipping ...", input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -631,7 +609,13 @@ impl Node {
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
opkind = RebaseScale::rebase(
|
||||
opkind,
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -71,7 +71,6 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i12
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_i128(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
|
||||
@@ -244,6 +243,7 @@ pub fn new_op_from_onnx(
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
inputs: &mut [super::NodeType],
|
||||
symbol_values: &SymbolValues,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
|
||||
use crate::circuit::InputType;
|
||||
|
||||
@@ -262,7 +262,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[index].bump_scale(scale);
|
||||
c.rebase_scale(scale)?;
|
||||
inputs[index].replace_opkind(SupportedOp::Constant(c.clone()));
|
||||
Ok(SupportedOp::Linear(PolyOp::Identity))
|
||||
Ok(SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(scale),
|
||||
}))
|
||||
} else {
|
||||
Ok(default_op)
|
||||
}
|
||||
@@ -283,8 +285,8 @@ pub fn new_op_from_onnx(
|
||||
"shift left".to_string(),
|
||||
)));
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(1.0 / 2.0f32.powf(raw_values[0])),
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
@@ -305,8 +307,8 @@ pub fn new_op_from_onnx(
|
||||
"shift right".to_string(),
|
||||
)));
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(2.0f32.powf(raw_values[0])),
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
@@ -545,7 +547,7 @@ pub fn new_op_from_onnx(
|
||||
// Raw values are always f32
|
||||
let raw_value = extract_tensor_value(op.0)?;
|
||||
// If bool or a tensor dimension then don't scale
|
||||
let constant_scale = match dt {
|
||||
let mut constant_scale = match dt {
|
||||
DatumType::Bool
|
||||
| DatumType::TDim
|
||||
| DatumType::I64
|
||||
@@ -560,6 +562,12 @@ pub fn new_op_from_onnx(
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
};
|
||||
|
||||
// if all raw_values are round then set scale to 0
|
||||
let all_round = raw_value.iter().all(|x| (x).fract() == 0.0);
|
||||
if all_round && rebase_frac_zero_constants {
|
||||
constant_scale = 0;
|
||||
}
|
||||
|
||||
// Quantize the raw value
|
||||
let quantized_value =
|
||||
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
|
||||
@@ -666,8 +674,10 @@ pub fn new_op_from_onnx(
|
||||
if unit == 0. {
|
||||
SupportedOp::Nonlinear(LookupOp::ReLU)
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
SupportedOp::Nonlinear(LookupOp::Max {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
}
|
||||
@@ -708,8 +718,11 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Min {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
} else {
|
||||
@@ -717,17 +730,13 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
// Extract the slope layer hyperparams
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let scale_diff = std::cmp::max(scales.input, scales.params) - inputs[0].out_scales()[0];
|
||||
let additional_scale = if scale_diff > 0 {
|
||||
scale_to_multiplier(scale_diff)
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Recip {
|
||||
scale: (scale_to_multiplier(in_scale).powf(2.0) * additional_scale).into(),
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: false,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -752,7 +761,9 @@ pub fn new_op_from_onnx(
|
||||
"Scan" => {
|
||||
return Err("scan should never be analyzed explicitly".into());
|
||||
}
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => SupportedOp::Linear(PolyOp::Identity),
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
@@ -857,11 +868,11 @@ pub fn new_op_from_onnx(
|
||||
}),
|
||||
)?
|
||||
} else {
|
||||
SupportedOp::Linear(PolyOp::Identity)
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
}
|
||||
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
|
||||
SupportedOp::Linear(PolyOp::Identity)
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
}
|
||||
@@ -886,12 +897,15 @@ pub fn new_op_from_onnx(
|
||||
let const_idx = const_idx[0];
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
inputs[const_idx].decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
op = SupportedOp::Nonlinear(LookupOp::Div {
|
||||
// we invert the constant for division
|
||||
denom: crate::circuit::utils::F32(1. / c.raw_values[0]),
|
||||
})
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
let raw_values = 1.0 / c.raw_values[0];
|
||||
if raw_values.log2().fract() == 0.0 {
|
||||
inputs[const_idx].decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
op = SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +237,11 @@ impl VarScales {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
Ok(Self {
|
||||
|
||||
16
src/lib.rs
16
src/lib.rs
@@ -29,7 +29,7 @@
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
use circuit::Tolerance;
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use graph::Visibility;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -91,7 +91,7 @@ pub struct RunArgs {
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
|
||||
pub lookup_range: (i128, i128),
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
pub logrows: u32,
|
||||
@@ -110,6 +110,15 @@ pub struct RunArgs {
|
||||
/// Flags whether params are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
pub param_visibility: Visibility,
|
||||
#[arg(long, default_value = "false")]
|
||||
/// Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[arg(long, default_value = "false")]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[arg(long, default_value = "unsafe")]
|
||||
pub check_mode: CheckMode,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -126,6 +135,9 @@ impl Default for RunArgs {
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
div_rebasing: false,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
#[pyo3(get, set)]
|
||||
pub lookup_range: (i128, i128),
|
||||
pub lookup_range: crate::circuit::table::Range,
|
||||
#[pyo3(get, set)]
|
||||
pub logrows: u32,
|
||||
#[pyo3(get, set)]
|
||||
@@ -159,6 +159,12 @@ struct PyRunArgs {
|
||||
pub param_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
pub check_mode: CheckMode,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -185,6 +191,9 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,6 +212,9 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -511,7 +523,9 @@ fn gen_settings(
|
||||
target = CalibrationTarget::default(), // default is "resources
|
||||
lookup_safety_margin = DEFAULT_LOOKUP_SAFETY_MARGIN.parse().unwrap(),
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
div_rebasing = None,
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
data: PathBuf,
|
||||
@@ -520,7 +534,9 @@ fn calibrate_settings(
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
div_rebasing: Option<bool>,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
@@ -529,6 +545,8 @@ fn calibrate_settings(
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
|
||||
@@ -30,11 +30,11 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use std::cmp::max;
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
use thiserror::Error;
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -1452,6 +1452,43 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
/// # Arguments
|
||||
/// * `self` - Tensor
|
||||
/// * `rhs` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use std::ops::Rem;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.rem(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the broadcasted shape of two tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::get_broadcasted_shape;
|
||||
|
||||
@@ -950,8 +950,7 @@ pub fn neg<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sy
|
||||
/// Elementwise multiplies multiple tensors.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Tensor
|
||||
/// * `t` - Tensors
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -3126,7 +3125,7 @@ pub mod nonlinearities {
|
||||
|
||||
let sum = sum(&exp).unwrap();
|
||||
intermediate_values.push(sum.clone());
|
||||
let inv_denom = recip(&sum, scale.powf(2.0));
|
||||
let inv_denom = recip(&sum, scale, scale);
|
||||
|
||||
((exp * inv_denom).unwrap(), intermediate_values)
|
||||
}
|
||||
@@ -3163,7 +3162,7 @@ pub mod nonlinearities {
|
||||
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
|
||||
let scale = input_scale * output_scale;
|
||||
let diff: Tensor<i128> = sub(t).unwrap();
|
||||
let recip = recip(&t[0], scale as f64);
|
||||
let recip = recip(&t[0], input_scale as f64, output_scale as f64);
|
||||
let product = mult(&[diff, recip]).unwrap();
|
||||
let _tol = ((tol / 100.0) * scale as f32).round() as f64;
|
||||
let upper_bound = greater_than(&product, _tol);
|
||||
@@ -3774,14 +3773,15 @@ pub mod nonlinearities {
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2_f64;
|
||||
/// let result = recip(&x, k);
|
||||
/// let result = recip(&x, 1.0, k);
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn recip(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
|
||||
pub fn recip(a: &Tensor<i128>, input_scale: f64, out_scale: f64) -> Tensor<i128> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let denom = (1_f64) / (a_i as f64 + f64::EPSILON);
|
||||
let d_inv_x = scale * denom;
|
||||
let rescaled = (a_i as f64) / input_scale;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as i128)
|
||||
})
|
||||
.unwrap()
|
||||
|
||||
@@ -871,3 +871,30 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
/// inverts the inner values
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let mut cloned_self = self.clone();
|
||||
|
||||
match &mut cloned_self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.map(|x| match x {
|
||||
ValType::AssignedValue(v) => ValType::AssignedValue(v.invert()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
ValType::AssignedValue(v.value_field().invert())
|
||||
}
|
||||
ValType::Value(v) => ValType::Value(v.map(|x| x.invert().unwrap_or(F::ZERO))),
|
||||
ValType::Constant(v) => ValType::Constant(v.invert().unwrap_or(F::ZERO)),
|
||||
});
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
}
|
||||
};
|
||||
Ok(cloned_self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,95 +182,96 @@ mod native_tests {
|
||||
"mnist_gan",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 5] = [
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
"accuracy",
|
||||
"1l_mlp",
|
||||
"4l_relu_conv_fc",
|
||||
"1l_elu",
|
||||
"1l_prelu",
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
"1l_mlp",
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
// "1l_average",
|
||||
"1l_div",
|
||||
"1l_pad",
|
||||
"1l_pad", // 5
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax",
|
||||
"1l_softmax", //10
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu",
|
||||
"1l_relu", //15
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small",
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc",
|
||||
"3l_relu_conv_fc", //25
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu", //30
|
||||
"min",
|
||||
"1l_elu",
|
||||
"min", //30
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample", //35
|
||||
"1l_identity",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"idolmodel",
|
||||
"trig",
|
||||
"prelu_gmm",
|
||||
"lstm", //40
|
||||
"rnn",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // "variable_cnn",
|
||||
"decision_tree", // 45
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
"xgboost", //50
|
||||
"lightgbm",
|
||||
"xgboost",
|
||||
"lightgbm", //50
|
||||
"hummingbird_decision_tree",
|
||||
"oh_decision_tree",
|
||||
"linear_svc",
|
||||
"gather_elements",
|
||||
"less",
|
||||
"less", //55
|
||||
"xgboost_reg",
|
||||
"1l_powf",
|
||||
"scatter_elements",
|
||||
"1l_linear", //60
|
||||
"linear_regression",
|
||||
"1l_linear",
|
||||
"linear_regression", //60
|
||||
"sklearn_mlp",
|
||||
"1l_mean",
|
||||
"rounding_ops",
|
||||
// "mean_as_constrain",
|
||||
"arange",
|
||||
"layernorm",
|
||||
"layernorm", //65
|
||||
"bitwise_ops",
|
||||
"blackman_window",
|
||||
"softsign", //70
|
||||
"softsign", //68
|
||||
"softplus",
|
||||
"selu",
|
||||
"selu", //70
|
||||
"hard_sigmoid",
|
||||
"log_softmax",
|
||||
"eye",
|
||||
"ltsf",
|
||||
"remainder",
|
||||
"remainder", //75
|
||||
"bitshift",
|
||||
];
|
||||
|
||||
@@ -489,7 +490,7 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
seq!(N in 0..=4 {
|
||||
seq!(N in 0..=5 {
|
||||
#(#[test_case(ACCURACY_CAL_TESTS[N])])*
|
||||
fn mock_accuracy_cal_tests(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -512,13 +513,23 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_div_rebase_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -528,7 +539,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -538,7 +549,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -549,7 +560,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -826,10 +837,10 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, Some(vec![0,1]), true, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
@@ -839,7 +850,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, Some(vec![0,1]), true, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
test_dir.close().unwrap();
|
||||
@@ -855,7 +866,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, Some(vec![0,6]), false, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, None, false, "single");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -865,7 +876,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", Some(vec![0,6]));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -883,7 +894,7 @@ mod native_tests {
|
||||
use test_case::test_case;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify_render_seperately;
|
||||
|
||||
|
||||
use crate::native_tests::kzg_evm_on_chain_input_prove_and_verify;
|
||||
use crate::native_tests::kzg_evm_aggr_prove_and_verify;
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
@@ -1273,6 +1284,7 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -1299,22 +1311,29 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
"-M".to_string(),
|
||||
format!("{}/{}/network.onnx", test_dir, example_name),
|
||||
format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size={}", batch_size),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
args.push("--div-rebasing".to_string());
|
||||
};
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-settings",
|
||||
"-M",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
&format!("--variables=batch_size={}", batch_size),
|
||||
&format!("--input-visibility={}", input_visibility),
|
||||
&format!("--param-visibility={}", param_visibility),
|
||||
&format!("--output-visibility={}", output_visibility),
|
||||
&format!("--num-inner-cols={}", num_inner_columns),
|
||||
])
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1392,6 +1411,7 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
target_perc: f32,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1403,6 +1423,7 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -1661,6 +1682,7 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -1737,6 +1759,7 @@ mod native_tests {
|
||||
"resources",
|
||||
None,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -2010,8 +2033,9 @@ mod native_tests {
|
||||
1,
|
||||
"resources",
|
||||
// we need the accuracy
|
||||
Some(vec![7, 8]),
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -78,14 +78,20 @@ def compare_outputs(zk_output, onnx_output):
|
||||
|
||||
zip_object = zip(np.array(zk_output).flatten(),
|
||||
np.array(onnx_output).flatten())
|
||||
for list1_i, list2_i in zip_object:
|
||||
for (i, (list1_i, list2_i)) in enumerate(zip_object):
|
||||
if list1_i == 0.0 and list2_i == 0.0:
|
||||
res.append(0)
|
||||
else:
|
||||
diff = list1_i - list2_i
|
||||
res.append(100 * (diff) / (list2_i))
|
||||
# iterate and print the diffs if they are greater than 0.0
|
||||
if abs(diff) > 0.0:
|
||||
print("------- index: ", i)
|
||||
print("------- diff: ", diff)
|
||||
print("------- zk_output: ", list1_i)
|
||||
print("------- onnx_output: ", list2_i)
|
||||
|
||||
|
||||
print("res: ", res)
|
||||
|
||||
return np.mean(np.abs(res))
|
||||
|
||||
|
||||
Binary file not shown.
@@ -1 +1,61 @@
|
||||
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":0,"scale_rebase_multiplier":10,"lookup_range":[-2,0],"logrows":6,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_rows":16,"total_assignments":32,"total_const_size":8,"model_instance_shapes":[[1,4]],"model_output_scales":[0],"model_input_scales":[0],"module_sizes":{"kzg":[],"poseidon":[0,[0]]},"required_lookups":["ReLU"],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1702474230544}
|
||||
{
|
||||
"run_args": {
|
||||
"tolerance": {
|
||||
"val": 0.0,
|
||||
"scale": 1.0
|
||||
},
|
||||
"input_scale": 0,
|
||||
"param_scale": 0,
|
||||
"scale_rebase_multiplier": 10,
|
||||
"lookup_range": [
|
||||
-2,
|
||||
0
|
||||
],
|
||||
"logrows": 6,
|
||||
"num_inner_cols": 2,
|
||||
"variables": [
|
||||
[
|
||||
"batch_size",
|
||||
1
|
||||
]
|
||||
],
|
||||
"input_visibility": "Private",
|
||||
"output_visibility": "Public",
|
||||
"param_visibility": "Private",
|
||||
"div_rebasing": false,
|
||||
"rebase_frac_zero_constants": false,
|
||||
"check_mode": "UNSAFE"
|
||||
},
|
||||
"num_rows": 16,
|
||||
"total_assignments": 32,
|
||||
"total_const_size": 8,
|
||||
"model_instance_shapes": [
|
||||
[
|
||||
1,
|
||||
4
|
||||
]
|
||||
],
|
||||
"model_output_scales": [
|
||||
0
|
||||
],
|
||||
"model_input_scales": [
|
||||
0
|
||||
],
|
||||
"module_sizes": {
|
||||
"kzg": [],
|
||||
"poseidon": [
|
||||
0,
|
||||
[
|
||||
0
|
||||
]
|
||||
]
|
||||
},
|
||||
"required_lookups": [
|
||||
"ReLU"
|
||||
],
|
||||
"required_range_checks": [],
|
||||
"check_mode": "UNSAFE",
|
||||
"version": "0.0.0",
|
||||
"num_blinding_factors": null,
|
||||
"timestamp": 1702474230544
|
||||
}
|
||||
Reference in New Issue
Block a user