mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abcd5380db | ||
|
|
076b737108 | ||
|
|
97d9832591 | ||
|
|
e0771683a6 | ||
|
|
319c222307 |
45
Cargo.lock
generated
45
Cargo.lock
generated
@@ -1867,7 +1867,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.6.0",
|
||||
"halo2curves 0.6.1",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -2262,7 +2262,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#6a2b9ada9804807ddba03bbadaf6e63822cec275"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec 1.0.1",
|
||||
@@ -2279,13 +2279,13 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#6a2b9ada9804807ddba03bbadaf6e63822cec275"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.6.0",
|
||||
"halo2curves 0.6.1",
|
||||
"icicle",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
@@ -2375,6 +2375,31 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.6.1"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c#9fff22c5f72cc54fac1ef3a844e1072b08cfecdf"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"ff",
|
||||
"group",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
@@ -3461,8 +3486,10 @@ dependencies = [
|
||||
"blake2b_simd",
|
||||
"ff",
|
||||
"group",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
]
|
||||
@@ -5686,6 +5713,16 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
|
||||
|
||||
[[package]]
|
||||
name = "unroll"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ad948c1cb799b1a70f836077721a92a35ac177d4daddf4c20a633786d4cf618"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.7.1"
|
||||
|
||||
@@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2curves = { version = "0.6.0", features = ["derive_serde"] }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"]}
|
||||
|
||||
@@ -41,7 +41,7 @@ pub struct KZGChip {
|
||||
}
|
||||
|
||||
impl KZGChip {
|
||||
/// Returns the number of inputs to the hash function
|
||||
/// Commit to the message using the KZG commitment scheme
|
||||
pub fn commit(
|
||||
message: Vec<Fp>,
|
||||
degree: u32,
|
||||
|
||||
@@ -128,13 +128,7 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
|
||||
|
||||
let mut scaled_unit =
|
||||
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
|
||||
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
|
||||
region.increment(scaled_unit.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
@@ -166,27 +160,14 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
// this is now of scale 2 * scale hence why we rescaled the unit scale
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), scaled_unit.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
|
||||
|
||||
// debug print the diff
|
||||
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
|
||||
|
||||
log::debug!("range_check_bracket: {:?}", range_check_bracket);
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
&[product],
|
||||
&(range_check_bracket, 3 * range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
|
||||
@@ -826,10 +826,7 @@ pub(crate) fn calibrate(
|
||||
let range = if let Some(scales) = scales {
|
||||
scales
|
||||
} else {
|
||||
match target {
|
||||
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
|
||||
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
|
||||
}
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
|
||||
|
||||
@@ -35,6 +35,8 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
@@ -61,6 +63,20 @@ pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
/// Max circuit area
|
||||
pub static ref EZKL_MAX_CIRCUIT_AREA: Option<usize> =
|
||||
if let Ok(max_circuit_area) = std::env::var("EZKL_MAX_CIRCUIT_AREA") {
|
||||
Some(max_circuit_area.parse().unwrap_or(0))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
@@ -530,6 +546,7 @@ impl GraphSettings {
|
||||
pub struct GraphConfig {
|
||||
model_config: ModelConfig,
|
||||
module_configs: ModuleConfigs,
|
||||
circuit_size: CircuitSize,
|
||||
}
|
||||
|
||||
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
|
||||
@@ -1366,7 +1383,6 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
struct CircuitSize {
|
||||
num_instances: usize,
|
||||
@@ -1374,20 +1390,22 @@ struct CircuitSize {
|
||||
num_fixed: usize,
|
||||
num_challenges: usize,
|
||||
num_selectors: usize,
|
||||
logrows: u32,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl CircuitSize {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>) -> Self {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
|
||||
CircuitSize {
|
||||
num_instances: cs.num_instance_columns(),
|
||||
num_advice_columns: cs.num_advice_columns(),
|
||||
num_fixed: cs.num_fixed_columns(),
|
||||
num_challenges: cs.num_challenges(),
|
||||
num_selectors: cs.num_selectors(),
|
||||
logrows,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
@@ -1398,6 +1416,25 @@ impl CircuitSize {
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// number of columns
|
||||
pub fn num_columns(&self) -> usize {
|
||||
self.num_instances + self.num_advice_columns + self.num_fixed
|
||||
}
|
||||
|
||||
/// area of the circuit
|
||||
pub fn area(&self) -> usize {
|
||||
self.num_columns() * (1 << self.logrows)
|
||||
}
|
||||
|
||||
/// area less than max
|
||||
pub fn area_less_than_max(&self) -> bool {
|
||||
if EZKL_MAX_CIRCUIT_AREA.is_some() {
|
||||
self.area() < EZKL_MAX_CIRCUIT_AREA.unwrap()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fp> for GraphCircuit {
|
||||
@@ -1472,10 +1509,12 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
(cs.degree() as f32).log2().ceil()
|
||||
);
|
||||
|
||||
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
"circuit size: \n {}",
|
||||
CircuitSize::from_cs(cs)
|
||||
circuit_size
|
||||
.as_json()
|
||||
.unwrap()
|
||||
.to_colored_json_auto()
|
||||
@@ -1485,6 +1524,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
GraphConfig {
|
||||
model_config,
|
||||
module_configs,
|
||||
circuit_size,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1497,6 +1537,16 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), PlonkError> {
|
||||
// check if the circuit area is less than the max
|
||||
if !config.circuit_size.area_less_than_max() {
|
||||
error!(
|
||||
"circuit area {} is larger than the max allowed area {}",
|
||||
config.circuit_size.area(),
|
||||
EZKL_MAX_CIRCUIT_AREA.unwrap()
|
||||
);
|
||||
return Err(PlonkError::Synthesis);
|
||||
}
|
||||
|
||||
trace!("Setting input in synthesize");
|
||||
let input_vis = &self.settings().run_args.input_visibility;
|
||||
let output_vis = &self.settings().run_args.output_visibility;
|
||||
|
||||
@@ -1174,18 +1174,29 @@ impl Model {
|
||||
};
|
||||
|
||||
debug!(
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
|
||||
idx,
|
||||
node.as_str(),
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants()
|
||||
region.total_constants(),
|
||||
region.max_lookup_inputs(),
|
||||
region.min_lookup_inputs()
|
||||
);
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
|
||||
@@ -40,8 +40,24 @@ use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error as thisError;
|
||||
|
||||
// not wasm
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
// Buf writer capacity
|
||||
lazy_static! {
|
||||
static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_BUF_CAPACITY: &usize = &8000;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(
|
||||
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
|
||||
@@ -315,7 +331,7 @@ where
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let file = std::fs::File::create(proof_path)?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(&mut writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -678,7 +694,7 @@ where
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
@@ -700,7 +716,7 @@ where
|
||||
info!("loading proving key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
@@ -720,7 +736,7 @@ where
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
@@ -737,7 +753,7 @@ where
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
@@ -750,7 +766,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
params.write(&mut writer)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
|
||||
@@ -560,7 +560,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -147,13 +147,13 @@ def test_calibrate():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -183,7 +183,7 @@ def test_model_compile():
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'network.onnx'
|
||||
)
|
||||
compiled_model_path = os.path.join(
|
||||
@@ -205,7 +205,7 @@ def test_forward():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
|
||||
Reference in New Issue
Block a user