Compare commits

...

6 Commits

Author SHA1 Message Date
dante
523c77c912 feat: lookupless sqrt and rsqrt (#867) 2024-11-10 15:56:38 +00:00
dante
948e5cd4b9 chore: version proof and witness (#865) 2024-11-08 02:55:35 +00:00
dante
00155e585f feat: bounded lookup log argument (#864) 2024-11-07 12:16:55 +00:00
dante
0876faa12c feat: bounded lookup round half to even (#863) 2024-11-01 00:51:15 -04:00
dante
a3c131dac0 feat: lookupless rounding ops (#862) 2024-10-31 11:29:46 -04:00
sebastiandanconia
fd9c2305ac docs: improve cli friendliness (#861)
* Improve clarity of an info!() message

* Replace references to EZKL_REPO_PATH in `--help' output

Command `--help' messages aren't meant to be unduly verbose; we can
write them for common/simple use cases. We continue to support
EZKL_REPO_PATH for users who need it, for example to support
containerized server use cases.

To be clear, by default, EZKL_REPO_PATH = $HOME/.ezkl
2024-10-30 17:25:47 -04:00
31 changed files with 1758 additions and 704 deletions

View File

@@ -240,6 +240,8 @@ jobs:
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and bounded lookup log
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10

View File

@@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -648,10 +648,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -271,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@@ -171,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -328,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "171702d3",
"metadata": {},
"outputs": [],
@@ -348,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"id": "671dfdd5",
"metadata": {},
"outputs": [],
@@ -364,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"id": "50eba2f4",
"metadata": {},
"outputs": [],
@@ -399,9 +399,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

42
examples/onnx/log/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.log(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 3)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}

View File

@@ -0,0 +1,14 @@
pytorch2.2.2:o

inputoutput/Log"Log
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -21,9 +21,9 @@ def main():
torch_model = Circuit()
# Input to the model
shape = [3, 2, 3]
w = 0.1*torch.rand(1, *shape, requires_grad=True)
x = 0.1*torch.rand(1, *shape, requires_grad=True)
y = 0.1*torch.rand(1, *shape, requires_grad=True)
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
torch_out = torch_model(w, x, y)
# Export the model
torch.onnx.export(torch_model, # model being run

View File

@@ -1 +1,148 @@
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
{
"input_shapes": [
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
]
],
"input_data": [
[
0.5,
1.5,
-0.04514765739440918,
0.5936200618743896,
0.9271858930587769,
0.6688600778579712,
-0.20331168174743652,
-0.7016235589981079,
0.025863051414489746,
-0.19426143169403076,
0.9827852249145508,
0.4897397756576538,
-1.5,
-0.5,
0.9278832674026489,
0.5943725109100342,
-0.573331356048584,
0.3675816059112549
],
[
0.7803324460983276,
-0.9616303443908691,
0.6070173978805542,
-0.028337717056274414,
-0.5080242156982422,
-0.9280107021331787,
0.6150380373001099,
0.3865993022918701,
-0.43668973445892334,
0.17152702808380127,
0.5144252777099609,
-0.28881049156188965,
0.8932310342788696,
0.059034109115600586,
0.6865451335906982,
0.009820222854614258,
0.23011493682861328,
-0.9492779970169067
],
[
-0.21352827548980713,
-0.16015326976776123,
-0.38964390754699707,
0.13464701175689697,
-0.8814496994018555,
0.5037975311279297,
-0.804405927658081,
0.9858957529067993,
0.19567716121673584,
0.9777265787124634,
0.6151977777481079,
0.568595290184021,
0.10584986209869385,
-0.8975653648376465,
0.6235959529876709,
-0.547879695892334,
0.9289869070053101,
0.7567293643951416
]
],
"output_data": [
[
1.0,
0.0,
-0.0,
1.0,
1.0,
1.0,
-0.0,
-1.0,
0.0,
-0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
1.0,
-1.0,
0.0
],
[
0.0,
-1.0,
0.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
-1.0
],
[
-0.0,
-0.0,
-0.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0
]
]
}

View File

@@ -1,10 +1,11 @@
pytorch2.0.1:â
pytorch2.2.2:ă

woutput_w/Round"Round

xoutput_x/Floor"Floor

youtput_y/Ceil"Ceil torch_jitZ%
youtput_y/Ceil"Ceil
main_graphZ%
w



View File

@@ -180,9 +180,6 @@ struct PyRunArgs {
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
@@ -197,6 +194,9 @@ struct PyRunArgs {
/// int: The number of legs used for decomposition
#[pyo3(get, set)]
pub decomp_legs: usize,
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
}
/// default instantiation of PyRunArgs
@@ -212,6 +212,7 @@ impl PyRunArgs {
impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
@@ -223,7 +224,6 @@ impl From<PyRunArgs> for RunArgs {
output_visibility: py_run_args.output_visibility,
param_visibility: py_run_args.param_visibility,
variables: py_run_args.variables,
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
@@ -236,6 +236,7 @@ impl From<PyRunArgs> for RunArgs {
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
@@ -247,7 +248,6 @@ impl Into<PyRunArgs> for RunArgs {
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
@@ -873,8 +873,6 @@ fn gen_settings(
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
@@ -889,7 +887,6 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
py: Python,
@@ -901,7 +898,6 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
only_range_check_rebase: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::calibrate(
@@ -912,7 +908,6 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.await

View File

@@ -13,13 +13,38 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Ln {
scale: utils::F32,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
legs: usize,
},
Ceil {
scale: utils::F32,
legs: usize,
},
Floor {
scale: utils::F32,
legs: usize,
},
Round {
scale: utils::F32,
legs: usize,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
@@ -96,8 +121,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn as_string(&self) -> String {
match self {
HybridOp::Max => format!("MAX"),
HybridOp::Min => format!("MIN"),
HybridOp::Rsqrt {
input_scale,
output_scale,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
HybridOp::Max => "MAX".to_string(),
HybridOp::Min => "MIN".to_string(),
HybridOp::Recip {
input_scale,
output_scale,
@@ -105,13 +146,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
@@ -166,6 +201,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Ceil { scale, legs } => {
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Floor { scale, legs } => {
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Round { scale, legs } => {
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
HybridOp::SumPool {
@@ -192,13 +253,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
)?,
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::loop_div(
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
layouts::div(
config,
region,
values[..].try_into()?,
@@ -289,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
HybridOp::Softmax {
output_scale,
input_scale,
..
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
};
@@ -16,15 +15,10 @@ use halo2curves::ff::PrimeField;
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Div { denom: utils::F32 },
Cast { scale: utils::F32 },
Ceil { scale: utils::F32 },
Floor { scale: utils::F32 },
Round { scale: utils::F32 },
RoundHalfToEven { scale: utils::F32 },
Sqrt { scale: utils::F32 },
Rsqrt { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
IsOdd,
PowersOfTwo { scale: utils::F32 },
Ln { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
Exp { scale: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
@@ -54,19 +48,14 @@ impl LookupOp {
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
LookupOp::Floor { scale } => format!("floor_{}", scale),
LookupOp::Round { scale } => format!("round_{}", scale),
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
LookupOp::IsOdd => "is_odd".to_string(),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale } => format!("exp_{}", scale),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
@@ -91,45 +80,28 @@ impl LookupOp {
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res =
match &self {
LookupOp::Ceil { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
}
LookupOp::Floor { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
LookupOp::PowersOfTwo { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
}
LookupOp::Round { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
}
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
LookupOp::Cast { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
),
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Sqrt { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
}
LookupOp::Rsqrt { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
}
LookupOp::Erf { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
}
LookupOp::Exp { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
}
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
}
LookupOp::Cos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
}
@@ -186,18 +158,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the name of the operation
fn as_string(&self) -> String {
match self {
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
LookupOp::IsOdd => "IS_ODD".to_string(),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
@@ -231,13 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
_ => inputs_scale[0],
};
let scale = inputs_scale[0];
Ok(scale)
}

View File

@@ -474,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
Ok(())
}
/// Update the max and min forcefully
pub fn update_max_min_lookup_inputs_force(
&mut self,
min: IntegerRep,
max: IntegerRep,
) -> Result<(), CircuitError> {
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
if range.0 > range.1 {

View File

@@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get largest element represented by the range
pub fn largest(&self) -> IntegerRep {
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
}
fn name(&self) -> String {
format!(
"{}_{}_{}",
self.nonlinearity.as_path(),
self.range.0,
self.range.1
self.largest()
)
}
/// Configures the table.
@@ -222,7 +226,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
}
let smallest = self.range.0;
let largest = self.range.1;
let largest = self.largest();
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
let inputs = Tensor::from(smallest..=largest)
@@ -291,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);
if !preassigned_input {
table.assign_cell(
|| format!("nl_i_col row {}", row_offset),

View File

@@ -333,18 +333,6 @@ impl<'source> FromPyObject<'source> for ContractType {
}
}
}
// not wasm
use lazy_static::lazy_static;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
lazy_static! {
/// The version of the ezkl library
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
"source - no compatibility guaranteed"
} else {
env!("CARGO_PKG_VERSION")
};
}
/// Get the styles for the CLI
pub fn get_styles() -> clap::builder::Styles {
@@ -395,7 +383,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
pub struct Cli {
/// If provided, outputs the completion file for given shell
#[clap(long = "generate", value_parser)]
@@ -486,9 +474,6 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long, value_hint = clap::ValueHint::Other)]
max_logrows: Option<u32>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
only_range_check_rebase: Option<bool>,
},
/// Generates a dummy SRS
@@ -508,7 +493,7 @@ pub enum Commands {
/// Gets an SRS from a circuit settings file.
#[command(name = "get-srs")]
GetSrs {
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
@@ -555,7 +540,7 @@ pub enum Commands {
/// The path to save the proving key to
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
@@ -582,7 +567,7 @@ pub enum Commands {
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
#[arg(
@@ -624,7 +609,7 @@ pub enum Commands {
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to output the verification key file to
@@ -701,7 +686,7 @@ pub enum Commands {
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
#[arg(
@@ -733,7 +718,7 @@ pub enum Commands {
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -755,7 +740,7 @@ pub enum Commands {
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
#[command(name = "create-evm-vka")]
CreateEvmVKArtifact {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -798,7 +783,7 @@ pub enum Commands {
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load the desired verification key file
@@ -831,7 +816,7 @@ pub enum Commands {
/// The path to the verification key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
@@ -849,7 +834,7 @@ pub enum Commands {
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit

View File

@@ -140,7 +140,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
scales,
scale_rebase_multiplier,
max_logrows,
only_range_check_rebase,
} => calibrate(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
@@ -149,7 +148,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
max_logrows,
)
.await
@@ -671,17 +669,21 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
let pb = init_spinner();
pb.set_message("Validating SRS (this may take a while) ...");
let pb = init_spinner();
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
pb.finish_with_message("SRS validated.");
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?;
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
let mut file = std::fs::File::create(&computed_srs_path)?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!(
"Saved SRS to {}.",
computed_srs_path.as_os_str().to_str().unwrap_or("disk")
);
info!("SRS downloaded");
} else {
@@ -727,7 +729,7 @@ pub(crate) async fn gen_witness(
None
};
let mut input = circuit.load_graph_input(&data).await?;
let mut input = circuit.load_graph_input(&data).await?;
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
let mut input = circuit.load_graph_input(&data)?;
@@ -967,7 +969,6 @@ pub(crate) async fn calibrate(
lookup_safety_margin: f64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, EZKLError> {
use log::error;
@@ -1003,12 +1004,6 @@ pub(crate) async fn calibrate(
(11..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if only_range_check_rebase {
vec![false]
} else {
vec![true, false]
};
let mut found_params: Vec<GraphSettings> = vec![];
// 2 x 2 grid
@@ -1046,12 +1041,6 @@ pub(crate) async fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
let range_grid = range_grid
.iter()
.cartesian_product(div_rebasing.iter())
.map(|(a, b)| (*a, *b))
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
let mut forward_pass_res = HashMap::new();
let pb = init_bar(range_grid.len() as u64);
@@ -1060,30 +1049,23 @@ pub(crate) async fn calibrate(
let mut num_failed = 0;
let mut num_passed = 0;
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
pb.set_message(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
"i-scale: {}, p-scale: {}, rebase-(x): {}, fail: {}, pass: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().blue(),
div_rebasing.to_string().yellow(),
scale_rebase_multiplier.to_string().yellow(),
num_failed.to_string().red(),
num_passed.to_string().green()
));
let key = (
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
);
let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);
let local_run_args = RunArgs {
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
..settings.run_args.clone()
};
@@ -1187,7 +1169,6 @@ pub(crate) async fn calibrate(
let found_run_args = RunArgs {
input_scale: new_settings.run_args.input_scale,
param_scale: new_settings.run_args.param_scale,
div_rebasing: new_settings.run_args.div_rebasing,
lookup_range: new_settings.run_args.lookup_range,
logrows: new_settings.run_args.logrows,
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
@@ -1295,7 +1276,6 @@ pub(crate) async fn calibrate(
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
best_params.run_args.div_rebasing,
))
.ok_or("no params found")?
.iter()
@@ -2021,7 +2001,7 @@ pub(crate) fn mock_aggregate(
}
}
// proof aggregation
let pb = {
let pb = {
let pb = init_spinner();
pb.set_message("Aggregating (may take a while)...");
pb
@@ -2032,7 +2012,7 @@ pub(crate) fn mock_aggregate(
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
prover.verify().map_err(ExecutionError::VerifyError)?;
pb.finish_with_message("Done.");
pb.finish_with_message("Done.");
Ok(String::new())
}
@@ -2126,7 +2106,7 @@ pub(crate) fn aggregate(
}
// proof aggregation
let pb = {
let pb = {
let pb = init_spinner();
pb.set_message("Aggregating (may take a while)...");
pb
@@ -2275,7 +2255,7 @@ pub(crate) fn aggregate(
);
snark.save(&proof_path)?;
pb.finish_with_message("Done.");
pb.finish_with_message("Done.");
Ok(snark)
}

View File

@@ -133,6 +133,8 @@ pub struct GraphWitness {
pub min_lookup_inputs: IntegerRep,
/// max range check size
pub max_range_size: IntegerRep,
/// (optional) version of ezkl used
pub version: Option<String>,
}
impl GraphWitness {
@@ -161,6 +163,7 @@ impl GraphWitness {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
version: None,
}
}
@@ -1350,6 +1353,7 @@ impl GraphCircuit {
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_size: model_results.max_range_size,
version: Some(crate::version().to_string()),
};
witness.generate_rescaled_elements(

View File

@@ -915,20 +915,9 @@ impl Model {
if scales.contains_key(&i) {
let scale_diff = n.out_scale - scales[&i];
n.opkind = if scale_diff > 0 {
RebaseScale::rebase(
n.opkind,
scales[&i],
n.out_scale,
1,
run_args.div_rebasing,
)
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
} else {
RebaseScale::rebase_up(
n.opkind,
scales[&i],
n.out_scale,
run_args.div_rebasing,
)
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
};
n.out_scale = scales[&i];
}

View File

@@ -120,7 +120,6 @@ impl RebaseScale {
global_scale: crate::Scale,
op_out_scale: crate::Scale,
scale_rebase_multiplier: u32,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
@@ -137,7 +136,6 @@ impl RebaseScale {
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op.original_scale,
})
@@ -148,7 +146,6 @@ impl RebaseScale {
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op_out_scale,
})
@@ -163,7 +160,6 @@ impl RebaseScale {
inner: SupportedOp,
target_scale: crate::Scale,
op_out_scale: crate::Scale,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
@@ -176,7 +172,6 @@ impl RebaseScale {
original_scale: op.original_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
})
} else {
@@ -187,7 +182,6 @@ impl RebaseScale {
original_scale: op_out_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
})
}
@@ -595,13 +589,7 @@ impl Node {
let mut out_scale = opkind.out_scale(in_scales.clone())?;
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
let global_scale = scales.get_max();
opkind = RebaseScale::rebase(
opkind,
global_scale,
out_scale,
scales.rebase_multiplier,
run_args.div_rebasing,
);
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
out_scale = opkind.out_scale(in_scales)?;

View File

@@ -764,7 +764,7 @@ pub fn new_op_from_onnx(
.collect::<Vec<_>>();
if inputs.len() == 2 {
if const_inputs.len() > 0 {
if !const_inputs.is_empty() {
let const_idx = const_inputs[0];
let boxed_op = inputs[const_idx].opkind();
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
@@ -803,7 +803,7 @@ pub fn new_op_from_onnx(
}
}
"Recip" => {
let in_scale = inputs[0].out_scales()[0];
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
// If the input scale is larger than the params scale
SupportedOp::Hybrid(HybridOp::Recip {
@@ -837,61 +837,75 @@ pub fn new_op_from_onnx(
"Abs" => SupportedOp::Linear(PolyOp::Abs),
"Neg" => SupportedOp::Linear(PolyOp::Neg),
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
"Sqrt" => SupportedOp::Hybrid(HybridOp::Sqrt {
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Rsqrt" => {
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Rsqrt {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
})
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Ln" => {
if run_args.bounded_log_lookup {
SupportedOp::Hybrid(HybridOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
})
} else {
SupportedOp::Nonlinear(LookupOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
})
}
}
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Source" => {
let dt = node.outputs[0].fact.datum_type;
@@ -935,11 +949,9 @@ pub fn new_op_from_onnx(
replace_const(
0,
0,
SupportedOp::Nonlinear(LookupOp::Cast {
scale: crate::circuit::utils::F32(scale_to_multiplier(
input_scales[0],
)
as f32),
SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
)?
} else {
@@ -1045,7 +1057,7 @@ pub fn new_op_from_onnx(
}
};
let in_scale = inputs[0].out_scales()[0];
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
@@ -1083,17 +1095,21 @@ pub fn new_op_from_onnx(
pool_dims: kernel_shape.to_vec(),
})
}
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Floor" => SupportedOp::Nonlinear(LookupOp::Floor {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Round" => SupportedOp::Nonlinear(LookupOp::Round {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
"Round" => SupportedOp::Hybrid(HybridOp::Round {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Sign" => SupportedOp::Linear(PolyOp::Sign),
"Pow" => {
@@ -1113,7 +1129,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
} else {
SupportedOp::Nonlinear(LookupOp::Pow {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(input_scales[0]).into(),
a: crate::circuit::utils::F32(exponent),
})
}

View File

@@ -108,6 +108,18 @@ use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
/// The version of the ezkl library
const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Get the version of the library
pub fn version() -> &'static str {
match VERSION {
"0.0.0" => "source - no compatibility guaranteed",
_ => VERSION,
}
}
/// Bindings managment
#[cfg(any(
feature = "ios-bindings",
@@ -297,8 +309,6 @@ pub struct RunArgs {
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
/// Should constants with 0.0 fraction be rebased to scale 0
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
@@ -317,11 +327,18 @@ pub struct RunArgs {
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
/// the number of legs used for decompositions
pub decomp_legs: usize,
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// use unbounded lookup for the log
pub bounded_log_lookup: bool,
}
impl Default for RunArgs {
fn default() -> Self {
Self {
bounded_log_lookup: false,
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
@@ -333,7 +350,6 @@ impl Default for RunArgs {
input_visibility: Visibility::Private,
output_visibility: Visibility::Public,
param_visibility: Visibility::Private,
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: None,

View File

@@ -59,10 +59,7 @@ fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
#[allow(missing_docs)]
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(ValueEnum)
)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
pub enum ProofType {
#[default]
Single,
@@ -134,10 +131,7 @@ impl<'source> pyo3::FromPyObject<'source> for ProofType {
#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(ValueEnum)
)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
pub enum StrategyType {
Single,
Accum,
@@ -203,10 +197,7 @@ pub enum PfSysError {
#[allow(missing_docs)]
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(ValueEnum)
)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
pub enum TranscriptType {
Poseidon,
#[default]
@@ -324,6 +315,8 @@ where
pub timestamp: Option<u128>,
/// commitment
pub commitment: Option<Commitments>,
/// (optional) version of ezkl used to generate the proof
version: Option<String>,
}
#[cfg(feature = "python-bindings")]
@@ -385,6 +378,7 @@ where
.as_millis(),
),
commitment,
version: Some(crate::version().to_string()),
}
}
@@ -920,6 +914,7 @@ mod tests {
pretty_public_inputs: None,
timestamp: None,
commitment: None,
version: None,
};
snark

View File

@@ -27,7 +27,7 @@ pub fn get_rep(
n: usize,
) -> Result<Vec<IntegerRep>, DecompositionError> {
// check if x is too large
if x.abs() > (base.pow(n as u32) as IntegerRep) {
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
return Err(DecompositionError::TooLarge(*x, base, n));
}
let mut rep = vec![0; n + 1];
@@ -1421,85 +1421,6 @@ pub fn slice<T: TensorType + Send + Sync>(
pub mod nonlinearities {
use super::*;
/// Ceiling operator.
/// # Arguments
/// * `a` - Tensor
/// * `scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
///
/// use ezkl::tensor::ops::nonlinearities::ceil;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[3, 2],
/// ).unwrap();
/// let result = ceil(&x, 2.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn ceil(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale;
let rounded = kix.ceil() * scale;
Ok::<_, TensorError>(rounded as IntegerRep)
})
.unwrap()
}
/// Floor operator.
/// # Arguments
/// * `a` - Tensor
/// * `scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::floor;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[3, 2],
/// ).unwrap();
/// let result = floor(&x, 2.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn floor(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale;
let rounded = kix.floor() * scale;
Ok::<_, TensorError>(rounded as IntegerRep)
})
.unwrap()
}
/// Round operator.
/// # Arguments
/// * `a` - Tensor
/// * `scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::round;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[3, 2],
/// ).unwrap();
/// let result = round(&x, 2.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn round(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale;
let rounded = kix.round() * scale;
Ok::<_, TensorError>(rounded as IntegerRep)
})
.unwrap()
}
/// Round half to even operator.
/// # Arguments
/// * `a` - Tensor
@@ -1553,6 +1474,93 @@ pub mod nonlinearities {
.unwrap()
}
/// Checks if a tensor's elements are odd
/// # Arguments
/// * `a` - Tensor
/// * `scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::is_odd;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
///
/// let result = is_odd(&x);
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn is_odd(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let rounded = if a_i % 2 == 0 { 0 } else { 1 };
Ok::<_, TensorError>(rounded)
})
.unwrap()
}
/// Powers of 2
/// # Arguments
/// * `a` - Tensor
/// * `scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::ipow2;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = ipow2(&x, 1.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 32768, 4, 2, 2, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn ipow2(a: &Tensor<IntegerRep>, scale_output: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = a_i as f64;
let kix = scale_output * (2.0_f64).powf(kix);
let rounded = kix.round();
Ok::<_, TensorError>(rounded as IntegerRep)
})
.unwrap()
}
/// Elementwise applies ln base 2 to a tensor of integers.
/// # Arguments
/// * `a` - Tensor
/// * `scale_input` - Single value
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::ilog2;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 2]),
/// &[2, 3],
/// ).unwrap();
/// let result = ilog2(&x, 1.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 1, 0, 0, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn ilog2(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let log = (kix).log2();
let floor = log.floor();
let ceil = log.ceil();
let floor_dist = ((2.0_f64).powf(floor) - kix).abs();
let ceil_dist = (kix - (2.0_f64).powf(ceil)).abs();
if floor_dist < ceil_dist {
Ok::<_, TensorError>(floor as IntegerRep)
} else {
Ok::<_, TensorError>(ceil as IntegerRep)
}
})
.unwrap()
}
/// Elementwise applies sigmoid to a tensor of integers.
/// # Arguments
///
@@ -1681,12 +1689,11 @@ pub mod nonlinearities {
.unwrap()
}
/// Elementwise applies exponential to a tensor of integers.
/// Elementwise applies ln to a tensor of integers.
/// # Arguments
///
/// * `a` - Tensor
/// * `scale_input` - Single value
/// * `scale_output` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
@@ -1721,27 +1728,6 @@ pub mod nonlinearities {
.unwrap()
}
/// Elementwise applies sign to a tensor of integers.
/// # Arguments
/// * `a` - Tensor
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::sign;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[-2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = sign(&x);
/// let expected = Tensor::<IntegerRep>::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn sign(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum()))
.unwrap()
}
/// Elementwise applies square root to a tensor of integers.
/// # Arguments
///
@@ -2225,101 +2211,6 @@ pub mod nonlinearities {
.unwrap()
}
/// Elementwise applies leaky relu to a tensor of integers.
/// # Arguments
///
/// * `a` - Tensor
/// * `scale` - Single value
/// * `slope` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::leakyrelu;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, -5]),
/// &[2, 3],
/// ).unwrap();
/// let result = leakyrelu(&x, 0.1);
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn leakyrelu(a: &Tensor<IntegerRep>, slope: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let rounded = if a_i < 0 {
let d_inv_x = (slope) * (a_i as f64);
d_inv_x.round() as IntegerRep
} else {
let d_inv_x = a_i as f64;
d_inv_x.round() as IntegerRep
};
Ok::<_, TensorError>(rounded)
})
.unwrap()
}
/// Elementwise applies max to a tensor of integers.
/// # Arguments
/// * `a` - Tensor
/// * `b` - scalar
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::max;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, -5]),
/// &[2, 3],
/// ).unwrap();
/// let result = max(&x, 1.0, 1.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn max(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
// calculate value of output
a.par_enum_map(|_, a_i| {
let d_inv_x = (a_i as f64) / scale_input;
let rounded = if d_inv_x <= threshold {
(threshold * scale_input).round() as IntegerRep
} else {
(d_inv_x * scale_input).round() as IntegerRep
};
Ok::<_, TensorError>(rounded)
})
.unwrap()
}
/// Elementwise applies min to a tensor of integers.
/// # Arguments
/// * `a` - Tensor
/// * `b` - scalar
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::min;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, -5]),
/// &[2, 3],
/// ).unwrap();
/// let result = min(&x, 1.0, 2.0);
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn min(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
// calculate value of output
a.par_enum_map(|_, a_i| {
let d_inv_x = (a_i as f64) / scale_input;
let rounded = if d_inv_x >= threshold {
(threshold * scale_input).round() as IntegerRep
} else {
(d_inv_x * scale_input).round() as IntegerRep
};
Ok::<_, TensorError>(rounded)
})
.unwrap()
}
/// Elementwise divides a tensor with a const integer element.
/// # Arguments
///
@@ -2400,104 +2291,6 @@ pub mod nonlinearities {
})
.unwrap()
}
/// Elementwise greater than
/// # Arguments
///
/// * `a` - Tensor
/// * `b` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::greater_than;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 7, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = 2.0;
/// let result = greater_than(&x, k);
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn greater_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) > 0_f64)))
.unwrap()
}
/// Elementwise greater than
/// # Arguments
///
/// * `a` - Tensor
/// * `b` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::greater_than_equal;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 7, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = 2.0;
/// let result = greater_than_equal(&x, k);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn greater_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) >= 0_f64)))
.unwrap()
}
/// Elementwise less than
/// # Arguments
/// * `a` - Tensor
/// * `b` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::less_than;
///
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 7, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = 2.0;
///
/// let result = less_than(&x, k);
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn less_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) < 0_f64)))
.unwrap()
}
/// Elementwise less than
/// # Arguments
/// * `a` - Tensor
/// * `b` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::less_than_equal;
///
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 7, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = 2.0;
///
/// let result = less_than_equal(&x, k);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn less_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) <= 0_f64)))
.unwrap()
}
}
/// Ops that return the transcript i.e intermediate calcs of an op

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -27,7 +27,8 @@
"check_mode": "UNSAFE",
"commitment": "KZG",
"decomp_base": 128,
"decomp_legs": 2
"decomp_legs": 2,
"bounded_log_lookup": false
},
"num_rows": 46,
"total_assignments": 92,

View File

@@ -1 +1 @@
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127,"version":"source - no compatibility guaranteed"}

View File

@@ -205,7 +205,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 95] = [
const TESTS: [&str; 96] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -305,6 +305,7 @@ mod native_tests {
"lstm_medium", // 92
"lenet_5", // 93
"rsqrt", // 94
"log", // 95
];
const WASM_TESTS: [&str; 46] = [
@@ -538,12 +539,12 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
test_dir.close().unwrap();
}
});
seq!(N in 0..=94 {
seq!(N in 0..=95 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -555,15 +556,7 @@ mod native_tests {
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn accuracy_measurement_div_rebase_(test: &str) {
crate::native_tests::init_binary();
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn accuracy_measurement_public_outputs_(test: &str) {
@@ -571,7 +564,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6);
test_dir.close().unwrap();
}
@@ -581,7 +574,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 );
test_dir.close().unwrap();
}
@@ -591,7 +584,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6);
test_dir.close().unwrap();
}
@@ -602,7 +595,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1);
test_dir.close().unwrap();
}
@@ -611,7 +604,17 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_bounded_lookup_log(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
test_dir.close().unwrap();
}
@@ -622,7 +625,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// gen random number between 0.0 and 1.0
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
test_dir.close().unwrap();
}
@@ -637,7 +640,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
}
@@ -647,7 +650,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -656,7 +659,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -665,7 +668,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -674,7 +677,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -683,7 +686,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -692,7 +695,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -701,7 +704,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -711,7 +714,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -721,7 +724,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -730,7 +733,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -740,7 +743,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -749,7 +752,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -759,7 +762,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -769,7 +772,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -779,7 +782,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -789,7 +792,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -799,7 +802,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
@@ -852,9 +855,11 @@ mod native_tests {
fn kzg_prove_and_verify_tight_lookup_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let path = test_dir.into_path();
let path = path.to_str().unwrap();
crate::native_tests::mv_test_(path, test);
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, false, "single", Commitments::KZG, 1);
test_dir.close().unwrap();
// test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
@@ -974,7 +979,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
test_dir.close().unwrap();
}
});
@@ -1451,6 +1456,7 @@ mod native_tests {
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
bounded_lookup_log: bool,
) {
let mut tolerance = tolerance;
gen_circuit_settings_and_witness(
@@ -1463,10 +1469,10 @@ mod native_tests {
cal_target,
scales_to_use,
2,
false,
&mut tolerance,
Commitments::KZG,
2,
bounded_lookup_log,
);
if tolerance > 0.0 {
@@ -1604,10 +1610,10 @@ mod native_tests {
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
div_rebasing: bool,
tolerance: &mut f32,
commitment: Commitments,
lookup_safety_margin: usize,
bounded_lookup_log: bool,
) {
let mut args = vec![
"gen-settings".to_string(),
@@ -1626,13 +1632,12 @@ mod native_tests {
format!("--commitment={}", commitment),
];
if div_rebasing {
args.push("--div-rebasing".to_string());
};
if bounded_lookup_log {
args.push("--bounded-log-lookup".to_string());
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args(args)
.stdout(std::process::Stdio::null())
.status()
.expect("failed to execute process");
assert!(status.success());
@@ -1729,7 +1734,6 @@ mod native_tests {
batch_size: usize,
cal_target: &str,
target_perc: f32,
div_rebasing: bool,
) {
gen_circuit_settings_and_witness(
test_dir,
@@ -1741,10 +1745,10 @@ mod native_tests {
cal_target,
None,
2,
div_rebasing,
&mut 0.0,
Commitments::KZG,
2,
false,
);
println!(
@@ -2025,10 +2029,10 @@ mod native_tests {
target_str,
scales_to_use,
num_inner_columns,
false,
&mut 0.0,
commitment,
lookup_safety_margin,
false,
);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
@@ -2457,10 +2461,10 @@ mod native_tests {
// we need the accuracy
Some(vec![4]),
1,
false,
&mut 0.0,
Commitments::KZG,
2,
false,
);
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);