mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
523c77c912 | ||
|
|
948e5cd4b9 | ||
|
|
00155e585f | ||
|
|
0876faa12c |
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -240,6 +240,8 @@ jobs:
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
|
||||
@@ -592,7 +592,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -648,10 +648,10 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -171,7 +171,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -328,7 +328,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "171702d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 27,
|
||||
"id": "671dfdd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -364,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 28,
|
||||
"id": "50eba2f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -399,9 +399,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
42
examples/onnx/log/gen.py
Normal file
42
examples/onnx/log/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.log(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 3)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/log/input.json
Normal file
1
examples/onnx/log/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}
|
||||
14
examples/onnx/log/network.onnx
Normal file
14
examples/onnx/log/network.onnx
Normal file
@@ -0,0 +1,14 @@
|
||||
pytorch2.2.2:o
|
||||
|
||||
inputoutput/Log"Log
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -1 +1,148 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.6261028051376343, 0.49872446060180664, -0.04514765739440918, 0.5936200618743896, 0.9271858930587769, 0.6688600778579712, -0.20331168174743652, -0.7016235589981079, 0.025863051414489746, -0.19426143169403076, 0.9827852249145508, 0.4897397756576538, 0.2992602586746216, 0.7011144161224365, 0.9278832674026489, 0.5943725109100342, -0.573331356048584, 0.3675816059112549], [0.7803324460983276, -0.9616303443908691, 0.6070173978805542, -0.028337717056274414, -0.5080242156982422, -0.9280107021331787, 0.6150380373001099, 0.3865993022918701, -0.43668973445892334, 0.17152702808380127, 0.5144252777099609, -0.28881049156188965, 0.8932310342788696, 0.059034109115600586, 0.6865451335906982, 0.009820222854614258, 0.23011493682861328, -0.9492779970169067], [-0.21352827548980713, -0.16015326976776123, -0.38964390754699707, 0.13464701175689697, -0.8814496994018555, 0.5037975311279297, -0.804405927658081, 0.9858957529067993, 0.19567716121673584, 0.9777265787124634, 0.6151977777481079, 0.568595290184021, 0.10584986209869385, -0.8975653648376465, 0.6235959529876709, -0.547879695892334, 0.9289869070053101, 0.7567293643951416]], "output_data": [[1.0, 0.0, -0.0, 1.0, 1.0, 1.0, -0.0, -1.0, 0.0, -0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 0.0], [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0], [-0.0, -0.0, -0.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0]]}
|
||||
{
|
||||
"input_shapes": [
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
]
|
||||
],
|
||||
"input_data": [
|
||||
[
|
||||
0.5,
|
||||
1.5,
|
||||
-0.04514765739440918,
|
||||
0.5936200618743896,
|
||||
0.9271858930587769,
|
||||
0.6688600778579712,
|
||||
-0.20331168174743652,
|
||||
-0.7016235589981079,
|
||||
0.025863051414489746,
|
||||
-0.19426143169403076,
|
||||
0.9827852249145508,
|
||||
0.4897397756576538,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.9278832674026489,
|
||||
0.5943725109100342,
|
||||
-0.573331356048584,
|
||||
0.3675816059112549
|
||||
],
|
||||
[
|
||||
0.7803324460983276,
|
||||
-0.9616303443908691,
|
||||
0.6070173978805542,
|
||||
-0.028337717056274414,
|
||||
-0.5080242156982422,
|
||||
-0.9280107021331787,
|
||||
0.6150380373001099,
|
||||
0.3865993022918701,
|
||||
-0.43668973445892334,
|
||||
0.17152702808380127,
|
||||
0.5144252777099609,
|
||||
-0.28881049156188965,
|
||||
0.8932310342788696,
|
||||
0.059034109115600586,
|
||||
0.6865451335906982,
|
||||
0.009820222854614258,
|
||||
0.23011493682861328,
|
||||
-0.9492779970169067
|
||||
],
|
||||
[
|
||||
-0.21352827548980713,
|
||||
-0.16015326976776123,
|
||||
-0.38964390754699707,
|
||||
0.13464701175689697,
|
||||
-0.8814496994018555,
|
||||
0.5037975311279297,
|
||||
-0.804405927658081,
|
||||
0.9858957529067993,
|
||||
0.19567716121673584,
|
||||
0.9777265787124634,
|
||||
0.6151977777481079,
|
||||
0.568595290184021,
|
||||
0.10584986209869385,
|
||||
-0.8975653648376465,
|
||||
0.6235959529876709,
|
||||
-0.547879695892334,
|
||||
0.9289869070053101,
|
||||
0.7567293643951416
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
0.0
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0
|
||||
],
|
||||
[
|
||||
-0.0,
|
||||
-0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0
|
||||
]
|
||||
]
|
||||
}
|
||||
@@ -180,9 +180,6 @@ struct PyRunArgs {
|
||||
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Should constants with 0.0 fraction be rebased to scale 0
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
@@ -197,6 +194,9 @@ struct PyRunArgs {
|
||||
/// int: The number of legs used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_legs: usize,
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -212,6 +212,7 @@ impl PyRunArgs {
|
||||
impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
@@ -223,7 +224,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
@@ -236,6 +236,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
@@ -247,7 +248,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
@@ -873,8 +873,6 @@ fn gen_settings(
|
||||
/// max_logrows: int
|
||||
/// Optional max logrows to use for calibration
|
||||
///
|
||||
/// only_range_check_rebase: bool
|
||||
/// Check ranges when rebasing
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -889,7 +887,6 @@ fn gen_settings(
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
@@ -901,7 +898,6 @@ fn calibrate_settings(
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
@@ -912,7 +908,6 @@ fn calibrate_settings(
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -13,6 +13,20 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
@@ -31,7 +45,6 @@ pub enum HybridOp {
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
ReduceMax {
|
||||
axes: Vec<usize>,
|
||||
@@ -108,11 +121,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RSQRT (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Max => format!("MAX"),
|
||||
HybridOp::Min => format!("MIN"),
|
||||
|
||||
HybridOp::Max => "MAX".to_string(),
|
||||
HybridOp::Min => "MIN".to_string(),
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
@@ -120,13 +146,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"DIV (denom={}, use_range_check_for_int={})",
|
||||
denom, use_range_check_for_int
|
||||
),
|
||||
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -181,6 +201,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => layouts::rsqrt(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
)?,
|
||||
HybridOp::Sqrt { scale } => {
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
@@ -216,13 +253,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?,
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::loop_div(
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
layouts::div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
@@ -313,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
|
||||
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
output_scale,
|
||||
input_scale,
|
||||
..
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -29,41 +29,96 @@ use crate::{
|
||||
use super::*;
|
||||
use crate::circuit::ops::lookup::LookupOp;
|
||||
|
||||
/// Same as div but splits the division into N parts
|
||||
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Calculate the L1 distance between two tensors.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::l1_distance;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = l1_distance::<Fp>(&dummy_config, &mut dummy_region, &[x, k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1, 1, 1, 2, 2, 2]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn l1_distance<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
divisor: F,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if divisor == F::ONE {
|
||||
return Ok(value[0].clone());
|
||||
}
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let abs_diff = abs(config, region, &[diff])?;
|
||||
|
||||
// if integer val is divisible by 2, we can use a faster method and div > F::S
|
||||
let mut divisor = divisor;
|
||||
let mut num_parts = 1;
|
||||
Ok(abs_diff)
|
||||
}
|
||||
|
||||
while felt_to_integer_rep(divisor) % 2 == 0
|
||||
&& felt_to_integer_rep(divisor) > (2_i128.pow(F::S - 4))
|
||||
{
|
||||
divisor = integer_rep_to_felt(felt_to_integer_rep(divisor) / 2);
|
||||
num_parts += 1;
|
||||
}
|
||||
/// Determines if from a set of 3 tensors the 1st is closest to a reference tensor.
|
||||
/// should only be used in the context of a monotonic function like the product used in the division, recipe, and sqrt arguments;
|
||||
/// or the increasing powers of 2 in the ln argument. Which is used to construct a convex error function.
|
||||
fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>,
|
||||
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
|
||||
) -> Result<(), CircuitError> {
|
||||
let one = create_constant_tensor(F::from(1), 1);
|
||||
|
||||
let output = div(config, region, value, divisor)?;
|
||||
if num_parts == 1 {
|
||||
return Ok(output);
|
||||
}
|
||||
let f_x = f(config, region, x)?;
|
||||
|
||||
let divisor_int = 2_i128.pow(num_parts - 1);
|
||||
let divisor_felt = integer_rep_to_felt(divisor_int);
|
||||
if divisor_int <= 2_i128.pow(F::S - 3) {
|
||||
div(config, region, &[output], divisor_felt)
|
||||
} else {
|
||||
// keep splitting the divisor until it satisfies the condition
|
||||
loop_div(config, region, &[output], divisor_felt)
|
||||
}
|
||||
let x_plus_1 = pairwise(config, region, &[x.clone(), one.clone()], BaseOp::Add)?;
|
||||
let f_x_plus_1 = f(config, region, &x_plus_1)?;
|
||||
|
||||
let x_minus_1 = pairwise(config, region, &[x.clone(), one.clone()], BaseOp::Sub)?;
|
||||
let f_x_minus_1 = f(config, region, &x_minus_1)?;
|
||||
|
||||
// because the function is convex, the result should be the minimum of the three
|
||||
// not that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) <= f(x-1)
|
||||
// the result is 1 if the function is optimal solely because of the convexity of the function
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal (or f(x) and f(x-1)).
|
||||
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
|
||||
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Err is less than some constant
|
||||
pub fn diff_less_than<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
constant: F,
|
||||
) -> Result<(), CircuitError> {
|
||||
let distance = l1_distance(config, region, values)?;
|
||||
|
||||
let constant = create_constant_tensor(constant, 1);
|
||||
let is_less = less(config, region, &[distance.clone(), constant.clone()])?;
|
||||
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, is_less.len());
|
||||
enforce_equality(config, region, &[is_less, comparison_unit])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Div accumulated layout
|
||||
@@ -80,13 +135,8 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_integer_rep(div) / 2;
|
||||
|
||||
let divisor = create_constant_tensor(div, 1);
|
||||
|
||||
let divisor = region.assign(&config.custom_gates.inputs[1], &divisor)?;
|
||||
region.increment(divisor.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
@@ -117,19 +167,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), input.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
diff_less_than(config, region, &[input.clone(), product.clone()], div)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -145,19 +183,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let integer_input_scale = felt_to_integer_rep(input_scale);
|
||||
let integer_output_scale = felt_to_integer_rep(output_scale);
|
||||
|
||||
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
|
||||
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
|
||||
|
||||
let input_scale_ratio = if range_check_len > 0 {
|
||||
integer_rep_to_felt(integer_input_scale * integer_output_scale / range_check_len)
|
||||
} else {
|
||||
F::ONE
|
||||
};
|
||||
|
||||
let range_check_bracket = range_check_len / 2;
|
||||
let unit_scale = create_constant_tensor(output_scale * input_scale, 1);
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
@@ -183,25 +209,22 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// this is now of scale 2 * scale
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), input.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// divide by input_scale
|
||||
let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?;
|
||||
|
||||
let zero_inverse_val =
|
||||
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64)[0];
|
||||
let zero_inverse = create_constant_tensor(integer_rep_to_felt(zero_inverse_val), 1);
|
||||
|
||||
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
|
||||
|
||||
let not_equal_zero_mask = not(config, region, &[equal_zero_mask.clone()])?;
|
||||
let equal_inverse_mask = equals(config, region, &[claimed_output.clone(), zero_inverse])?;
|
||||
|
||||
let masked_unit_scale = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[unit_scale.clone(), not_equal_zero_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// assert the two masks are equal
|
||||
enforce_equality(
|
||||
config,
|
||||
@@ -209,24 +232,135 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let unit_scale = create_constant_tensor(integer_rep_to_felt(range_check_len), 1);
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
-> Result<ValTensor<F>, CircuitError> {
|
||||
let product = pairwise(config, region, &[x.clone(), input.clone()], BaseOp::Mult)?;
|
||||
|
||||
let unit_mask = pairwise(config, region, &[equal_zero_mask, unit_scale], BaseOp::Mult)?;
|
||||
let distance = l1_distance(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), masked_unit_scale.clone()],
|
||||
)?;
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
// now add the unit mask to the rebased_div
|
||||
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[rebased_offset_div],
|
||||
&(range_check_bracket, 3 * range_check_bracket),
|
||||
)?;
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Square root accumulated layout
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::sqrt;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 9]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = sqrt::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into()).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 1, 2, 2, 2, 2, 3]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let unit_scale = create_constant_tensor(integer_rep_to_felt(input_scale.0 as IntegerRep), 1);
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.int_evals()?;
|
||||
tensor::ops::nonlinearities::sqrt(&input_evals, input_scale.0 as f64)
|
||||
.par_iter()
|
||||
.map(|x| Value::known(integer_rep_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// force the output to be positive or zero
|
||||
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
// rescaled input
|
||||
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
-> Result<ValTensor<F>, CircuitError> {
|
||||
let product = pairwise(config, region, &[x.clone(), x.clone()], BaseOp::Mult)?;
|
||||
let distance = l1_distance(config, region, &[product.clone(), rescaled_input.clone()])?;
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Reciprocal square root accumulated layout
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::rsqrt;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = rsqrt::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into(), 1.0.into()).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 1, 1, 1, 1, 1, 1]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn rsqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let sqrt = sqrt(config, region, value, input_scale)?;
|
||||
|
||||
let felt_output_scale = integer_rep_to_felt(output_scale.0 as IntegerRep);
|
||||
let felt_input_scale = integer_rep_to_felt(input_scale.0 as IntegerRep);
|
||||
|
||||
let recip = recip(config, region, &[sqrt], felt_input_scale, felt_output_scale)?;
|
||||
|
||||
Ok(recip)
|
||||
}
|
||||
|
||||
/// Dot product of two tensors.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -1805,6 +1939,10 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() == 1 {
|
||||
return Ok(values[0].clone());
|
||||
}
|
||||
|
||||
region.flush()?;
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
@@ -3102,7 +3240,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
last_elem = loop_div(config, region, &[last_elem], F::from(kernel_len as u64))?;
|
||||
last_elem = div(config, region, &[last_elem], F::from(kernel_len as u64))?;
|
||||
}
|
||||
Ok(last_elem)
|
||||
}
|
||||
@@ -4507,6 +4645,293 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
)
|
||||
}
|
||||
|
||||
/// integer ln layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::ln;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = ln::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into()).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 0, 4, -8]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// first generate the claimed val
|
||||
|
||||
let mut input = values[0].clone();
|
||||
let scale_as_felt = integer_rep_to_felt(scale.0.round() as IntegerRep);
|
||||
|
||||
let triple_scaled_as_felt_tensor =
|
||||
create_constant_tensor(scale_as_felt * scale_as_felt * scale_as_felt, 1);
|
||||
|
||||
// natural ln is log2(x) * ln(2)
|
||||
let ln2 = utils::F32::from(2.0_f32.ln());
|
||||
// now create a constant tensor for ln2 with scale
|
||||
let ln2_tensor: ValTensor<F> = create_constant_tensor(
|
||||
integer_rep_to_felt((ln2.0 * scale.0).round() as IntegerRep),
|
||||
1,
|
||||
);
|
||||
let unit = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1);
|
||||
|
||||
// 2. assign the image
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
// don't need to increment because the claimed output is assigned to output and incremented accordingly
|
||||
}
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.int_evals()?;
|
||||
// returns an integer with the base 2 logarithm
|
||||
tensor::ops::nonlinearities::ilog2(&input_evals.clone(), scale.0 as f64)
|
||||
.par_iter()
|
||||
.map(|x| Value::known(integer_rep_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input.dims())?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let pow2_of_claimed_output = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone()],
|
||||
&LookupOp::PowersOfTwo { scale },
|
||||
)?;
|
||||
|
||||
let num_bits = (std::mem::size_of::<IntegerRep>() * 8) as IntegerRep;
|
||||
|
||||
region.update_max_min_lookup_inputs_force(-num_bits, num_bits)?;
|
||||
|
||||
// now subtract 1 from the claimed output
|
||||
let claimed_output_minus_one = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), unit.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// now add 1 to the claimed output
|
||||
let claimed_output_plus_one = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), unit.clone()],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
// prior power of 2 is less than claimed output
|
||||
let prior_pow2 = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output_minus_one],
|
||||
&LookupOp::PowersOfTwo { scale },
|
||||
)?;
|
||||
|
||||
// next power of 2 is greater than claimed output
|
||||
let next_pow2 = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output_plus_one],
|
||||
&LookupOp::PowersOfTwo { scale },
|
||||
)?;
|
||||
|
||||
let distance_to_claimed = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[input.clone(), pow2_of_claimed_output.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let abs_distance_to_claimed = abs(config, region, &[distance_to_claimed.clone()])?;
|
||||
|
||||
let abs_distance_to_next_pow2 =
|
||||
l1_distance(config, region, &[input.clone(), next_pow2.clone()])?;
|
||||
|
||||
let abs_distance_to_prior_pow2 =
|
||||
l1_distance(config, region, &[input.clone(), prior_pow2.clone()])?;
|
||||
|
||||
// because we round up this can be equal
|
||||
let is_closest_to_0: ValTensor<F> = less(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
abs_distance_to_claimed.clone(),
|
||||
abs_distance_to_next_pow2.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
let is_closest_to_1 = less(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
abs_distance_to_claimed.clone(),
|
||||
abs_distance_to_prior_pow2.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
let is_closest = and(
|
||||
config,
|
||||
region,
|
||||
&[is_closest_to_0.clone(), is_closest_to_1.clone()],
|
||||
)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len());
|
||||
comparison_unit.reshape(is_closest.dims())?;
|
||||
let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?;
|
||||
|
||||
enforce_equality(config, region, &[is_closest, assigned_unit])?;
|
||||
|
||||
// get a linear interpolation now
|
||||
|
||||
let sign_of_distance_to_claimed = sign(config, region, &[distance_to_claimed.clone()])?;
|
||||
let sign_of_distance_to_claimed_is_negative = equals(
|
||||
config,
|
||||
region,
|
||||
&[sign_of_distance_to_claimed.clone(), negative_one.clone()],
|
||||
)?;
|
||||
|
||||
let sign_of_distance_to_claimed_is_positive = not(
|
||||
config,
|
||||
region,
|
||||
&[sign_of_distance_to_claimed_is_negative.clone()],
|
||||
)?;
|
||||
|
||||
let pow2_prior_to_claimed_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[pow2_of_claimed_output.clone(), prior_pow2.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let pow2_next_to_claimed_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[next_pow2.clone(), pow2_of_claimed_output.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let recip_pow2_prior_to_claimed_distance = recip(
|
||||
config,
|
||||
region,
|
||||
&[pow2_prior_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
)?;
|
||||
|
||||
let interpolated_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
recip_pow2_prior_to_claimed_distance.clone(),
|
||||
distance_to_claimed.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let gated_prior_interpolated_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
interpolated_distance.clone(),
|
||||
sign_of_distance_to_claimed_is_negative.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let recip_next_to_claimed_distance = recip(
|
||||
config,
|
||||
region,
|
||||
&[pow2_next_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
)?;
|
||||
|
||||
let interpolated_distance_next = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
recip_next_to_claimed_distance.clone(),
|
||||
distance_to_claimed.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let gated_next_interpolated_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
interpolated_distance_next.clone(),
|
||||
sign_of_distance_to_claimed_is_positive.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let scaled_claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), triple_scaled_as_felt_tensor],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
scaled_claimed_output.clone(),
|
||||
gated_prior_interpolated_distance.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
claimed_output.clone(),
|
||||
gated_next_interpolated_distance.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
// now multiply the claimed output by ln2
|
||||
pairwise(config, region, &[claimed_output, ln2_tensor], BaseOp::Mult)
|
||||
}
|
||||
|
||||
/// round layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
@@ -4551,11 +4976,7 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let one = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?;
|
||||
let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1);
|
||||
let assigned_negative_one = region.assign(&config.custom_gates.output, &negative_one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
// if scale is not exactly divisible by 2 we warn
|
||||
if scale.0 % 2.0 != 0.0 {
|
||||
@@ -4588,8 +5009,8 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_positive = equals(config, region, &[sign.clone(), assigned_one.clone()])?;
|
||||
let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?;
|
||||
let is_positive = equals(config, region, &[sign.clone(), one.clone()])?;
|
||||
let is_negative = equals(config, region, &[sign, negative_one.clone()])?;
|
||||
|
||||
let is_greater_than_midway = greater_equal(
|
||||
config,
|
||||
@@ -4654,6 +5075,146 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
)
|
||||
}
|
||||
|
||||
/// round half to even layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::round;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = round::<Fp>(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn round_half_to_even<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
// if scale is not exactly divisible by 2 we warn
|
||||
if scale.0 % 2.0 != 0.0 {
|
||||
log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate");
|
||||
}
|
||||
|
||||
let midway_point: ValTensor<F> = create_constant_tensor(
|
||||
integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep),
|
||||
1,
|
||||
);
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let penultimate_elem =
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?;
|
||||
|
||||
let is_equal_to_midway =
|
||||
equals(config, region, &[last_elem.clone(), midway_point.clone()])?;
|
||||
// penultimate_elem is equal to midway point and even, do nothing
|
||||
let is_odd = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[penultimate_elem.clone()],
|
||||
&LookupOp::IsOdd,
|
||||
)?;
|
||||
|
||||
let is_odd_and_equal_to_midway = and(
|
||||
config,
|
||||
region,
|
||||
&[is_odd.clone(), is_equal_to_midway.clone()],
|
||||
)?;
|
||||
|
||||
let is_greater_than_midway =
|
||||
greater(config, region, &[last_elem.clone(), midway_point.clone()])?;
|
||||
|
||||
// if the number is equal to midway point and odd increment, or if it is is_greater_than_midway
|
||||
let is_odd_and_equal_to_midway_or_greater_than_midway = or(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
is_odd_and_equal_to_midway.clone(),
|
||||
is_greater_than_midway.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
is_odd_and_equal_to_midway_or_greater_than_midway.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.update_max_min_lookup_inputs_force(0, scale.0 as IntegerRep)?;
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -5019,11 +5580,8 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
input_felt_scale,
|
||||
output_felt_scale,
|
||||
)?;
|
||||
// product of num * (1 / denom) = 2*output_scale
|
||||
let percent = pairwise(config, region, &[input, inv_denom], BaseOp::Mult)?;
|
||||
|
||||
// rebase the percent to 2x the scale
|
||||
loop_div(config, region, &[percent], input_felt_scale)
|
||||
// product of num * (1 / denom) = input_scale * output_scale
|
||||
pairwise(config, region, &[input, inv_denom], BaseOp::Mult)
|
||||
}
|
||||
|
||||
/// Applies softmax
|
||||
@@ -5047,7 +5605,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
/// ).unwrap());
|
||||
/// let result = softmax::<Fp>(&dummy_config, &mut dummy_region, &[x], 128.0.into(), (128.0 * 128.0).into()).unwrap();
|
||||
/// // doubles the scale of the input
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2734, 2734, 2756, 2734, 2734, 2691]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[350012, 350012, 352768, 350012, 350012, 344500]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
@@ -5127,17 +5685,8 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let int_scale = scale.0 as IntegerRep;
|
||||
// felt scale
|
||||
let felt_scale = integer_rep_to_felt(int_scale);
|
||||
// range check len capped at 2^(S-3) and make it divisible 2
|
||||
let range_check_bracket = std::cmp::min(
|
||||
utils::F32(scale.0),
|
||||
utils::F32(2_f32.powf((F::S - 5) as f32)),
|
||||
)
|
||||
.0;
|
||||
|
||||
let range_check_bracket_int = range_check_bracket as IntegerRep;
|
||||
|
||||
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
|
||||
let input_scale_ratio = ((scale.0.powf(2.0) / range_check_bracket) * tol) as IntegerRep / 2 * 2;
|
||||
let input_scale_ratio = (scale.0 * tol) as IntegerRep / 2 * 2;
|
||||
|
||||
let recip = recip(
|
||||
config,
|
||||
@@ -5153,7 +5702,7 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
|
||||
log::debug!("product: {}", product.show());
|
||||
let rebased_product = loop_div(
|
||||
let rebased_product = div(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
@@ -5162,10 +5711,5 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
log::debug!("rebased_product: {}", rebased_product.show());
|
||||
|
||||
// check that it is within the tolerance range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[rebased_product],
|
||||
&(-range_check_bracket_int, range_check_bracket_int),
|
||||
)
|
||||
range_check(config, region, &[rebased_product], &(-int_scale, int_scale))
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -16,12 +15,10 @@ use halo2curves::ff::PrimeField;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Div { denom: utils::F32 },
|
||||
Cast { scale: utils::F32 },
|
||||
RoundHalfToEven { scale: utils::F32 },
|
||||
Sqrt { scale: utils::F32 },
|
||||
Rsqrt { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
IsOdd,
|
||||
PowersOfTwo { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
@@ -51,16 +48,14 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
|
||||
LookupOp::IsOdd => "is_odd".to_string(),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
@@ -85,36 +80,28 @@ impl LookupOp {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::PowersOfTwo { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
|
||||
}
|
||||
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Cast { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
|
||||
),
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Rsqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Erf { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Exp { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
|
||||
}
|
||||
@@ -171,15 +158,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
|
||||
LookupOp::IsOdd => "IS_ODD".to_string(),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
|
||||
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
|
||||
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
|
||||
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
|
||||
@@ -213,13 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
let scale = inputs_scale[0];
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
|
||||
@@ -474,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min forcefully
|
||||
pub fn update_max_min_lookup_inputs_force(
|
||||
&mut self,
|
||||
min: IntegerRep,
|
||||
max: IntegerRep,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
|
||||
@@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get largest element represented by the range
|
||||
pub fn largest(&self) -> IntegerRep {
|
||||
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
|
||||
}
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"{}_{}_{}",
|
||||
self.nonlinearity.as_path(),
|
||||
self.range.0,
|
||||
self.range.1
|
||||
self.largest()
|
||||
)
|
||||
}
|
||||
/// Configures the table.
|
||||
@@ -222,7 +226,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
let largest = self.largest();
|
||||
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
@@ -291,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
row_offset += chunk_idx * self.col_size;
|
||||
let (x, y) = self.cartesian_coord(row_offset);
|
||||
|
||||
if !preassigned_input {
|
||||
table.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
|
||||
@@ -333,18 +333,6 @@ impl<'source> FromPyObject<'source> for ContractType {
|
||||
}
|
||||
}
|
||||
}
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
lazy_static! {
|
||||
/// The version of the ezkl library
|
||||
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
|
||||
"source - no compatibility guaranteed"
|
||||
} else {
|
||||
env!("CARGO_PKG_VERSION")
|
||||
};
|
||||
}
|
||||
|
||||
/// Get the styles for the CLI
|
||||
pub fn get_styles() -> clap::builder::Styles {
|
||||
@@ -395,7 +383,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, about, long_about = None)]
|
||||
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
|
||||
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
|
||||
pub struct Cli {
|
||||
/// If provided, outputs the completion file for given shell
|
||||
#[clap(long = "generate", value_parser)]
|
||||
@@ -486,9 +474,6 @@ pub enum Commands {
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
|
||||
only_range_check_rebase: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
|
||||
@@ -140,7 +140,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
@@ -149,7 +148,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -671,10 +669,10 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
pb.finish_with_message("SRS validated.");
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
|
||||
@@ -682,7 +680,10 @@ pub(crate) async fn get_srs_cmd(
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to {}.", computed_srs_path.as_os_str().to_str().unwrap_or("disk"));
|
||||
info!(
|
||||
"Saved SRS to {}.",
|
||||
computed_srs_path.as_os_str().to_str().unwrap_or("disk")
|
||||
);
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -728,7 +729,7 @@ pub(crate) async fn gen_witness(
|
||||
None
|
||||
};
|
||||
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
@@ -968,7 +969,6 @@ pub(crate) async fn calibrate(
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, EZKLError> {
|
||||
use log::error;
|
||||
@@ -1004,12 +1004,6 @@ pub(crate) async fn calibrate(
|
||||
(11..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
// 2 x 2 grid
|
||||
@@ -1047,12 +1041,6 @@ pub(crate) async fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
@@ -1061,30 +1049,23 @@ pub(crate) async fn calibrate(
|
||||
let mut num_failed = 0;
|
||||
let mut num_passed = 0;
|
||||
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, fail: {}, pass: {}",
|
||||
input_scale.to_string().blue(),
|
||||
param_scale.to_string().blue(),
|
||||
scale_rebase_multiplier.to_string().blue(),
|
||||
div_rebasing.to_string().yellow(),
|
||||
scale_rebase_multiplier.to_string().yellow(),
|
||||
num_failed.to_string().red(),
|
||||
num_passed.to_string().green()
|
||||
));
|
||||
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
@@ -1188,7 +1169,6 @@ pub(crate) async fn calibrate(
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
@@ -1296,7 +1276,6 @@ pub(crate) async fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -2022,7 +2001,7 @@ pub(crate) fn mock_aggregate(
|
||||
}
|
||||
}
|
||||
// proof aggregation
|
||||
let pb = {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
pb
|
||||
@@ -2033,7 +2012,7 @@ pub(crate) fn mock_aggregate(
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
pb.finish_with_message("Done.");
|
||||
pb.finish_with_message("Done.");
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -2127,7 +2106,7 @@ pub(crate) fn aggregate(
|
||||
}
|
||||
|
||||
// proof aggregation
|
||||
let pb = {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
pb
|
||||
@@ -2276,7 +2255,7 @@ pub(crate) fn aggregate(
|
||||
);
|
||||
snark.save(&proof_path)?;
|
||||
|
||||
pb.finish_with_message("Done.");
|
||||
pb.finish_with_message("Done.");
|
||||
|
||||
Ok(snark)
|
||||
}
|
||||
|
||||
@@ -133,6 +133,8 @@ pub struct GraphWitness {
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// max range check size
|
||||
pub max_range_size: IntegerRep,
|
||||
/// (optional) version of ezkl used
|
||||
pub version: Option<String>,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -161,6 +163,7 @@ impl GraphWitness {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
version: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1350,6 +1353,7 @@ impl GraphCircuit {
|
||||
max_lookup_inputs: model_results.max_lookup_inputs,
|
||||
min_lookup_inputs: model_results.min_lookup_inputs,
|
||||
max_range_size: model_results.max_range_size,
|
||||
version: Some(crate::version().to_string()),
|
||||
};
|
||||
|
||||
witness.generate_rescaled_elements(
|
||||
|
||||
@@ -915,20 +915,9 @@ impl Model {
|
||||
if scales.contains_key(&i) {
|
||||
let scale_diff = n.out_scale - scales[&i];
|
||||
n.opkind = if scale_diff > 0 {
|
||||
RebaseScale::rebase(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
1,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
|
||||
} else {
|
||||
RebaseScale::rebase_up(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
|
||||
};
|
||||
n.out_scale = scales[&i];
|
||||
}
|
||||
|
||||
@@ -120,7 +120,6 @@ impl RebaseScale {
|
||||
global_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
scale_rebase_multiplier: u32,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
@@ -137,7 +136,6 @@ impl RebaseScale {
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op.original_scale,
|
||||
})
|
||||
@@ -148,7 +146,6 @@ impl RebaseScale {
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op_out_scale,
|
||||
})
|
||||
@@ -163,7 +160,6 @@ impl RebaseScale {
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
|
||||
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
|
||||
@@ -176,7 +172,6 @@ impl RebaseScale {
|
||||
original_scale: op.original_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
@@ -187,7 +182,6 @@ impl RebaseScale {
|
||||
original_scale: op_out_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -595,13 +589,7 @@ impl Node {
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(
|
||||
opkind,
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
run_args.div_rebasing,
|
||||
);
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -764,7 +764,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if const_inputs.len() > 0 {
|
||||
if !const_inputs.is_empty() {
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
@@ -803,7 +803,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
@@ -837,61 +837,75 @@ pub fn new_op_from_onnx(
|
||||
"Abs" => SupportedOp::Linear(PolyOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Sqrt" => SupportedOp::Hybrid(HybridOp::Sqrt {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
})
|
||||
}
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Ln" => {
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Source" => {
|
||||
let dt = node.outputs[0].fact.datum_type;
|
||||
@@ -935,11 +949,9 @@ pub fn new_op_from_onnx(
|
||||
replace_const(
|
||||
0,
|
||||
0,
|
||||
SupportedOp::Nonlinear(LookupOp::Cast {
|
||||
scale: crate::circuit::utils::F32(scale_to_multiplier(
|
||||
input_scales[0],
|
||||
)
|
||||
as f32),
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
)?
|
||||
} else {
|
||||
@@ -1045,7 +1057,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Softmax {
|
||||
@@ -1084,19 +1096,20 @@ pub fn new_op_from_onnx(
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
@@ -1116,7 +1129,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
a: crate::circuit::utils::F32(exponent),
|
||||
})
|
||||
}
|
||||
|
||||
22
src/lib.rs
22
src/lib.rs
@@ -108,6 +108,18 @@ use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
/// The version of the ezkl library
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Get the version of the library
|
||||
pub fn version() -> &'static str {
|
||||
match VERSION {
|
||||
"0.0.0" => "source - no compatibility guaranteed",
|
||||
_ => VERSION,
|
||||
}
|
||||
}
|
||||
|
||||
/// Bindings managment
|
||||
#[cfg(any(
|
||||
feature = "ios-bindings",
|
||||
@@ -297,8 +309,6 @@ pub struct RunArgs {
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
@@ -317,11 +327,18 @@ pub struct RunArgs {
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
@@ -333,7 +350,6 @@ impl Default for RunArgs {
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
div_rebasing: false,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: None,
|
||||
|
||||
@@ -59,10 +59,7 @@ fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum ProofType {
|
||||
#[default]
|
||||
Single,
|
||||
@@ -134,10 +131,7 @@ impl<'source> pyo3::FromPyObject<'source> for ProofType {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum StrategyType {
|
||||
Single,
|
||||
Accum,
|
||||
@@ -203,10 +197,7 @@ pub enum PfSysError {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum TranscriptType {
|
||||
Poseidon,
|
||||
#[default]
|
||||
@@ -324,6 +315,8 @@ where
|
||||
pub timestamp: Option<u128>,
|
||||
/// commitment
|
||||
pub commitment: Option<Commitments>,
|
||||
/// (optional) version of ezkl used to generate the proof
|
||||
version: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -385,6 +378,7 @@ where
|
||||
.as_millis(),
|
||||
),
|
||||
commitment,
|
||||
version: Some(crate::version().to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -920,6 +914,7 @@ mod tests {
|
||||
pretty_public_inputs: None,
|
||||
timestamp: None,
|
||||
commitment: None,
|
||||
version: None,
|
||||
};
|
||||
|
||||
snark
|
||||
|
||||
@@ -1474,6 +1474,93 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Checks if a tensor's elements are odd
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::is_odd;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let result = is_odd(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn is_odd(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rounded = if a_i % 2 == 0 { 0 } else { 1 };
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Powers of 2
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ipow2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ipow2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 32768, 4, 2, 2, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ipow2(a: &Tensor<IntegerRep>, scale_output: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = a_i as f64;
|
||||
let kix = scale_output * (2.0_f64).powf(kix);
|
||||
let rounded = kix.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies ln base 2 to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ilog2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 2]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ilog2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 1, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ilog2(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let log = (kix).log2();
|
||||
let floor = log.floor();
|
||||
let ceil = log.ceil();
|
||||
let floor_dist = ((2.0_f64).powf(floor) - kix).abs();
|
||||
let ceil_dist = (kix - (2.0_f64).powf(ceil)).abs();
|
||||
|
||||
if floor_dist < ceil_dist {
|
||||
Ok::<_, TensorError>(floor as IntegerRep)
|
||||
} else {
|
||||
Ok::<_, TensorError>(ceil as IntegerRep)
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sigmoid to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1602,12 +1689,11 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies exponential to a tensor of integers.
|
||||
/// Elementwise applies ln to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// * `scale_output` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -27,7 +27,8 @@
|
||||
"check_mode": "UNSAFE",
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false
|
||||
},
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127,"version":"source - no compatibility guaranteed"}
|
||||
@@ -205,7 +205,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 95] = [
|
||||
const TESTS: [&str; 96] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -305,6 +305,7 @@ mod native_tests {
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
"rsqrt", // 94
|
||||
"log", // 95
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -538,12 +539,12 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=94 {
|
||||
seq!(N in 0..=95 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -555,15 +556,7 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_div_rebase_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_public_outputs_(test: &str) {
|
||||
@@ -571,7 +564,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -581,7 +574,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 );
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -591,7 +584,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -602,7 +595,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -611,7 +604,17 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_bounded_lookup_log(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -622,7 +625,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -637,7 +640,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -647,7 +650,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -656,7 +659,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -665,7 +668,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -674,7 +677,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -683,7 +686,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -692,7 +695,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -701,7 +704,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -711,7 +714,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -721,7 +724,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -730,7 +733,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -740,7 +743,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -749,7 +752,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -759,7 +762,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -769,7 +772,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -779,7 +782,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -789,7 +792,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -799,7 +802,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -976,7 +979,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1453,6 +1456,7 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
@@ -1465,10 +1469,10 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
&mut tolerance,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
@@ -1606,10 +1610,10 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: &mut f32,
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1628,9 +1632,9 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
args.push("--div-rebasing".to_string());
|
||||
};
|
||||
if bounded_lookup_log {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
@@ -1730,7 +1734,6 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
target_perc: f32,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1742,10 +1745,10 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -2026,10 +2029,10 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
&mut 0.0,
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -2458,10 +2461,10 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
Reference in New Issue
Block a user