Compare commits

...

3 Commits

Author SHA1 Message Date
github-actions[bot]
8015a17978 ci: update version string in docs 2025-04-23 08:12:43 +00:00
dante
0de0682bfa refactor: configurable div epsilon (#968) 2025-04-23 09:12:24 +01:00
dante
bf9cf14ab7 refactor!: rpc url should be required (#965)
BREAKING CHANGE: in python the order of arguments for evm related functions has changed
2025-04-22 12:45:36 +01:00
28 changed files with 215 additions and 166 deletions

View File

@@ -258,7 +258,7 @@ jobs:
- name: Install built wheel
if: matrix.target == 'x86_64-unknown-linux-musl'
uses: addnab/docker-run-action@v3
uses: addnab/docker-run-action@3e77f186b7a929ef010f183a9e24c0f9955ea609
with:
image: alpine:latest
options: -v ${{ github.workspace }}:/io -w /io
@@ -380,7 +380,7 @@ jobs:
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
uses: dfm/rtds-action@618148c547f4b56cdf4fa4dcf3a94c91ce025f2d
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}

View File

@@ -33,7 +33,7 @@ jobs:
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3
with:
crate: cargo-nextest
locked: true
@@ -233,7 +233,7 @@ jobs:
with:
# Pin to version 0.12.1
version: "v0.12.1"
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
@@ -256,10 +256,10 @@ jobs:
submodules: recursive
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1
uses: foundry-rs/foundry-toolchain@3b74dacdda3c0b763089addb99ed86bc3800e68b
- name: Run tests
run: |
run: |
cd tests/foundry
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
forge test -vvvv --fuzz-runs 64

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '22.0.1'
version = release

View File

@@ -1088,7 +1088,7 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" rpc_url='http://127.0.0.1:3030'\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",

View File

@@ -472,8 +472,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -526,9 +526,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -557,8 +557,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -566,7 +566,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -580,7 +580,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.9"
},
"orig_nbformat": 4
},

View File

@@ -543,8 +543,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -597,9 +597,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -628,8 +628,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -651,7 +651,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.9"
},
"orig_nbformat": 4
},

View File

@@ -474,8 +474,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -529,9 +529,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -560,8 +560,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]

View File

@@ -453,8 +453,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -474,8 +474,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -462,8 +462,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -483,8 +483,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -504,8 +504,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -527,8 +527,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -261,8 +261,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
@@ -288,7 +288,7 @@
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
" res = await ezkl.deploy_evm(addr_path_vk, 'http://127.0.0.1:3030', sol_key_code_path, \"vka\")\n",
" assert res == True\n",
"\n",
" with open(addr_path_vk, 'r') as file:\n",
@@ -298,8 +298,8 @@
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" addr_vk = addr_vk\n",
" )\n",
" assert res == True"

View File

@@ -562,8 +562,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -616,9 +616,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -653,8 +653,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]

View File

@@ -666,7 +666,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -689,8 +689,8 @@
"# await\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -701,7 +701,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -722,8 +722,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -743,7 +743,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
@@ -756,7 +757,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@@ -849,8 +849,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -870,8 +870,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -358,8 +358,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -405,9 +405,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )"
]
},
@@ -470,8 +470,8 @@
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -531,7 +531,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@@ -206,6 +206,9 @@ struct PyRunArgs {
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
}
/// default instantiation of PyRunArgs
@@ -238,12 +241,14 @@ impl From<PyRunArgs> for RunArgs {
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
}
}
}
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
let eps = self.get_epsilon();
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
input_scale: self.input_scale,
@@ -262,6 +267,7 @@ impl Into<PyRunArgs> for RunArgs {
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
}
}
}
@@ -962,7 +968,7 @@ fn gen_settings(
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
min=None,
min=None,
max=None
))]
#[gen_stub_pyfunction]
@@ -1823,7 +1829,7 @@ fn create_evm_data_attestation(
test_data,
input_source,
output_source,
rpc_url=None
rpc_url,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_data(
@@ -1833,7 +1839,7 @@ fn setup_test_evm_data(
test_data: PathBuf,
input_source: PyTestDataSource,
output_source: PyTestDataSource,
rpc_url: Option<String>,
rpc_url: String,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_data(
@@ -1857,8 +1863,8 @@ fn setup_test_evm_data(
/// deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
rpc_url,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
contract_type=ContractType::default(),
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
@@ -1867,8 +1873,8 @@ fn setup_test_evm_data(
fn deploy_evm(
py: Python,
addr_path: PathBuf,
rpc_url: String,
sol_code_path: PathBuf,
rpc_url: Option<String>,
contract_type: ContractType,
optimizer_runs: usize,
private_key: Option<String>,
@@ -1896,9 +1902,9 @@ fn deploy_evm(
#[pyfunction(signature = (
addr_path,
input_data,
rpc_url,
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
))]
@@ -1907,9 +1913,9 @@ fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
input_data: String,
rpc_url: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
@@ -1956,8 +1962,8 @@ fn deploy_da_evm(
///
#[pyfunction(signature = (
addr_verifier,
rpc_url,
proof_path=PathBuf::from(DEFAULT_PROOF),
rpc_url=None,
addr_da = None,
addr_vk = None,
))]
@@ -1965,8 +1971,8 @@ fn deploy_da_evm(
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
rpc_url: String,
proof_path: PathBuf,
rpc_url: Option<String>,
addr_da: Option<&'a str>,
addr_vk: Option<&'a str>,
) -> PyResult<Bound<'a, PyAny>> {

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils},
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::multiplier_to_scale,
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
@@ -15,10 +15,12 @@ use serde::{Deserialize, Serialize};
pub enum HybridOp {
Ln {
scale: utils::F32,
eps: f64,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Sqrt {
scale: utils::F32,
@@ -42,6 +44,7 @@ pub enum HybridOp {
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Div {
denom: utils::F32,
@@ -77,6 +80,7 @@ pub enum HybridOp {
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
eps: f64,
},
Output {
decomp: bool,
@@ -128,12 +132,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
"RSQRT (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::Ln { scale, eps } => format!("LN(scale={}, eps={})", scale, eps),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
@@ -146,16 +151,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => format!(
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
"RECIP (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized, data_format
normalized,
data_format,
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
@@ -177,10 +184,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
"SOFTMAX (input_scale={}, output_scale={}, axes={:?}, eps={})",
input_scale, output_scale, axes, eps
)
}
HybridOp::Output { decomp } => {
@@ -211,17 +219,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
*eps,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::Ln { scale, eps } => {
layouts::ln(config, region, values[..].try_into()?, *scale, *eps)?
}
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
@@ -255,12 +267,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
*eps,
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
@@ -317,6 +331,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => layouts::softmax_axes(
config,
region,
@@ -324,6 +339,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*input_scale,
*output_scale,
axes,
*eps,
)?,
HybridOp::Output { decomp } => {
layouts::output(config, region, values[..].try_into()?, *decomp)?
@@ -364,6 +380,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
eps: _,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};

View File

@@ -303,6 +303,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
value: &[ValTensor<F>; 1],
input_scale: F,
output_scale: F,
eps: f64,
) -> Result<ValTensor<F>, CircuitError> {
let input = value[0].clone();
let input_dims = input.dims();
@@ -317,6 +318,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&input_evals,
felt_to_integer_rep(input_scale) as f64,
felt_to_integer_rep(output_scale) as f64,
eps,
)
.par_iter()
.map(|x| Value::known(integer_rep_to_felt(*x)))
@@ -335,7 +337,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let claimed_output = identity(config, region, &[claimed_output], true)?;
// divide by input_scale
let zero_inverse_val =
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64)[0];
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64, eps)[0];
let zero_inverse = create_constant_tensor(integer_rep_to_felt(zero_inverse_val), 1);
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
@@ -473,7 +475,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap());
/// let result = rsqrt::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into(), 1.0.into()).unwrap();
/// let result = rsqrt::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into(), 1.0.into(), f64::EPSILON).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 1, 1, 1, 1, 1, 1]), &[3, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
@@ -483,13 +485,21 @@ pub fn rsqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
value: &[ValTensor<F>; 1],
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
) -> Result<ValTensor<F>, CircuitError> {
let sqrt = sqrt(config, region, value, input_scale)?;
let felt_output_scale = integer_rep_to_felt(output_scale.0 as IntegerRep);
let felt_input_scale = integer_rep_to_felt(input_scale.0 as IntegerRep);
let recip = recip(config, region, &[sqrt], felt_input_scale, felt_output_scale)?;
let recip = recip(
config,
region,
&[sqrt],
felt_input_scale,
felt_output_scale,
eps,
)?;
Ok(recip)
}
@@ -1547,7 +1557,7 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
/// * Creates pairs: (index_input, value_input) for original elements
/// * Creates pairs: (index_output, value_output) for permuted elements
/// * index_input is a fixed sequence 0,1,2... corresponding to input positions
///
///
/// - Core permutation verification:
/// * For each (index_input, value_input), verify there exists exactly one
/// (index_output, value_output) such that value_input = value_output
@@ -5702,7 +5712,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = ln::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into()).unwrap();
/// let result = ln::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into(), f64::EPSILON).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 0, 4, -8]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
@@ -5712,6 +5722,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
eps: f64,
) -> Result<ValTensor<F>, CircuitError> {
// first generate the claimed val
@@ -5882,6 +5893,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&[pow2_prior_to_claimed_distance],
scale_as_felt,
scale_as_felt * scale_as_felt,
eps,
)?;
let interpolated_distance = pairwise(
@@ -5910,6 +5922,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&[pow2_next_to_claimed_distance],
scale_as_felt,
scale_as_felt * scale_as_felt,
eps,
)?;
let interpolated_distance_next = pairwise(
@@ -6698,12 +6711,13 @@ pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::
input_scale: utils::F32,
output_scale: utils::F32,
axes: &[usize],
eps: f64,
) -> Result<ValTensor<F>, CircuitError> {
let soft_max_at_scale = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, CircuitError> {
softmax(config, region, values, input_scale, output_scale)
softmax(config, region, values, input_scale, output_scale, eps)
};
let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?;
@@ -6718,6 +6732,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
values: &[ValTensor<F>; 1],
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
) -> Result<ValTensor<F>, CircuitError> {
let is_assigned = values[0].all_prev_assigned();
let mut input = values[0].clone();
@@ -6736,6 +6751,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
&[denom],
input_felt_scale,
output_felt_scale,
eps,
)?;
// product of num * (1 / denom) = input_scale * output_scale
pairwise(config, region, &[input, inv_denom], BaseOp::Mult)
@@ -6760,7 +6776,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
/// Some(&[2, 2, 3, 2, 2, 0]),
/// &[2, 3],
/// ).unwrap());
/// let result = softmax::<Fp>(&dummy_config, &mut dummy_region, &[x], 128.0.into(), (128.0 * 128.0).into()).unwrap();
/// let result = softmax::<Fp>(&dummy_config, &mut dummy_region, &[x], 128.0.into(), (128.0 * 128.0).into(), f64::EPSILON).unwrap();
/// // doubles the scale of the input
/// let expected = Tensor::<IntegerRep>::new(Some(&[350012, 350012, 352768, 350012, 350012, 344500]), &[2, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
@@ -6771,6 +6787,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
) -> Result<ValTensor<F>, CircuitError> {
// get the max then subtract it
let max_val = max(config, region, values)?;
@@ -6787,7 +6804,14 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
},
)?;
percent(config, region, &[ex.clone()], input_scale, output_scale)
percent(
config,
region,
&[ex.clone()],
input_scale,
output_scale,
eps,
)
}
/// Checks that the percent error between the expected public output and the actual output value

View File

@@ -685,9 +685,9 @@ pub enum Commands {
/// Should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
test_data: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
/// where the input data come from
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
input_source: TestDataSource,
@@ -882,9 +882,9 @@ pub enum Commands {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Url)]
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -910,9 +910,9 @@ pub enum Commands {
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -932,9 +932,9 @@ pub enum Commands {
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
/// does the verifier use data attestation ?
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_da: Option<H160Flag>,

View File

@@ -12,7 +12,6 @@ use alloy::dyn_abi::abi::TokenSeq;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::node_bindings::Anvil;
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{B256, I256, ParseSignedError};
use alloy::providers::ProviderBuilder;
@@ -313,25 +312,12 @@ pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
/// Return an instance of Anvil and a client for the given RPC URL. If none is provided, a local client is used.
pub async fn setup_eth_backend(
rpc_url: Option<&str>,
rpc_url: &str,
private_key: Option<&str>,
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
// Launch anvil
let endpoint: String;
if let Some(rpc_url) = rpc_url {
endpoint = rpc_url.to_string();
} else {
let anvil = Anvil::new()
.args([
"--code-size-limit=41943040",
"--disable-block-gas-limit",
"-p",
"8545",
])
.spawn();
endpoint = anvil.endpoint();
}
let endpoint = rpc_url.to_string();
// Instantiate the wallet
let wallet: LocalWallet;
@@ -365,7 +351,7 @@ pub async fn setup_eth_backend(
///
pub async fn deploy_contract_via_solidity(
sol_code_path: PathBuf,
rpc_url: Option<&str>,
rpc_url: &str,
runs: usize,
private_key: Option<&str>,
contract_name: &str,
@@ -387,7 +373,7 @@ pub async fn deploy_da_verifier_via_solidity(
settings_path: PathBuf,
input: String,
sol_code_path: PathBuf,
rpc_url: Option<&str>,
rpc_url: &str,
runs: usize,
private_key: Option<&str>,
) -> Result<H160, EthError> {
@@ -574,7 +560,7 @@ pub async fn verify_proof_via_solidity(
proof: Snark<Fr, G1Affine>,
addr: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
rpc_url: &str,
) -> Result<bool, EthError> {
let flattened_instances = proof.instances.into_iter().flatten();
@@ -676,7 +662,7 @@ pub async fn verify_proof_with_data_attestation(
addr_verifier: H160,
addr_da: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
rpc_url: &str,
) -> Result<bool, EthError> {
use ethabi::{Function, Param, ParamType, StateMutability, Token};
@@ -1015,7 +1001,7 @@ pub fn fix_da_sol(commitment_bytes: Option<Vec<u8>>, only_kzg: bool) -> Result<S
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {

View File

@@ -1608,7 +1608,7 @@ pub(crate) async fn deploy_da_evm(
data: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1617,7 +1617,7 @@ pub(crate) async fn deploy_da_evm(
settings_path,
data,
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
)
@@ -1632,7 +1632,7 @@ pub(crate) async fn deploy_da_evm(
pub(crate) async fn deploy_evm(
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1645,7 +1645,7 @@ pub(crate) async fn deploy_evm(
};
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
contract_name,
@@ -1688,7 +1688,7 @@ pub(crate) fn encode_evm_calldata(
pub(crate) async fn verify_evm(
proof_path: PathBuf,
addr_verifier: H160Flag,
rpc_url: Option<String>,
rpc_url: String,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
) -> Result<String, EZKLError> {
@@ -1702,7 +1702,7 @@ pub(crate) async fn verify_evm(
addr_verifier.into(),
addr_da.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
&rpc_url,
)
.await?
} else {
@@ -1710,7 +1710,7 @@ pub(crate) async fn verify_evm(
proof.clone(),
addr_verifier.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
&rpc_url,
)
.await?
};
@@ -1851,7 +1851,7 @@ pub(crate) async fn setup_test_evm_data(
data_path: String,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
input_source: TestDataSource,
output_source: TestDataSource,
) -> Result<String, EZKLError> {

View File

@@ -1,17 +1,17 @@
use super::errors::GraphError;
use super::quantize_float;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::postgres::Client;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -19,7 +19,7 @@ use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
tract_data::{TVec, prelude::Tensor as TractTensor},
value::TValue,
};
@@ -201,9 +201,9 @@ impl OnChainSource {
data: &FileSource,
scales: Vec<crate::Scale>,
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
rpc: &str,
) -> Result<Self, GraphError> {
use crate::eth::{read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT};
use crate::eth::{read_on_chain_inputs, test_on_chain_data};
use log::debug;
// Set up local anvil instance for reading on-chain data
@@ -217,7 +217,7 @@ impl OnChainSource {
shapes[idx] = vec![i.len()];
}
}
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
let used_rpc = rpc.to_string();
let call_to_account = test_on_chain_data(client.clone(), data).await?;
debug!("Call to account: {:?}", call_to_account);
@@ -381,7 +381,7 @@ impl GraphData {
return Err(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
))
));
}
}
Ok(inputs)
@@ -434,13 +434,13 @@ impl GraphData {
/// Loads graph input data from a string, first seeing if it is a file path or JSON data
/// If it is a file path, it will load the data from the file
/// Otherwise, it will attempt to parse the string as JSON data
///
///
/// # Arguments
/// * `data` - String containing the input data
/// # Returns
/// A new GraphData instance containing the loaded data
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
@@ -515,7 +515,7 @@ impl GraphData {
return Err(GraphError::InvalidDims(
0,
"on-chain data cannot be split into batches".to_string(),
))
));
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
GraphData {
@@ -538,7 +538,6 @@ impl GraphData {
input.len(),
input_size
),
));
}

View File

@@ -36,12 +36,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::fieldutils::{IntegerRep, felt_to_f64};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use crate::{EZKL_BUF_CAPACITY, RunArgs};
use halo2_proofs::{
circuit::Layouter,
@@ -56,13 +56,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
@@ -764,7 +764,7 @@ pub struct TestOnChainData {
/// The path to the test witness
pub data: std::path::PathBuf,
/// rpc endpoint
pub rpc: Option<String>,
pub rpc: String,
/// data sources for the on chain data
pub data_sources: TestSources,
}
@@ -1027,7 +1027,7 @@ impl GraphCircuit {
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let (client, client_address) = setup_eth_backend(&source.rpc, None).await?;
let input = read_on_chain_inputs(client.clone(), client_address, &source.call).await?;
let quantized_evm_inputs =
evm_quantize(client, scales, &input, &source.call.decimals).await?;
@@ -1481,13 +1481,9 @@ impl GraphCircuit {
// print file data
debug!("file data: {:?}", file_data);
let on_chain_data: OnChainSource = OnChainSource::test_from_file_data(
&file_data,
scales,
shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
let on_chain_data: OnChainSource =
OnChainSource::test_from_file_data(&file_data, scales, shapes, &test_on_chain_data.rpc)
.await?;
// Here we update the GraphData struct with the on-chain data
if input_data.is_some() {
data.input_data = on_chain_data.clone().into();

View File

@@ -858,6 +858,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
@@ -903,6 +904,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Rsqrt {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
@@ -913,6 +915,7 @@ pub fn new_op_from_onnx(
if run_args.bounded_log_lookup {
SupportedOp::Hybrid(HybridOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
eps: run_args.get_epsilon(),
})
} else {
SupportedOp::Nonlinear(LookupOp::Ln {
@@ -1131,6 +1134,7 @@ pub fn new_op_from_onnx(
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
eps: run_args.get_epsilon(),
})
}
"MaxPool" => {

View File

@@ -97,11 +97,11 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{table::Range, CheckMode};
use circuit::{CheckMode, table::Range};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use graph::{MAX_PUBLIC_SRS, Visibility};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
@@ -350,6 +350,16 @@ pub struct RunArgs {
arg(long, default_value = "false")
)]
pub ignore_range_check_inputs_outputs: bool,
/// Optional override for epsilon value
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub epsilon: Option<f64>,
}
impl RunArgs {
/// Returns the epsilon value
pub fn get_epsilon(&self) -> f64 {
self.epsilon.unwrap_or(f64::EPSILON)
}
}
impl Default for RunArgs {
@@ -376,6 +386,7 @@ impl Default for RunArgs {
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
epsilon: None,
}
}
}

View File

@@ -160,7 +160,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
@@ -168,7 +168,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
@@ -188,7 +188,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
@@ -196,7 +196,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
@@ -216,7 +216,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
@@ -224,7 +224,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
/// ```
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
a: &Tensor<T>,
@@ -1859,14 +1859,14 @@ pub mod nonlinearities {
/// Some(&[4, 25, 8, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = rsqrt(&x, 1.0);
/// let result = rsqrt(&x, 1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64, eps: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let fout = scale_input / (kix.sqrt() + f64::EPSILON);
let fout = scale_input / (kix.sqrt() + eps);
let rounded = fout.round();
Ok::<_, TensorError>(rounded as IntegerRep)
})
@@ -2339,15 +2339,20 @@ pub mod nonlinearities {
/// &[2, 3],
/// ).unwrap();
/// let k = 2_f64;
/// let result = recip(&x, 1.0, k);
/// let result = recip(&x, 1.0, k, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
pub fn recip(
a: &Tensor<IntegerRep>,
input_scale: f64,
out_scale: f64,
eps: f64,
) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let rescaled = (a_i as f64) / input_scale;
let denom = if rescaled == 0_f64 {
(1_f64) / (rescaled + f64::EPSILON)
(1_f64) / (rescaled + eps)
} else {
(1_f64) / (rescaled)
};
@@ -2366,16 +2371,16 @@ pub mod nonlinearities {
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let result = zero_recip(1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<IntegerRep> {
pub fn zero_recip(out_scale: f64, eps: f64) -> Tensor<IntegerRep> {
let a = Tensor::<IntegerRep>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let denom = (1_f64) / (rescaled + eps);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
})

Binary file not shown.

View File

@@ -479,15 +479,15 @@ async def test_deploy_evm_reusable_and_vka():
res = await ezkl.deploy_evm(
addr_path_verifier,
sol_code_path,
anvil_url,
sol_code_path,
"verifier/reusable",
)
res = await ezkl.deploy_evm(
addr_path_vk,
vk_code_path,
anvil_url,
vk_code_path,
"vka",
)
@@ -506,8 +506,8 @@ async def test_deploy_evm():
res = await ezkl.deploy_evm(
addr_path,
sol_code_path,
anvil_url,
sol_code_path,
)
assert res == True
@@ -528,8 +528,8 @@ async def test_deploy_evm_with_private_key():
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
private_key=anvil_default_private_key
)
@@ -540,8 +540,8 @@ async def test_deploy_evm_with_private_key():
with pytest.raises(RuntimeError, match="Failed to run deploy_evm"):
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
private_key=custom_zero_balance_private_key
)
@@ -564,8 +564,8 @@ async def test_verify_evm():
res = await ezkl.verify_evm(
addr,
anvil_url,
proof_path,
rpc_url=anvil_url,
# sol_code_path
# optimizer_runs
)
@@ -604,8 +604,8 @@ async def test_verify_evm_separate_vk():
res = await ezkl.verify_evm(
addr_verifier,
anvil_url,
proof_path,
rpc_url=anvil_url,
addr_vk=addr_vk,
# sol_code_path
# optimizer_runs
@@ -831,8 +831,8 @@ async def test_evm_aggregate_and_verify_aggr():
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
)
# as a sanity check