mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
523c77c912 | ||
|
|
948e5cd4b9 | ||
|
|
00155e585f | ||
|
|
0876faa12c | ||
|
|
a3c131dac0 | ||
|
|
fd9c2305ac | ||
|
|
a0060f341d | ||
|
|
17f1d42739 | ||
|
|
ebaee9e2b1 | ||
|
|
d51cba589a |
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -240,6 +240,8 @@ jobs:
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
|
||||
16
Cargo.lock
generated
16
Cargo.lock
generated
@@ -2543,6 +2543,12 @@ dependencies = [
|
||||
"allocator-api2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
@@ -2811,12 +2817,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.2.5"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
|
||||
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.3",
|
||||
"hashbrown 0.15.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5628,9 +5634,9 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.2"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
|
||||
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
|
||||
@@ -56,6 +56,7 @@ alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5
|
||||
"rpc-types-eth",
|
||||
"signer-wallet",
|
||||
"node-bindings",
|
||||
|
||||
], optional = true }
|
||||
foundry-compilers = { version = "0.4.1", features = ["svm-solc"], optional = true }
|
||||
ethabi = { version = "18", optional = true }
|
||||
@@ -168,7 +169,7 @@ harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "relu"
|
||||
name = "sigmoid"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
@@ -176,12 +177,12 @@ name = "relu_lookupless"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu"
|
||||
name = "accum_matmul_sigmoid"
|
||||
harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu_overflow"
|
||||
name = "accum_matmul_sigmoid_overflow"
|
||||
harness = false
|
||||
|
||||
[[bin]]
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&a,
|
||||
BITS,
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&a,
|
||||
BITS,
|
||||
k,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -68,7 +68,14 @@ impl Circuit<Fr> for NLCircuit {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
@@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -146,6 +146,8 @@ where
|
||||
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::configure(
|
||||
@@ -156,15 +158,11 @@ where
|
||||
);
|
||||
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
@@ -195,6 +193,11 @@ where
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
@@ -224,7 +227,10 @@ where
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -53,6 +53,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
// tells the config layer to add an affine op to the circuit gate
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::<F>::configure(
|
||||
cs,
|
||||
&[input.clone(), params.clone()],
|
||||
@@ -60,17 +64,12 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
@@ -104,6 +103,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
@@ -144,7 +148,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -184,7 +191,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
println!("6");
|
||||
|
||||
@@ -592,7 +592,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -648,10 +648,10 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -171,7 +171,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -328,7 +328,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "171702d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 27,
|
||||
"id": "671dfdd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -364,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 28,
|
||||
"id": "50eba2f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -399,9 +399,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
42
examples/onnx/log/gen.py
Normal file
42
examples/onnx/log/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.log(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 3)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/log/input.json
Normal file
1
examples/onnx/log/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}
|
||||
14
examples/onnx/log/network.onnx
Normal file
14
examples/onnx/log/network.onnx
Normal file
@@ -0,0 +1,14 @@
|
||||
pytorch2.2.2:o
|
||||
|
||||
inputoutput/Log"Log
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -21,9 +21,9 @@ def main():
|
||||
torch_model = Circuit()
|
||||
# Input to the model
|
||||
shape = [3, 2, 3]
|
||||
w = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
x = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
y = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
torch_out = torch_model(w, x, y)
|
||||
# Export the model
|
||||
torch.onnx.export(torch_model, # model being run
|
||||
|
||||
@@ -1 +1,148 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
|
||||
{
|
||||
"input_shapes": [
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
]
|
||||
],
|
||||
"input_data": [
|
||||
[
|
||||
0.5,
|
||||
1.5,
|
||||
-0.04514765739440918,
|
||||
0.5936200618743896,
|
||||
0.9271858930587769,
|
||||
0.6688600778579712,
|
||||
-0.20331168174743652,
|
||||
-0.7016235589981079,
|
||||
0.025863051414489746,
|
||||
-0.19426143169403076,
|
||||
0.9827852249145508,
|
||||
0.4897397756576538,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.9278832674026489,
|
||||
0.5943725109100342,
|
||||
-0.573331356048584,
|
||||
0.3675816059112549
|
||||
],
|
||||
[
|
||||
0.7803324460983276,
|
||||
-0.9616303443908691,
|
||||
0.6070173978805542,
|
||||
-0.028337717056274414,
|
||||
-0.5080242156982422,
|
||||
-0.9280107021331787,
|
||||
0.6150380373001099,
|
||||
0.3865993022918701,
|
||||
-0.43668973445892334,
|
||||
0.17152702808380127,
|
||||
0.5144252777099609,
|
||||
-0.28881049156188965,
|
||||
0.8932310342788696,
|
||||
0.059034109115600586,
|
||||
0.6865451335906982,
|
||||
0.009820222854614258,
|
||||
0.23011493682861328,
|
||||
-0.9492779970169067
|
||||
],
|
||||
[
|
||||
-0.21352827548980713,
|
||||
-0.16015326976776123,
|
||||
-0.38964390754699707,
|
||||
0.13464701175689697,
|
||||
-0.8814496994018555,
|
||||
0.5037975311279297,
|
||||
-0.804405927658081,
|
||||
0.9858957529067993,
|
||||
0.19567716121673584,
|
||||
0.9777265787124634,
|
||||
0.6151977777481079,
|
||||
0.568595290184021,
|
||||
0.10584986209869385,
|
||||
-0.8975653648376465,
|
||||
0.6235959529876709,
|
||||
-0.547879695892334,
|
||||
0.9289869070053101,
|
||||
0.7567293643951416
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
0.0
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0
|
||||
],
|
||||
[
|
||||
-0.0,
|
||||
-0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0
|
||||
]
|
||||
]
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
pytorch2.0.1:â
|
||||
pytorch2.2.2:ă
|
||||
|
||||
woutput_w/Round"Round
|
||||
|
||||
xoutput_x/Floor"Floor
|
||||
|
||||
youtput_y/Ceil"Ceil torch_jitZ%
|
||||
youtput_y/Ceil"Ceil
|
||||
main_graphZ%
|
||||
w
|
||||
|
||||
|
||||
|
||||
42
examples/onnx/rsqrt/gen.py
Normal file
42
examples/onnx/rsqrt/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# reciprocal sqrt
|
||||
m = 1 / torch.sqrt(x)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/rsqrt/input.json
Normal file
1
examples/onnx/rsqrt/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}
|
||||
17
examples/onnx/rsqrt/network.onnx
Normal file
17
examples/onnx/rsqrt/network.onnx
Normal file
@@ -0,0 +1,17 @@
|
||||
pytorch2.2.2:Ź
|
||||
$
|
||||
input/Sqrt_output_0/Sqrt"Sqrt
|
||||
1
|
||||
/Sqrt_output_0output/Reciprocal"
|
||||
Reciprocal
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -180,9 +180,6 @@ struct PyRunArgs {
|
||||
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Should constants with 0.0 fraction be rebased to scale 0
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
@@ -197,6 +194,9 @@ struct PyRunArgs {
|
||||
/// int: The number of legs used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_legs: usize,
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -212,6 +212,7 @@ impl PyRunArgs {
|
||||
impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
@@ -223,7 +224,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
@@ -236,6 +236,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
@@ -247,7 +248,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
@@ -873,8 +873,6 @@ fn gen_settings(
|
||||
/// max_logrows: int
|
||||
/// Optional max logrows to use for calibration
|
||||
///
|
||||
/// only_range_check_rebase: bool
|
||||
/// Check ranges when rebasing
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -889,7 +887,6 @@ fn gen_settings(
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
@@ -901,7 +898,6 @@ fn calibrate_settings(
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
@@ -912,7 +908,6 @@ fn calibrate_settings(
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -177,7 +177,7 @@ impl<'source> FromPyObject<'source> for Tolerance {
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
|
||||
pub lookup_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub table_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
@@ -209,7 +209,7 @@ impl DynamicLookups {
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, usize), Selector>,
|
||||
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
@@ -646,57 +646,73 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
for t in tables.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of inner columns
|
||||
if tables
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
.windows(2)
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"tables inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_ltable = cs.complex_selector();
|
||||
for q in 0..tables[0].num_blocks() {
|
||||
let s_ltable = cs.complex_selector();
|
||||
|
||||
for x in 0..lookups[0].num_blocks() {
|
||||
for y in 0..lookups[0].num_inner_cols() {
|
||||
let s_lookup = cs.complex_selector();
|
||||
for x in 0..lookups[0].num_blocks() {
|
||||
for y in 0..lookups[0].num_inner_cols() {
|
||||
let s_lookup = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_lookupq = cs.query_selector(s_lookup);
|
||||
let mut expression = vec![];
|
||||
let s_ltableq = cs.query_selector(s_ltable);
|
||||
let mut lookup_queries = vec![one.clone()];
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_lookupq = cs.query_selector(s_lookup);
|
||||
let mut expression = vec![];
|
||||
let s_ltableq = cs.query_selector(s_ltable);
|
||||
let mut lookup_queries = vec![one.clone()];
|
||||
|
||||
for lookup in lookups {
|
||||
lookup_queries.push(match lookup {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
for lookup in lookups {
|
||||
lookup_queries.push(match lookup {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut table_queries = vec![one.clone()];
|
||||
for table in tables {
|
||||
table_queries.push(match table {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
let mut table_queries = vec![one.clone()];
|
||||
for table in tables {
|
||||
table_queries.push(match table {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
|
||||
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
|
||||
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.dynamic_lookups
|
||||
.lookup_selectors
|
||||
.entry((x, y))
|
||||
.or_insert(s_lookup);
|
||||
expression
|
||||
});
|
||||
self.dynamic_lookups
|
||||
.lookup_selectors
|
||||
.entry((q, (x, y)))
|
||||
.or_insert(s_lookup);
|
||||
}
|
||||
}
|
||||
|
||||
self.dynamic_lookups.table_selectors.push(s_ltable);
|
||||
}
|
||||
self.dynamic_lookups.table_selectors.push(s_ltable);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.dynamic_lookups.tables.is_empty() {
|
||||
@@ -729,57 +745,72 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of blocks
|
||||
if references
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
.windows(2)
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"references inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_reference = cs.complex_selector();
|
||||
for q in 0..references[0].num_blocks() {
|
||||
let s_reference = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
input_queries.push(match input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
for input in inputs {
|
||||
input_queries.push(match input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.shuffles
|
||||
.input_selectors
|
||||
.entry((x, y))
|
||||
.or_insert(s_input);
|
||||
expression
|
||||
});
|
||||
self.shuffles
|
||||
.input_selectors
|
||||
.entry((q, (x, y)))
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.references.is_empty() {
|
||||
|
||||
@@ -94,4 +94,7 @@ pub enum CircuitError {
|
||||
#[error("[io] {0}")]
|
||||
/// IO error
|
||||
IoError(#[from] std::io::Error),
|
||||
/// Invalid scale
|
||||
#[error("negative scale for an op that requires positive inputs {0}")]
|
||||
NegativeScale(String),
|
||||
}
|
||||
|
||||
@@ -13,14 +13,38 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
ReduceMax {
|
||||
axes: Vec<usize>,
|
||||
@@ -45,6 +69,8 @@ pub enum HybridOp {
|
||||
ReduceArgMin {
|
||||
dim: usize,
|
||||
},
|
||||
Max,
|
||||
Min,
|
||||
Softmax {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -79,6 +105,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::Max
|
||||
| HybridOp::Min
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
vec![0, 1]
|
||||
}
|
||||
@@ -93,21 +121,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RSQRT (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
|
||||
HybridOp::Max => "MAX".to_string(),
|
||||
HybridOp::Min => "MIN".to_string(),
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
|
||||
input_scale, output_scale, use_range_check_for_int
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"DIV (denom={}, use_range_check_for_int={})",
|
||||
denom, use_range_check_for_int
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -162,6 +201,34 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => layouts::rsqrt(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
)?,
|
||||
HybridOp::Sqrt { scale } => {
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Floor { scale, legs } => {
|
||||
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Round { scale, legs } => {
|
||||
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -179,38 +246,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
if input_scale.0.fract() == 0.0
|
||||
&& output_scale.0.fract() == 0.0
|
||||
&& *use_range_check_for_int
|
||||
{
|
||||
layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Recip {
|
||||
input_scale: *input_scale,
|
||||
output_scale: *output_scale,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::loop_div(
|
||||
} => layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
layouts::div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
@@ -301,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
|
||||
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
output_scale,
|
||||
input_scale,
|
||||
..
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -15,101 +14,27 @@ use halo2curves::ff::PrimeField;
|
||||
/// An enum representing the operations that can be used to express more complex operations via accumulation
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
},
|
||||
Sigmoid {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
KroneckerDelta,
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
HardSwish {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Div { denom: utils::F32 },
|
||||
IsOdd,
|
||||
PowersOfTwo { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
Cosh { scale: utils::F32 },
|
||||
ACosh { scale: utils::F32 },
|
||||
Sin { scale: utils::F32 },
|
||||
ASin { scale: utils::F32 },
|
||||
Sinh { scale: utils::F32 },
|
||||
ASinh { scale: utils::F32 },
|
||||
Tan { scale: utils::F32 },
|
||||
ATan { scale: utils::F32 },
|
||||
Tanh { scale: utils::F32 },
|
||||
ATanh { scale: utils::F32 },
|
||||
Erf { scale: utils::F32 },
|
||||
Pow { scale: utils::F32, a: utils::F32 },
|
||||
HardSwish { scale: utils::F32 },
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
@@ -123,27 +48,14 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
|
||||
LookupOp::IsOdd => "is_odd".to_string(),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
@@ -168,65 +80,28 @@ impl LookupOp {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::Ceil { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Floor { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
|
||||
LookupOp::PowersOfTwo { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
|
||||
}
|
||||
LookupOp::Round { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
|
||||
}
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::KroneckerDelta => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
|
||||
}
|
||||
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Cast { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
|
||||
),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
|
||||
}
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Rsqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Erf { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Exp { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
|
||||
}
|
||||
@@ -283,29 +158,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
|
||||
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
|
||||
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::KroneckerDelta => "K_DELTA".into(),
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RECIP(input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
|
||||
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
|
||||
LookupOp::IsOdd => "IS_ODD".to_string(),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
|
||||
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
|
||||
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
|
||||
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
|
||||
@@ -339,15 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::KroneckerDelta => 0,
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
let scale = inputs_scale[0];
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
|
||||
@@ -255,7 +255,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
self.raw_values = Tensor::new(None, &[0]).unwrap();
|
||||
}
|
||||
|
||||
///
|
||||
/// Pre-assign a value
|
||||
pub fn pre_assign(&mut self, val: ValTensor<F>) {
|
||||
self.pre_assigned_val = Some(val)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
circuit::{
|
||||
layouts,
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -9,9 +12,12 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
ReLU,
|
||||
Abs,
|
||||
Sign,
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
scale: i32,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
@@ -112,9 +118,9 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
|
||||
PolyOp::Abs => "ABS".to_string(),
|
||||
PolyOp::Sign => "SIGN".to_string(),
|
||||
PolyOp::ReLU => "RELU".to_string(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
@@ -198,7 +204,9 @@ impl<
|
||||
Ok(Some(match self {
|
||||
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
|
||||
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
|
||||
PolyOp::LeakyReLU { slope, scale } => {
|
||||
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
|
||||
}
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
@@ -329,6 +337,12 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
// this corresponds to the relu operation
|
||||
PolyOp::LeakyReLU {
|
||||
slope: F32(0.0), ..
|
||||
} => in_scales[0],
|
||||
// this corresponds to the leaky relu operation with a slope which induces a change in scale
|
||||
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
|
||||
@@ -180,6 +180,7 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
max_dynamic_input_len: usize,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
|
||||
@@ -193,11 +194,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.settings.legs
|
||||
}
|
||||
|
||||
/// get the max dynamic input len
|
||||
pub fn max_dynamic_input_len(&self) -> usize {
|
||||
self.max_dynamic_input_len
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
///
|
||||
pub fn debug_report(&self) {
|
||||
log::debug!(
|
||||
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
|
||||
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={}, max_dynamic_input_len={})",
|
||||
self.row().to_string().blue(),
|
||||
self.linear_coord().to_string().yellow(),
|
||||
self.total_constants().to_string().red(),
|
||||
@@ -205,7 +211,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.min_lookup_inputs().to_string().green(),
|
||||
self.max_range_size().to_string().green(),
|
||||
self.dynamic_lookup_col_coord().to_string().green(),
|
||||
self.shuffle_col_coord().to_string().green());
|
||||
self.shuffle_col_coord().to_string().green(),
|
||||
self.max_dynamic_input_len().to_string().green()
|
||||
);
|
||||
}
|
||||
|
||||
///
|
||||
@@ -223,6 +231,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.dynamic_lookup_index.index += n;
|
||||
}
|
||||
|
||||
/// increment the max dynamic input len
|
||||
pub fn update_max_dynamic_input_len(&mut self, n: usize) {
|
||||
self.max_dynamic_input_len = self.max_dynamic_input_len.max(n);
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
|
||||
self.dynamic_lookup_index.col_coord += n;
|
||||
@@ -274,6 +287,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(decomp_base, decomp_legs),
|
||||
assigned_constants: HashMap::new(),
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,6 +324,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -331,6 +346,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min forcefully
|
||||
pub fn update_max_min_lookup_inputs_force(
|
||||
&mut self,
|
||||
min: IntegerRep,
|
||||
max: IntegerRep,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
@@ -583,9 +610,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(ValTensor<F>, usize), CircuitError> {
|
||||
|
||||
self.update_max_dynamic_input_len(values.len());
|
||||
|
||||
if let Some(region) = &self.region {
|
||||
Ok(var.assign(
|
||||
Ok(var.assign_exact_column(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
@@ -596,7 +626,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
|
||||
let flush_len = var.get_column_flush(self.combined_dynamic_shuffle_coord(), values)?;
|
||||
|
||||
// get the diff between the current column and the next row
|
||||
Ok((values.clone(), flush_len))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -605,7 +639,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(ValTensor<F>, usize), CircuitError> {
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
|
||||
@@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get largest element represented by the range
|
||||
pub fn largest(&self) -> IntegerRep {
|
||||
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
|
||||
}
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"{}_{}_{}",
|
||||
self.nonlinearity.as_path(),
|
||||
self.range.0,
|
||||
self.range.1
|
||||
self.largest()
|
||||
)
|
||||
}
|
||||
/// Configures the table.
|
||||
@@ -222,7 +226,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
let largest = self.largest();
|
||||
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
@@ -291,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
row_offset += chunk_idx * self.col_size;
|
||||
let (x, y) = self.cartesian_coord(row_offset);
|
||||
|
||||
if !preassigned_input {
|
||||
table.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
|
||||
@@ -1379,7 +1379,10 @@ mod conv_relu_col_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
Box::new(PolyOp::ReLU),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -1516,7 +1519,7 @@ mod add_w_shape_casting {
|
||||
// parameters
|
||||
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let b = Tensor::from((0..1).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
let b = Tensor::from((0..1).map(|i| Value::known(F::from(i + 1))));
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
@@ -2347,7 +2350,14 @@ mod matmul_relu {
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -2439,7 +2449,14 @@ mod relu {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
Ok(config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
)
|
||||
@@ -2482,11 +2499,11 @@ mod lookup_ultra_overflow {
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
struct SigmoidCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for ReLUCircuit<F> {
|
||||
impl Circuit<F> for SigmoidCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
@@ -2500,7 +2517,7 @@ mod lookup_ultra_overflow {
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
|
||||
@@ -2533,7 +2550,7 @@ mod lookup_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
@@ -2546,13 +2563,13 @@ mod lookup_ultra_overflow {
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn relucircuit() {
|
||||
fn sigmoidcircuit() {
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
// parameters
|
||||
let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1))));
|
||||
|
||||
let circuit = ReLUCircuit::<F> {
|
||||
let circuit = SigmoidCircuit::<F> {
|
||||
input: ValTensor::from(a),
|
||||
};
|
||||
|
||||
@@ -2562,7 +2579,7 @@ mod lookup_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
ReLUCircuit<F>,
|
||||
SigmoidCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -333,18 +333,6 @@ impl<'source> FromPyObject<'source> for ContractType {
|
||||
}
|
||||
}
|
||||
}
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
lazy_static! {
|
||||
/// The version of the ezkl library
|
||||
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
|
||||
"source - no compatibility guaranteed"
|
||||
} else {
|
||||
env!("CARGO_PKG_VERSION")
|
||||
};
|
||||
}
|
||||
|
||||
/// Get the styles for the CLI
|
||||
pub fn get_styles() -> clap::builder::Styles {
|
||||
@@ -395,7 +383,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, about, long_about = None)]
|
||||
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
|
||||
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
|
||||
pub struct Cli {
|
||||
/// If provided, outputs the completion file for given shell
|
||||
#[clap(long = "generate", value_parser)]
|
||||
@@ -486,9 +474,6 @@ pub enum Commands {
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
|
||||
only_range_check_rebase: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -508,7 +493,7 @@ pub enum Commands {
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
#[command(name = "get-srs")]
|
||||
GetSrs {
|
||||
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
|
||||
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
|
||||
@@ -555,7 +540,7 @@ pub enum Commands {
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -582,7 +567,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -624,7 +609,7 @@ pub enum Commands {
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to output the verification key file to
|
||||
@@ -701,7 +686,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -733,7 +718,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -755,7 +740,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -798,7 +783,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
@@ -831,7 +816,7 @@ pub enum Commands {
|
||||
/// The path to the verification key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
@@ -849,7 +834,7 @@ pub enum Commands {
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
|
||||
@@ -140,7 +140,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
@@ -149,7 +148,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -671,17 +669,21 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
pb.finish_with_message("SRS validated.");
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?;
|
||||
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
|
||||
let mut file = std::fs::File::create(&computed_srs_path)?;
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
info!(
|
||||
"Saved SRS to {}.",
|
||||
computed_srs_path.as_os_str().to_str().unwrap_or("disk")
|
||||
);
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -727,7 +729,7 @@ pub(crate) async fn gen_witness(
|
||||
None
|
||||
};
|
||||
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
@@ -967,7 +969,6 @@ pub(crate) async fn calibrate(
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, EZKLError> {
|
||||
use log::error;
|
||||
@@ -1003,12 +1004,6 @@ pub(crate) async fn calibrate(
|
||||
(11..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
// 2 x 2 grid
|
||||
@@ -1046,12 +1041,6 @@ pub(crate) async fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
@@ -1060,30 +1049,23 @@ pub(crate) async fn calibrate(
|
||||
let mut num_failed = 0;
|
||||
let mut num_passed = 0;
|
||||
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, fail: {}, pass: {}",
|
||||
input_scale.to_string().blue(),
|
||||
param_scale.to_string().blue(),
|
||||
scale_rebase_multiplier.to_string().blue(),
|
||||
div_rebasing.to_string().yellow(),
|
||||
scale_rebase_multiplier.to_string().yellow(),
|
||||
num_failed.to_string().red(),
|
||||
num_passed.to_string().green()
|
||||
));
|
||||
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
@@ -1187,7 +1169,6 @@ pub(crate) async fn calibrate(
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
@@ -1203,6 +1184,7 @@ pub(crate) async fn calibrate(
|
||||
num_rows: new_settings.num_rows,
|
||||
total_assignments: new_settings.total_assignments,
|
||||
total_const_size: new_settings.total_const_size,
|
||||
total_dynamic_col_size: new_settings.total_dynamic_col_size,
|
||||
..settings.clone()
|
||||
};
|
||||
|
||||
@@ -1294,7 +1276,6 @@ pub(crate) async fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -1320,7 +1301,9 @@ pub(crate) async fn calibrate(
|
||||
let lookup_log_rows = best_params.lookup_log_rows_with_blinding();
|
||||
let module_log_row = best_params.module_constraint_logrows_with_blinding();
|
||||
let instance_logrows = best_params.log2_total_instances_with_blinding();
|
||||
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
let dynamic_lookup_logrows =
|
||||
best_params.min_dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
|
||||
let range_check_logrows = best_params.range_check_log_rows_with_blinding();
|
||||
|
||||
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
|
||||
@@ -2018,7 +2001,7 @@ pub(crate) fn mock_aggregate(
|
||||
}
|
||||
}
|
||||
// proof aggregation
|
||||
let pb = {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
pb
|
||||
@@ -2029,7 +2012,7 @@ pub(crate) fn mock_aggregate(
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
pb.finish_with_message("Done.");
|
||||
pb.finish_with_message("Done.");
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -2123,7 +2106,7 @@ pub(crate) fn aggregate(
|
||||
}
|
||||
|
||||
// proof aggregation
|
||||
let pb = {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
pb
|
||||
@@ -2272,7 +2255,7 @@ pub(crate) fn aggregate(
|
||||
);
|
||||
snark.save(&proof_path)?;
|
||||
|
||||
pb.finish_with_message("Done.");
|
||||
pb.finish_with_message("Done.");
|
||||
|
||||
Ok(snark)
|
||||
}
|
||||
|
||||
@@ -133,6 +133,8 @@ pub struct GraphWitness {
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// max range check size
|
||||
pub max_range_size: IntegerRep,
|
||||
/// (optional) version of ezkl used
|
||||
pub version: Option<String>,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -161,6 +163,7 @@ impl GraphWitness {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
version: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,6 +411,8 @@ pub struct GraphSettings {
|
||||
pub total_const_size: usize,
|
||||
/// total dynamic column size
|
||||
pub total_dynamic_col_size: usize,
|
||||
/// max dynamic column input length
|
||||
pub max_dynamic_input_len: usize,
|
||||
/// number of dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// number of shuffles
|
||||
@@ -485,6 +490,13 @@ impl GraphSettings {
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the dynamic lookup and shuffle
|
||||
pub fn min_dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
|
||||
(self.max_dynamic_input_len as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
|
||||
self.total_dynamic_col_size + self.total_shuffle_col_size
|
||||
}
|
||||
@@ -1341,6 +1353,7 @@ impl GraphCircuit {
|
||||
max_lookup_inputs: model_results.max_lookup_inputs,
|
||||
min_lookup_inputs: model_results.min_lookup_inputs,
|
||||
max_range_size: model_results.max_range_size,
|
||||
version: Some(crate::version().to_string()),
|
||||
};
|
||||
|
||||
witness.generate_rescaled_elements(
|
||||
|
||||
@@ -103,6 +103,8 @@ pub struct DummyPassRes {
|
||||
pub num_rows: usize,
|
||||
/// num dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// max dynamic lookup input len
|
||||
pub max_dynamic_input_len: usize,
|
||||
/// dynamic lookup col size
|
||||
pub dynamic_lookup_col_coord: usize,
|
||||
/// num shuffles
|
||||
@@ -360,6 +362,14 @@ impl NodeType {
|
||||
NodeType::SubGraph { .. } => SupportedOp::Unknown(Unknown),
|
||||
}
|
||||
}
|
||||
|
||||
/// check if it is a softmax
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
match self {
|
||||
NodeType::Node(n) => n.is_softmax(),
|
||||
NodeType::SubGraph { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
@@ -562,6 +572,7 @@ impl Model {
|
||||
num_rows: res.num_rows,
|
||||
total_assignments: res.linear_coord,
|
||||
required_lookups: res.lookup_ops.into_iter().collect(),
|
||||
max_dynamic_input_len: res.max_dynamic_input_len,
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
@@ -904,20 +915,9 @@ impl Model {
|
||||
if scales.contains_key(&i) {
|
||||
let scale_diff = n.out_scale - scales[&i];
|
||||
n.opkind = if scale_diff > 0 {
|
||||
RebaseScale::rebase(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
1,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
|
||||
} else {
|
||||
RebaseScale::rebase_up(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
|
||||
};
|
||||
n.out_scale = scales[&i];
|
||||
}
|
||||
@@ -1465,6 +1465,7 @@ impl Model {
|
||||
let res = DummyPassRes {
|
||||
num_rows: region.row(),
|
||||
linear_coord: region.linear_coord(),
|
||||
max_dynamic_input_len: region.max_dynamic_input_len(),
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
|
||||
@@ -120,7 +120,6 @@ impl RebaseScale {
|
||||
global_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
scale_rebase_multiplier: u32,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
@@ -137,7 +136,6 @@ impl RebaseScale {
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op.original_scale,
|
||||
})
|
||||
@@ -148,7 +146,6 @@ impl RebaseScale {
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op_out_scale,
|
||||
})
|
||||
@@ -163,7 +160,6 @@ impl RebaseScale {
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
|
||||
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
|
||||
@@ -176,7 +172,6 @@ impl RebaseScale {
|
||||
original_scale: op.original_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
@@ -187,7 +182,6 @@ impl RebaseScale {
|
||||
original_scale: op_out_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -595,13 +589,7 @@ impl Node {
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(
|
||||
opkind,
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
run_args.div_rebasing,
|
||||
);
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
@@ -623,6 +611,15 @@ impl Node {
|
||||
num_uses,
|
||||
})
|
||||
}
|
||||
|
||||
/// check if it is a softmax node
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
|
||||
@@ -763,93 +763,52 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
if unit == 0. {
|
||||
SupportedOp::Linear(PolyOp::ReLU)
|
||||
if !const_inputs.is_empty() {
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
if unit == 0. {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
SupportedOp::Nonlinear(LookupOp::Max {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
}
|
||||
"Min" => {
|
||||
// Extract the min value
|
||||
// first find the input that is a constant
|
||||
// and then extract the value
|
||||
let const_inputs = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.is_constant())
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Min {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Min)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -864,8 +823,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::LeakyReLU {
|
||||
SupportedOp::Linear(PolyOp::LeakyReLU {
|
||||
slope: crate::circuit::utils::F32(leaky_op.alpha),
|
||||
scale: scales.params,
|
||||
})
|
||||
}
|
||||
"Scan" => {
|
||||
@@ -877,61 +837,75 @@ pub fn new_op_from_onnx(
|
||||
"Abs" => SupportedOp::Linear(PolyOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Sqrt" => SupportedOp::Hybrid(HybridOp::Sqrt {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
})
|
||||
}
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Ln" => {
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Source" => {
|
||||
let dt = node.outputs[0].fact.datum_type;
|
||||
@@ -975,11 +949,9 @@ pub fn new_op_from_onnx(
|
||||
replace_const(
|
||||
0,
|
||||
0,
|
||||
SupportedOp::Nonlinear(LookupOp::Cast {
|
||||
scale: crate::circuit::utils::F32(scale_to_multiplier(
|
||||
input_scales[0],
|
||||
)
|
||||
as f32),
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
)?
|
||||
} else {
|
||||
@@ -1085,7 +1057,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Softmax {
|
||||
@@ -1123,17 +1095,21 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Nonlinear(LookupOp::Floor {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Nonlinear(LookupOp::Round {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
@@ -1146,10 +1122,17 @@ pub fn new_op_from_onnx(
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(c.raw_values[0]),
|
||||
})
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
|
||||
if exponent.fract() == 0.0 {
|
||||
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
a: crate::circuit::utils::F32(exponent),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support constant pow for now")
|
||||
}
|
||||
|
||||
@@ -443,7 +443,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
let dynamic_lookup =
|
||||
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
|
||||
if dynamic_lookup.num_blocks() > 1 {
|
||||
panic!("dynamic lookup or shuffle should only have one block");
|
||||
warn!("dynamic lookup has {} blocks", dynamic_lookup.num_blocks());
|
||||
};
|
||||
advices.push(dynamic_lookup);
|
||||
}
|
||||
|
||||
22
src/lib.rs
22
src/lib.rs
@@ -108,6 +108,18 @@ use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
/// The version of the ezkl library
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Get the version of the library
|
||||
pub fn version() -> &'static str {
|
||||
match VERSION {
|
||||
"0.0.0" => "source - no compatibility guaranteed",
|
||||
_ => VERSION,
|
||||
}
|
||||
}
|
||||
|
||||
/// Bindings managment
|
||||
#[cfg(any(
|
||||
feature = "ios-bindings",
|
||||
@@ -297,8 +309,6 @@ pub struct RunArgs {
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
@@ -317,11 +327,18 @@ pub struct RunArgs {
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
@@ -333,7 +350,6 @@ impl Default for RunArgs {
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
div_rebasing: false,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: None,
|
||||
|
||||
@@ -59,10 +59,7 @@ fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum ProofType {
|
||||
#[default]
|
||||
Single,
|
||||
@@ -134,10 +131,7 @@ impl<'source> pyo3::FromPyObject<'source> for ProofType {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum StrategyType {
|
||||
Single,
|
||||
Accum,
|
||||
@@ -203,10 +197,7 @@ pub enum PfSysError {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum TranscriptType {
|
||||
Poseidon,
|
||||
#[default]
|
||||
@@ -324,6 +315,8 @@ where
|
||||
pub timestamp: Option<u128>,
|
||||
/// commitment
|
||||
pub commitment: Option<Commitments>,
|
||||
/// (optional) version of ezkl used to generate the proof
|
||||
version: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -385,6 +378,7 @@ where
|
||||
.as_millis(),
|
||||
),
|
||||
commitment,
|
||||
version: Some(crate::version().to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -920,6 +914,7 @@ mod tests {
|
||||
pretty_public_inputs: None,
|
||||
timestamp: None,
|
||||
commitment: None,
|
||||
version: None,
|
||||
};
|
||||
|
||||
snark
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) {
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -1421,85 +1421,6 @@ pub fn slice<T: TensorType + Send + Sync>(
|
||||
pub mod nonlinearities {
|
||||
use super::*;
|
||||
|
||||
/// Ceiling operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
///
|
||||
/// use ezkl::tensor::ops::nonlinearities::ceil;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = ceil(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ceil(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.ceil() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Floor operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::floor;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = floor(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn floor(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.floor() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::round;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = round(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn round(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.round() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round half to even operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
@@ -1553,30 +1474,88 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Applies Kronecker delta to a tensor of integers.
|
||||
/// Checks if a tensor's elements are odd
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::kronecker_delta;
|
||||
/// use ezkl::tensor::ops::nonlinearities::is_odd;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = kronecker_delta(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
///
|
||||
/// let result = is_odd(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn kronecker_delta<T: TensorType + std::cmp::PartialEq + Send + Sync>(
|
||||
a: &Tensor<T>,
|
||||
) -> Tensor<T> {
|
||||
pub fn is_odd(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
if a_i == T::zero().unwrap() {
|
||||
Ok::<_, TensorError>(T::one().unwrap())
|
||||
let rounded = if a_i % 2 == 0 { 0 } else { 1 };
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Powers of 2
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ipow2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ipow2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 32768, 4, 2, 2, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ipow2(a: &Tensor<IntegerRep>, scale_output: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = a_i as f64;
|
||||
let kix = scale_output * (2.0_f64).powf(kix);
|
||||
let rounded = kix.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies ln base 2 to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ilog2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 2]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ilog2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 1, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ilog2(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let log = (kix).log2();
|
||||
let floor = log.floor();
|
||||
let ceil = log.ceil();
|
||||
let floor_dist = ((2.0_f64).powf(floor) - kix).abs();
|
||||
let ceil_dist = (kix - (2.0_f64).powf(ceil)).abs();
|
||||
|
||||
if floor_dist < ceil_dist {
|
||||
Ok::<_, TensorError>(floor as IntegerRep)
|
||||
} else {
|
||||
Ok::<_, TensorError>(T::zero().unwrap())
|
||||
Ok::<_, TensorError>(ceil as IntegerRep)
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
@@ -1710,12 +1689,11 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies exponential to a tensor of integers.
|
||||
/// Elementwise applies ln to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// * `scale_output` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -1750,27 +1728,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sign to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::sign;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[-2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sign(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn sign(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies square root to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2254,101 +2211,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies leaky relu to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// * `slope` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::leakyrelu;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = leakyrelu(&x, 0.1);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn leakyrelu(a: &Tensor<IntegerRep>, slope: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rounded = if a_i < 0 {
|
||||
let d_inv_x = (slope) * (a_i as f64);
|
||||
d_inv_x.round() as IntegerRep
|
||||
} else {
|
||||
let d_inv_x = a_i as f64;
|
||||
d_inv_x.round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies max to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::max;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = max(&x, 1.0, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn max(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x <= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies min to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::min;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = min(&x, 1.0, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn min(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x >= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise divides a tensor with a const integer element.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2429,104 +2291,6 @@ pub mod nonlinearities {
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) > 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than_equal;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) >= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) < 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than_equal;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) <= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ops that return the transcript i.e intermediate calcs of an op
|
||||
|
||||
@@ -541,7 +541,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
let mut is_empty = true;
|
||||
x.map(|_| is_empty = false);
|
||||
if is_empty {
|
||||
return Ok::<_, TensorError>(vec![Value::<F>::unknown(); n + 1]);
|
||||
Ok::<_, TensorError>(vec![Value::<F>::unknown(); n + 1])
|
||||
} else {
|
||||
let mut res = vec![Value::unknown(); n + 1];
|
||||
let mut int_rep = 0;
|
||||
|
||||
@@ -396,6 +396,53 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Helper function to get the remaining size of the column
|
||||
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<usize, halo2_proofs::plonk::Error> {
|
||||
if values.len() > self.col_size() {
|
||||
error!("Values are too large for the column");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
// this can only be called on columns that have a single inner column
|
||||
if self.num_inner_cols() != 1 {
|
||||
error!("This function can only be called on columns with a single inner column");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
// check if the values fit in the remaining space of the column
|
||||
let current_cartesian = self.cartesian_coord(offset);
|
||||
let final_cartesian = self.cartesian_coord(offset + values.len());
|
||||
|
||||
let mut flush_len = 0;
|
||||
if current_cartesian.0 != final_cartesian.0 {
|
||||
debug!("Values overflow the column, flushing to next column");
|
||||
// diff is the number of values that overflow the column
|
||||
flush_len += self.col_size() - current_cartesian.2;
|
||||
}
|
||||
|
||||
Ok(flush_len)
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
|
||||
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
|
||||
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
let flush_len = self.get_column_flush(offset, values)?;
|
||||
|
||||
let assigned_vals = self.assign(region, offset + flush_len, values, constants)?;
|
||||
|
||||
Ok((assigned_vals, flush_len))
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn dummy_assign_with_duplication<
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -27,12 +27,14 @@
|
||||
"check_mode": "UNSAFE",
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false
|
||||
},
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
"total_const_size": 3,
|
||||
"total_dynamic_col_size": 0,
|
||||
"max_dynamic_input_len": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
"num_shuffles": 0,
|
||||
"total_shuffle_col_size": 0,
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127,"version":"source - no compatibility guaranteed"}
|
||||
@@ -205,7 +205,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 94] = [
|
||||
const TESTS: [&str; 96] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -304,6 +304,8 @@ mod native_tests {
|
||||
"lstm_large", // 91
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
"rsqrt", // 94
|
||||
"log", // 95
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -537,12 +539,12 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
seq!(N in 0..=95 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -554,15 +556,7 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_div_rebase_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_public_outputs_(test: &str) {
|
||||
@@ -570,7 +564,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -580,7 +574,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 );
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -590,7 +584,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -601,7 +595,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -610,7 +604,17 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_bounded_lookup_log(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -621,7 +625,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -636,7 +640,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -646,7 +650,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -655,7 +659,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -664,7 +668,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -673,7 +677,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -682,7 +686,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -691,7 +695,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -700,7 +704,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -710,7 +714,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -720,7 +724,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -729,7 +733,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -739,7 +743,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -748,7 +752,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -758,7 +762,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -768,7 +772,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -778,7 +782,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -788,7 +792,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -798,7 +802,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -851,9 +855,11 @@ mod native_tests {
|
||||
fn kzg_prove_and_verify_tight_lookup_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let path = test_dir.into_path();
|
||||
let path = path.to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, false, "single", Commitments::KZG, 1);
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -973,7 +979,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1118,7 +1124,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
seq!(N in 0..4 {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -1450,6 +1456,7 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
@@ -1462,10 +1469,10 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
&mut tolerance,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
@@ -1603,10 +1610,10 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: &mut f32,
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1625,13 +1632,12 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
args.push("--div-rebasing".to_string());
|
||||
};
|
||||
if bounded_lookup_log {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
@@ -1728,7 +1734,6 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
target_perc: f32,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1740,10 +1745,10 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -2024,10 +2029,10 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
&mut 0.0,
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -2456,10 +2461,10 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -124,41 +124,40 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 34] = [
|
||||
"ezkl_demo_batch.ipynb",
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
// "mnist_vae.ipynb",
|
||||
"keras_simple_demo.ipynb",
|
||||
"mnist_gan_proof_splitting.ipynb", // 4
|
||||
"hashed_vis.ipynb", // 5
|
||||
"simple_demo_all_public.ipynb",
|
||||
"data_attest.ipynb",
|
||||
"little_transformer.ipynb",
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
"xgboost.ipynb",
|
||||
"lightgbm.ipynb",
|
||||
"svm.ipynb",
|
||||
"simple_demo_public_input_output.ipynb",
|
||||
"simple_demo_public_network_output.ipynb", // 20
|
||||
"gcn.ipynb",
|
||||
"linear_regression.ipynb",
|
||||
"stacked_regression.ipynb",
|
||||
"data_attest_hashed.ipynb",
|
||||
"kzg_vis.ipynb", // 25
|
||||
"kmeans.ipynb",
|
||||
"solvency.ipynb",
|
||||
"sklearn_mlp.ipynb",
|
||||
"generalized_inverse.ipynb",
|
||||
"mnist_classifier.ipynb", // 30
|
||||
"world_rotation.ipynb",
|
||||
"logistic_regression.ipynb",
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
"simple_demo_all_public.ipynb", // 7
|
||||
"data_attest.ipynb", // 8
|
||||
"little_transformer.ipynb", // 9
|
||||
"simple_demo_aggregated_proofs.ipynb", // 10
|
||||
"ezkl_demo.ipynb", // 11
|
||||
"lstm.ipynb", // 12
|
||||
"set_membership.ipynb", // 13
|
||||
"decision_tree.ipynb", // 14
|
||||
"random_forest.ipynb", // 15
|
||||
"gradient_boosted_trees.ipynb", // 16
|
||||
"xgboost.ipynb", // 17
|
||||
"lightgbm.ipynb", // 18
|
||||
"svm.ipynb", // 19
|
||||
"simple_demo_public_input_output.ipynb", // 20
|
||||
"simple_demo_public_network_output.ipynb", // 21
|
||||
"gcn.ipynb", // 22
|
||||
"linear_regression.ipynb", // 23
|
||||
"stacked_regression.ipynb", // 24
|
||||
"data_attest_hashed.ipynb", // 25
|
||||
"kzg_vis.ipynb", // 26
|
||||
"kmeans.ipynb", // 27
|
||||
"solvency.ipynb", // 28
|
||||
"sklearn_mlp.ipynb", // 29
|
||||
"generalized_inverse.ipynb", // 30
|
||||
"mnist_classifier.ipynb", // 31
|
||||
"world_rotation.ipynb", // 32
|
||||
"logistic_regression.ipynb", // 33
|
||||
];
|
||||
|
||||
macro_rules! test_func {
|
||||
|
||||
Reference in New Issue
Block a user