Compare commits

...

3 Commits

Author SHA1 Message Date
dante
0bc36a0d04 fix: eager exec of ok_or error prints 2025-01-11 11:37:25 -05:00
dante
c5354c382d refactor: range check sanity toggled by CHECKMODE (#902) 2025-01-10 22:58:52 +00:00
dante
bdcba5ca61 feat: add gen-random-data helpers func (#901) 2025-01-09 00:14:27 +00:00
12 changed files with 479 additions and 306 deletions

View File

@@ -6,17 +6,6 @@ on:
description: "Test scenario tags"
jobs:
bench_elgamal:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Bench elgamal
run: cargo bench --verbose --bench elgamal
bench_poseidon:
runs-on: self-hosted

View File

@@ -1,279 +1,284 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"\n",
"# note that you can also call the following function to generate random data for the model\n",
"# it is functionally equivalent to the code above\n",
"ezkl.gen_random_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -938,6 +938,45 @@ fn gen_settings(
Ok(true)
}
/// Generates random data for the model
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the data file
///
/// seed: int
/// Random seed to use for generated data
///
/// variables
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn gen_random_data(
model: PathBuf,
output: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
) -> Result<bool, PyErr> {
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
let err_str = format!("Failed to generate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Calibrates the circuit settings
///
/// Arguments
@@ -2055,6 +2094,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;

View File

@@ -75,6 +75,16 @@ impl FromStr for CheckMode {
}
}
impl CheckMode {
/// Returns the value of the check mode
pub fn is_safe(&self) -> bool {
match self {
CheckMode::SAFE => true,
CheckMode::UNSAFE => false,
}
}
}
#[allow(missing_docs)]
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]

View File

@@ -4038,7 +4038,7 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd + std::hash::H
}
let is_assigned = !w.any_unknowns()?;
if is_assigned && region.check_range() {
if is_assigned && region.check_range() && config.check_mode.is_safe() {
// assert is within range
let int_values = w.int_evals()?;
for v in int_values.iter() {

View File

@@ -90,6 +90,8 @@ pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
/// Default commitment
pub const DEFAULT_COMMITMENT: &str = "kzg";
/// Default seed used to generate random data
pub const DEFAULT_SEED: &str = "21242";
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
@@ -422,7 +424,21 @@ pub enum Commands {
#[clap(flatten)]
args: RunArgs,
},
/// Generate random data for a model
GenRandomData {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// Hand-written parser for graph variables, eg. batch_size=1
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = crate::parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
variables: Vec<(String, usize)>,
/// random seed for reproducibility (optional)
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
seed: u64,
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
/// The path to the .json calibration data file.

View File

@@ -65,6 +65,8 @@ use std::str::FromStr;
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;
use tract_onnx::prelude::IntoTensor;
use tract_onnx::prelude::Tensor as TractTensor;
use lazy_static::lazy_static;
@@ -134,6 +136,17 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
args,
),
Commands::GenRandomData {
model,
data,
variables,
seed,
} => gen_random_data(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
variables,
seed,
),
Commands::CalibrateSettings {
model,
settings_path,
@@ -828,6 +841,71 @@ pub(crate) fn gen_circuit_settings(
Ok(String::new())
}
/// Generate a circuit settings file
pub(crate) fn gen_random_data(
model_path: PathBuf,
data_path: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
) -> Result<String, EZKLError> {
let mut file = std::fs::File::open(&model_path).map_err(|e| {
crate::graph::errors::GraphError::ReadWriteFileError(
model_path.display().to_string(),
e.to_string(),
)
})?;
let (tract_model, _symbol_values) = Model::load_onnx_using_tract(&mut file, &variables)?;
let input_facts = tract_model
.input_outlets()
.map_err(|e| EZKLError::from(e.to_string()))?
.iter()
.map(|&i| tract_model.outlet_fact(i))
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
.map_err(|e| EZKLError::from(e.to_string()))?;
/// Generates a random tensor of a given size and type.
fn random(
sizes: &[usize],
datum_type: tract_onnx::prelude::DatumType,
seed: u64,
) -> TractTensor {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.gen());
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
if let Some(value) = &fact.konst {
return value.clone().into_tensor();
}
random(
fact.shape
.as_concrete()
.expect("Expected concrete shape, found: {fact:?}"),
fact.datum_type,
seed,
)
}
let generated = input_facts
.iter()
.map(|v| tensor_for_fact(v, seed))
.collect_vec();
let data = GraphData::from_tract_data(&generated)?;
data.save(data_path)?;
Ok(String::new())
}
// not for wasm targets
pub(crate) fn init_spinner() -> ProgressBar {
let pb = indicatif::ProgressBar::new_spinner();

View File

@@ -557,6 +557,34 @@ impl GraphData {
Ok(inputs)
}
// not wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Convert the tract data to tract data
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
use tract_onnx::prelude::DatumType;
let mut input_data = vec![];
for tensor in tensors {
match tensor.datum_type() {
tract_onnx::prelude::DatumType::Bool => {
let tensor = tensor.to_array_view::<bool>()?;
let tensor = tensor.iter().map(|e| FileSourceInner::Bool(*e)).collect();
input_data.push(tensor);
}
_ => {
let cast_tensor = tensor.cast_to_dt(DatumType::F64)?;
let tensor = cast_tensor.to_array_view::<f64>()?;
let tensor = tensor.iter().map(|e| FileSourceInner::Float(*e)).collect();
input_data.push(tensor);
}
}
}
Ok(GraphData {
input_data: DataSource::File(input_data),
output_data: None,
})
}
///
pub fn new(input_data: DataSource) -> Self {
GraphData {

View File

@@ -281,7 +281,7 @@ impl GraphWitness {
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let witness: GraphWitness =
serde_json::from_reader(reader).map_err(|e| Into::<GraphError>::into(e))?;
serde_json::from_reader(reader).map_err(Into::<GraphError>::into)?;
// check versions match
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));

View File

@@ -621,16 +621,16 @@ impl Model {
/// * `scale` - The scale to use for quantization.
/// * `public_params` - Whether to make the params public.
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn load_onnx_using_tract(
pub(crate) fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
variables: &[(String, usize)],
) -> Result<TractResult, GraphError> {
use tract_onnx::tract_hir::internal::GenericFactoid;
let mut model = tract_onnx::onnx().model_for_read(reader)?;
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(run_args.variables.clone());
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let input = model.node_mut(id.node);
@@ -655,7 +655,7 @@ impl Model {
}
let mut symbol_values = SymbolValues::default();
for (symbol, value) in run_args.variables.iter() {
for (symbol, value) in variables.iter() {
let symbol = model.symbols.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
debug!("set {} to {}", symbol, value);
@@ -683,7 +683,7 @@ impl Model {
) -> Result<ParsedNodes, GraphError> {
let start_time = instant::Instant::now();
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
let (model, symbol_values) = Self::load_onnx_using_tract(reader, &run_args.variables)?;
let scales = VarScales::from_args(run_args);
let nodes = Self::nodes_from_graph(
@@ -964,7 +964,7 @@ impl Model {
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
})?;
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
let (model, _) = Model::load_onnx_using_tract(&mut file, &run_args.variables)?;
let datum_types: Vec<DatumType> = model
.input_outlets()?

View File

@@ -558,7 +558,7 @@ impl VarTensor {
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let mut res: ValTensor<F> = {
let mut res: ValTensor<F> =
v.enum_map(|coord, k| {
let step = self.num_inner_cols();
@@ -579,12 +579,18 @@ impl VarTensor {
prev_cell = Some(cell.clone());
} else if coord > 0 && at_beginning_of_column {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = cell.cell().ok_or({
let cell = if let Some(cell) = cell.cell() {
cell
} else {
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
return Err(halo2_proofs::plonk::Error::Synthesis);
};
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
prev_cell
} else {
error!("Error getting prev cell: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Previous cell was not set");
@@ -594,7 +600,8 @@ impl VarTensor {
Ok(cell)
})?.into()};
})?.into();
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();

File diff suppressed because one or more lines are too long