mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04d7b5feaa | ||
|
|
45fd12a04f | ||
|
|
bc7c33190f |
5
.github/workflows/rust.yml
vendored
5
.github/workflows/rust.yml
vendored
@@ -313,6 +313,8 @@ jobs:
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
@@ -601,6 +603,8 @@ jobs:
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
@@ -653,4 +657,3 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
@@ -16,7 +17,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: (i128, i128) = (-32768, 32768);
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
@@ -16,7 +17,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: (i128, i128) = (-8180, 8180);
|
||||
const BITS: Range = (-8180, 8180);
|
||||
static mut LEN: usize = 4;
|
||||
static mut K: usize = 16;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
@@ -14,7 +15,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::Rng;
|
||||
|
||||
const BITS: (i128, i128) = (-32768, 32768);
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
circuit::ops::base::BaseOp,
|
||||
circuit::{table::Table, utils},
|
||||
circuit::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
|
||||
@@ -176,6 +179,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
|
||||
///
|
||||
pub tables: BTreeMap<LookupOp, Table<F>>,
|
||||
///
|
||||
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
|
||||
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -194,7 +201,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
lookup_index: dummy_var,
|
||||
selectors: BTreeMap::new(),
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
check_mode: CheckMode::SAFE,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -325,11 +334,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Self {
|
||||
selectors,
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
inputs: inputs.to_vec(),
|
||||
lookup_input: VarTensor::Empty,
|
||||
lookup_output: VarTensor::Empty,
|
||||
lookup_index: VarTensor::Empty,
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
output: output.clone(),
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
@@ -344,7 +355,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
input: &VarTensor,
|
||||
output: &VarTensor,
|
||||
index: &VarTensor,
|
||||
lookup_range: (i128, i128),
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
nl: &LookupOp,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
@@ -482,6 +493,74 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_range_check(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
input: &VarTensor,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check = if !self.range_checks.contains_key(&range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
self.range_checks.insert(range, range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for x in 0..input.num_blocks() {
|
||||
for y in 0..input.num_inner_cols() {
|
||||
let single_col_sel = cs.complex_selector();
|
||||
|
||||
cs.lookup("", |cs| {
|
||||
let mut res = vec![];
|
||||
let sel = cs.query_selector(single_col_sel);
|
||||
|
||||
let input_query = match &input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let default_x = range_check.get_first_element();
|
||||
|
||||
let not_sel = Expression::Constant(F::ONE) - sel.clone();
|
||||
|
||||
res.extend([(
|
||||
sel.clone() * input_query.clone()
|
||||
+ not_sel.clone() * Expression::Constant(default_x),
|
||||
range_check.input,
|
||||
)]);
|
||||
|
||||
res
|
||||
});
|
||||
selectors.insert((range, x, y), single_col_sel);
|
||||
}
|
||||
}
|
||||
self.range_check_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
debug!("assigning lookup input");
|
||||
self.lookup_input = input.clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// layout_tables must be called before layout.
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
for (i, table) in self.tables.values_mut().enumerate() {
|
||||
@@ -500,6 +579,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// layout_range_checks must be called before layout.
|
||||
pub fn layout_range_checks(
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
for range_check in self.range_checks.values_mut() {
|
||||
if !range_check.is_assigned {
|
||||
debug!("laying out range check for {:?}", range_check.range);
|
||||
range_check.layout(layouter)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Assigns variables to the regions created when calling `configure`.
|
||||
/// # Arguments
|
||||
/// * `values` - The explicit values to the operations.
|
||||
|
||||
@@ -19,7 +19,7 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
circuit::{ops::base::BaseOp, utils},
|
||||
fieldutils::i128_to_felt,
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
@@ -51,6 +51,66 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
|
||||
total_len
|
||||
}
|
||||
|
||||
/// Div accumulated layout
|
||||
pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
div: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_i128(div) - 1;
|
||||
|
||||
let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
|
||||
divisor.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let divisor = region.assign(&config.inputs[1], &divisor.into())?;
|
||||
region.increment(divisor.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
let divisor_evals = divisor.get_int_evals()?;
|
||||
tensor::ops::div(&[input_evals.clone(), divisor_evals.clone()])?
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), divisor.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), input.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Dot product accumulated layout
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2304,6 +2364,50 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// layout for range check.
|
||||
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
range: &crate::circuit::table::Range,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
|
||||
let x = values[0].clone();
|
||||
|
||||
let w = region.assign(&config.lookup_input, &x)?;
|
||||
|
||||
let assigned_len = x.len();
|
||||
|
||||
let is_dummy = region.is_dummy();
|
||||
|
||||
if !is_dummy {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config
|
||||
.lookup_input
|
||||
.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.range_check_selectors.get(&(range.clone(), x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
let elapsed = timer.elapsed();
|
||||
trace!(
|
||||
"range check {:?} layout took {:?}, row: {:?}",
|
||||
range,
|
||||
elapsed,
|
||||
region.row()
|
||||
);
|
||||
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
/// layout for nonlinearity check.
|
||||
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
|
||||
@@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::{multiplier_to_scale, scale_to_multiplier},
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
@@ -57,7 +57,7 @@ pub enum LookupOp {
|
||||
|
||||
impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> (i128, i128) {
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as i128;
|
||||
(-range, range)
|
||||
|
||||
@@ -10,6 +10,8 @@ use halo2curves::ff::PrimeField;
|
||||
|
||||
use self::{lookup::LookupOp, region::RegionCtx};
|
||||
|
||||
use super::table::Range;
|
||||
|
||||
///
|
||||
pub mod base;
|
||||
///
|
||||
@@ -60,6 +62,11 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns the range checks required by the operation.
|
||||
fn required_range_checks(&self) -> Vec<Range> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns true if the operation is an input.
|
||||
fn is_input(&self) -> bool {
|
||||
false
|
||||
|
||||
@@ -237,7 +237,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
PolyOp::Neg => layouts::neg(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Einsum { equation } => layouts::einsum(config, region, &values, equation)?,
|
||||
PolyOp::Einsum { equation } => layouts::einsum(config, region, values, equation)?,
|
||||
PolyOp::Sum { axes } => {
|
||||
layouts::sum_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
@@ -290,12 +290,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MultiBroadcastTo { .. } => in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Neg => in_scales[0],
|
||||
PolyOp::MoveAxis { .. } => in_scales[0],
|
||||
PolyOp::Downsample { .. } => in_scales[0],
|
||||
PolyOp::Resize { .. } => in_scales[0],
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
let mut scale = in_scales[0];
|
||||
@@ -339,14 +334,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
scale += in_scales[1];
|
||||
scale
|
||||
}
|
||||
PolyOp::Identity => in_scales[0],
|
||||
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
|
||||
PolyOp::Pad(_) => in_scales[0],
|
||||
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
|
||||
PolyOp::Pack(_, _) => in_scales[0],
|
||||
PolyOp::GlobalSumPool => in_scales[0],
|
||||
PolyOp::Concat { axis: _ } => in_scales[0],
|
||||
PolyOp::Slice { .. } => in_scales[0],
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,9 @@ use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
@@ -91,7 +94,7 @@ pub struct Table<F: PrimeField> {
|
||||
/// Flags if table has been previously assigned to.
|
||||
pub is_assigned: bool,
|
||||
/// Number of bits used in lookup table.
|
||||
pub range: (i128, i128),
|
||||
pub range: Range,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
@@ -129,7 +132,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range: (i128, i128), col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
|
||||
// double it to be safe
|
||||
let range_len = range.1 - range.0;
|
||||
// number of cols needed to store the range
|
||||
@@ -141,7 +144,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
range: (i128, i128),
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
@@ -257,3 +260,86 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Halo2 range check column
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RangeCheck<F: PrimeField> {
|
||||
/// Input to table.
|
||||
pub input: TableColumn,
|
||||
/// selector cn
|
||||
pub selector_constructor: SelectorConstructor<F>,
|
||||
/// Flags if table has been previously assigned to.
|
||||
pub is_assigned: bool,
|
||||
/// Number of bits used in lookup table.
|
||||
pub range: Range,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self) -> F {
|
||||
i128_to_felt(self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
|
||||
let inputs = cs.lookup_table_column();
|
||||
|
||||
RangeCheck {
|
||||
input: inputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(2),
|
||||
range,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values to the constraints generated when calling `configure`.
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
layouter.assign_table(
|
||||
|| "range check table",
|
||||
|mut table| {
|
||||
let _ = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(row_offset, input)| {
|
||||
table.assign_cell(
|
||||
|| format!("rc_i_col row {}", row_offset),
|
||||
self.input,
|
||||
row_offset,
|
||||
|| Value::known(*input),
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,6 +59,8 @@ pub const DEFAULT_SOL_CODE_DA: &str = "evm_deploy_da.sol";
|
||||
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
|
||||
/// Default contract address for data attestation
|
||||
pub const DEFAULT_CONTRACT_ADDRESS_DA: &str = "contract_da.address";
|
||||
/// Default contract address for vk
|
||||
pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
|
||||
/// Default check mode
|
||||
pub const DEFAULT_CHECKMODE: &str = "safe";
|
||||
/// Default calibration target
|
||||
@@ -75,6 +77,14 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
|
||||
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
|
||||
/// Default Compress selectors
|
||||
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
|
||||
/// Default render vk seperately
|
||||
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
/// Default VK sol path
|
||||
pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -315,9 +325,20 @@ pub enum Commands {
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
|
||||
#[arg(
|
||||
long,
|
||||
value_delimiter = ',',
|
||||
allow_hyphen_values = true,
|
||||
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
|
||||
)]
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
|
||||
#[arg(long)]
|
||||
div_rebasing: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -584,6 +605,31 @@ pub enum Commands {
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
|
||||
abi_path: PathBuf,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
|
||||
render_vk_seperately: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier for a single proof
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEVMVK {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
vk_path: PathBuf,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
sol_code_path: PathBuf,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI)]
|
||||
abi_path: PathBuf,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
|
||||
@@ -629,6 +675,11 @@ pub enum Commands {
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
|
||||
render_vk_seperately: bool,
|
||||
},
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
@@ -680,6 +731,25 @@ pub enum Commands {
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVK {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
sol_code_path: PathBuf,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
|
||||
/// The path to output the contract address
|
||||
addr_path: PathBuf,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
@@ -721,6 +791,9 @@ pub enum Commands {
|
||||
/// does the verifier use data attestation ?
|
||||
#[arg(long)]
|
||||
addr_da: Option<H160>,
|
||||
// is the vk rendered seperately, if so specify an address
|
||||
#[arg(long)]
|
||||
addr_vk: Option<H160>,
|
||||
},
|
||||
|
||||
/// Print the proof in hexadecimal
|
||||
|
||||
20
src/eth.rs
20
src/eth.rs
@@ -101,17 +101,18 @@ pub async fn setup_eth_backend(
|
||||
}
|
||||
|
||||
///
|
||||
pub async fn deploy_verifier_via_solidity(
|
||||
pub async fn deploy_contract_via_solidity(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
) -> Result<ethers::types::Address, Box<dyn Error>> {
|
||||
// anvil instance must be alive at least until the factory completes the deploy
|
||||
let (anvil, client) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
get_contract_artifacts(sol_code_path, "Halo2Verifier", runs)?;
|
||||
get_contract_artifacts(sol_code_path, contract_name, runs)?;
|
||||
|
||||
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?;
|
||||
let contract = factory.deploy(())?.send().await?;
|
||||
@@ -335,11 +336,16 @@ pub async fn update_account_calls(
|
||||
pub async fn verify_proof_via_solidity(
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
addr: ethers::types::Address,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
let flattened_instances = proof.instances.into_iter().flatten();
|
||||
|
||||
let encoded = encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
|
||||
let encoded = encode_calldata(
|
||||
addr_vk.as_ref().map(|x| x.0),
|
||||
&proof.proof,
|
||||
&flattened_instances.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
info!("encoded: {:#?}", hex::encode(&encoded));
|
||||
let (anvil, client) = setup_eth_backend(rpc_url, None).await?;
|
||||
@@ -439,6 +445,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
addr_verifier: ethers::types::Address,
|
||||
addr_da: ethers::types::Address,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
use ethers::abi::{Function, Param, ParamType, StateMutability, Token};
|
||||
@@ -452,8 +459,11 @@ pub async fn verify_proof_with_data_attestation(
|
||||
public_inputs.push(u);
|
||||
}
|
||||
|
||||
let encoded_verifier =
|
||||
encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
|
||||
let encoded_verifier = encode_calldata(
|
||||
addr_vk.as_ref().map(|x| x.0),
|
||||
&proof.proof,
|
||||
&flattened_instances.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
info!("encoded: {:#?}", hex::encode(&encoded_verifier));
|
||||
|
||||
|
||||
158
src/execute.rs
158
src/execute.rs
@@ -3,7 +3,7 @@ use crate::circuit::CheckMode;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::Commands;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity};
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(unused_imports)]
|
||||
use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity};
|
||||
@@ -176,7 +176,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
div_rebasing,
|
||||
} => calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -184,6 +186,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
max_logrows,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -204,7 +208,22 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
render_vk_seperately,
|
||||
} => create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
),
|
||||
Commands::CreateEVMVK {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMDataAttestation {
|
||||
settings_path,
|
||||
@@ -220,6 +239,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
} => create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -227,6 +247,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
),
|
||||
Commands::CompileCircuit {
|
||||
model,
|
||||
@@ -368,6 +389,25 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmVK {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -398,7 +438,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await,
|
||||
addr_vk,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
|
||||
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
|
||||
}
|
||||
}
|
||||
@@ -448,7 +489,7 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = sha256::digest(&std::fs::read(path.clone())?);
|
||||
let hash = sha256::digest(std::fs::read(path.clone())?);
|
||||
info!("SRS hash: {}", hash);
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
@@ -456,7 +497,7 @@ fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box
|
||||
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
|
||||
};
|
||||
|
||||
if hash != predefined_hash.to_string() {
|
||||
if hash != *predefined_hash {
|
||||
// delete file
|
||||
warn!("removing SRS file at {}", path.display());
|
||||
std::fs::remove_file(path)?;
|
||||
@@ -678,8 +719,7 @@ impl AccuracyResults {
|
||||
let error = (original.clone() - calibrated.clone())?;
|
||||
let abs_error = error.map(|x| x.abs());
|
||||
let squared_error = error.map(|x| x.powi(2));
|
||||
let percentage_error =
|
||||
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i].clone()))?;
|
||||
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let abs_percentage_error = percentage_error.map(|x| x.abs());
|
||||
|
||||
errors.extend(error.into_iter());
|
||||
@@ -695,29 +735,25 @@ impl AccuracyResults {
|
||||
abs_percentage_errors.iter().sum::<f32>() / abs_percentage_errors.len() as f32;
|
||||
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
|
||||
let median_error = errors[errors.len() / 2];
|
||||
let max_error = errors
|
||||
let max_error = *errors
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
let min_error = errors
|
||||
.unwrap();
|
||||
let min_error = *errors
|
||||
.iter()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
.unwrap();
|
||||
|
||||
let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
|
||||
let median_abs_error = abs_errors[abs_errors.len() / 2];
|
||||
let max_abs_error = abs_errors
|
||||
let max_abs_error = *abs_errors
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
let min_abs_error = abs_errors
|
||||
.unwrap();
|
||||
let min_abs_error = *abs_errors
|
||||
.iter()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
.unwrap();
|
||||
|
||||
let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
|
||||
|
||||
@@ -747,6 +783,8 @@ pub(crate) fn calibrate(
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
div_rebasing: Option<bool>,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
use std::collections::HashMap;
|
||||
@@ -790,9 +828,13 @@ pub(crate) fn calibrate(
|
||||
}
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
|
||||
vec![div_rebasing]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let scale_rebase_multiplier = [1, 2, 10];
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
// 2 x 2 grid
|
||||
let range_grid = range
|
||||
@@ -829,15 +871,21 @@ pub(crate) fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
pb.set_message("calibrating...");
|
||||
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
|
||||
#[cfg(unix)]
|
||||
@@ -857,6 +905,7 @@ pub(crate) fn calibrate(
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
@@ -931,6 +980,7 @@ pub(crate) fn calibrate(
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
@@ -940,6 +990,7 @@ pub(crate) fn calibrate(
|
||||
let found_settings = GraphSettings {
|
||||
run_args: found_run_args,
|
||||
required_lookups: new_settings.required_lookups,
|
||||
required_range_checks: new_settings.required_range_checks,
|
||||
model_output_scales: new_settings.model_output_scales,
|
||||
model_input_scales: new_settings.model_input_scales,
|
||||
num_rows: new_settings.num_rows,
|
||||
@@ -1157,6 +1208,7 @@ pub(crate) fn create_evm_verifier(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
@@ -1174,7 +1226,11 @@ pub(crate) fn create_evm_verifier(
|
||||
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
|
||||
num_instance,
|
||||
);
|
||||
let verifier_solidity = generator.render()?;
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
} else {
|
||||
generator.render()?
|
||||
};
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
@@ -1186,6 +1242,43 @@ pub(crate) fn create_evm_verifier(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_vk(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
|
||||
|
||||
let num_instance = circuit_settings.total_instances();
|
||||
let num_instance: usize = num_instance.iter().sum::<usize>();
|
||||
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
trace!("params computed");
|
||||
|
||||
let generator = halo2_solidity_verifier::SolidityGenerator::new(
|
||||
¶ms,
|
||||
&vk,
|
||||
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
|
||||
num_instance,
|
||||
);
|
||||
|
||||
let vk_solidity = generator.render_separately()?.1;
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0)?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
@@ -1281,13 +1374,15 @@ pub(crate) async fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let contract_address = deploy_verifier_via_solidity(
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1304,6 +1399,7 @@ pub(crate) async fn verify_evm(
|
||||
addr_verifier: H160,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<H160>,
|
||||
addr_vk: Option<H160>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
check_solc_requirement();
|
||||
@@ -1315,11 +1411,12 @@ pub(crate) async fn verify_evm(
|
||||
proof.clone(),
|
||||
addr_verifier,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
verify_proof_via_solidity(proof.clone(), addr_verifier, rpc_url.as_deref()).await?
|
||||
verify_proof_via_solidity(proof.clone(), addr_verifier, addr_vk, rpc_url.as_deref()).await?
|
||||
};
|
||||
|
||||
info!("Solidity verification result: {}", result);
|
||||
@@ -1339,6 +1436,7 @@ pub(crate) fn create_evm_aggregate_verifier(
|
||||
abi_path: PathBuf,
|
||||
circuit_settings: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let srs_path = get_srs_path(logrows, srs_path);
|
||||
@@ -1377,7 +1475,11 @@ pub(crate) fn create_evm_aggregate_verifier(
|
||||
|
||||
generator = generator.set_acc_encoding(Some(acc_encoding));
|
||||
|
||||
let verifier_solidity = generator.render()?;
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
} else {
|
||||
generator.render()?
|
||||
};
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::table::{Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::pfsys::PrettyElements;
|
||||
@@ -431,6 +431,8 @@ pub struct GraphSettings {
|
||||
pub module_sizes: ModuleSizes,
|
||||
/// required_lookups
|
||||
pub required_lookups: Vec<LookupOp>,
|
||||
/// required range_checks
|
||||
pub required_range_checks: Vec<Range>,
|
||||
/// check mode
|
||||
pub check_mode: CheckMode,
|
||||
/// ezkl version used
|
||||
@@ -639,7 +641,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
// dummy module settings, must load from GraphData after
|
||||
let mut settings = model.gen_params(run_args, CheckMode::UNSAFE)?;
|
||||
let mut settings = model.gen_params(run_args, run_args.check_mode)?;
|
||||
|
||||
let mut num_params = 0;
|
||||
if !model.const_shapes().is_empty() {
|
||||
@@ -763,18 +765,18 @@ impl GraphCircuit {
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.rescaled_inputs = elements.rescaled_inputs.clone();
|
||||
public_inputs.inputs = elements.inputs.clone();
|
||||
} else if let Some(_) = &data.processed_inputs {
|
||||
} else if data.processed_inputs.is_some() {
|
||||
public_inputs.processed_inputs = elements.processed_inputs.clone();
|
||||
}
|
||||
|
||||
if let Some(_) = &data.processed_params {
|
||||
if data.processed_params.is_some() {
|
||||
public_inputs.processed_params = elements.processed_params.clone();
|
||||
}
|
||||
|
||||
if self.settings().run_args.output_visibility.is_public() {
|
||||
public_inputs.rescaled_outputs = elements.rescaled_outputs.clone();
|
||||
public_inputs.outputs = elements.outputs.clone();
|
||||
} else if let Some(_) = &data.processed_outputs {
|
||||
} else if data.processed_outputs.is_some() {
|
||||
public_inputs.processed_outputs = elements.processed_outputs.clone();
|
||||
}
|
||||
|
||||
@@ -960,19 +962,20 @@ impl GraphCircuit {
|
||||
min_lookup_inputs: i128,
|
||||
max_lookup_inputs: i128,
|
||||
lookup_safety_margin: i128,
|
||||
) -> (i128, i128) {
|
||||
) -> Range {
|
||||
let mut margin = (
|
||||
lookup_safety_margin * min_lookup_inputs,
|
||||
lookup_safety_margin * max_lookup_inputs,
|
||||
);
|
||||
if lookup_safety_margin == 1 {
|
||||
margin.0 -= 1;
|
||||
margin.0 += 1;
|
||||
margin.1 += 1;
|
||||
}
|
||||
|
||||
margin
|
||||
}
|
||||
|
||||
fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(
|
||||
max_logrows as usize,
|
||||
Self::reserved_blinding_rows() as usize,
|
||||
@@ -1133,7 +1136,7 @@ impl GraphCircuit {
|
||||
|
||||
while (1 << extended_k) < (n * quotient_poly_degree) {
|
||||
extended_k += 1;
|
||||
if !(extended_k <= bn256::Fr::S) {
|
||||
if extended_k > bn256::Fr::S {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1456,6 +1459,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
params.run_args.lookup_range,
|
||||
params.run_args.logrows as usize,
|
||||
params.required_lookups,
|
||||
params.required_range_checks,
|
||||
params.check_mode,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -6,6 +6,7 @@ use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
@@ -240,6 +241,13 @@ impl NodeType {
|
||||
NodeType::SubGraph { model, .. } => model.required_lookups(),
|
||||
}
|
||||
}
|
||||
/// Returns the lookups required by a graph
|
||||
pub fn required_range_checks(&self) -> Vec<Range> {
|
||||
match self {
|
||||
NodeType::Node(n) => n.opkind.required_range_checks(),
|
||||
NodeType::SubGraph { model, .. } => model.required_range_checks(),
|
||||
}
|
||||
}
|
||||
/// Returns the scales of the node's output.
|
||||
pub fn out_scales(&self) -> Vec<crate::Scale> {
|
||||
match self {
|
||||
@@ -432,6 +440,15 @@ impl Model {
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
///
|
||||
fn required_range_checks(&self) -> Vec<Range> {
|
||||
self.graph
|
||||
.nodes
|
||||
.values()
|
||||
.flat_map(|n| n.required_range_checks())
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Creates a `Model` from a specified path to an Onnx file.
|
||||
/// # Arguments
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
@@ -489,6 +506,8 @@ impl Model {
|
||||
|
||||
// extract the requisite lookup ops from the model
|
||||
let mut lookup_ops: Vec<LookupOp> = self.required_lookups();
|
||||
// extract the requisite lookup ops from the model
|
||||
let mut range_checks: Vec<Range> = self.required_range_checks();
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -504,6 +523,9 @@ impl Model {
|
||||
let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
|
||||
lookup_ops.extend(set.into_iter().sorted());
|
||||
|
||||
let set: HashSet<_> = range_checks.drain(..).collect(); // dedup
|
||||
range_checks.extend(set.into_iter().sorted());
|
||||
|
||||
Ok(GraphSettings {
|
||||
run_args: run_args.clone(),
|
||||
model_instance_shapes: instance_shapes,
|
||||
@@ -511,6 +533,7 @@ impl Model {
|
||||
num_rows,
|
||||
total_assignments: linear_coord,
|
||||
required_lookups: lookup_ops,
|
||||
required_range_checks: range_checks,
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
total_const_size,
|
||||
@@ -611,7 +634,7 @@ impl Model {
|
||||
debug!("intermediate min lookup inputs: {}", min);
|
||||
}
|
||||
debug!(
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} ------------ scale: {}",
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
|
||||
idx,
|
||||
res.output.map(crate::fieldutils::felt_to_i32).show(),
|
||||
res.output
|
||||
@@ -1042,6 +1065,7 @@ impl Model {
|
||||
&run_args.param_visibility,
|
||||
i,
|
||||
symbol_values,
|
||||
run_args.div_rebasing,
|
||||
)?;
|
||||
if let Some(ref scales) = override_input_scales {
|
||||
if let Some(inp) = n.opkind.get_input() {
|
||||
@@ -1058,9 +1082,20 @@ impl Model {
|
||||
if scales.contains_key(&i) {
|
||||
let scale_diff = n.out_scale - scales[&i];
|
||||
n.opkind = if scale_diff > 0 {
|
||||
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
|
||||
RebaseScale::rebase(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
1,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
} else {
|
||||
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
|
||||
RebaseScale::rebase_up(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
};
|
||||
n.out_scale = scales[&i];
|
||||
}
|
||||
@@ -1155,9 +1190,10 @@ impl Model {
|
||||
pub fn configure(
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
lookup_range: (i128, i128),
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
required_lookups: Vec<LookupOp>,
|
||||
required_range_checks: Vec<Range>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
@@ -1176,6 +1212,10 @@ impl Model {
|
||||
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
|
||||
}
|
||||
|
||||
for range in required_range_checks {
|
||||
base_gate.configure_range_check(meta, input, range)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1216,6 +1256,7 @@ impl Model {
|
||||
let instance_idx = vars.get_instance_idx();
|
||||
|
||||
config.base.layout_tables(layouter)?;
|
||||
config.base.layout_range_checks(layouter)?;
|
||||
|
||||
let mut num_rows = 0;
|
||||
let mut linear_coord = 0;
|
||||
|
||||
@@ -132,6 +132,8 @@ pub struct RebaseScale {
|
||||
pub target_scale: i32,
|
||||
/// The original scale of the operation's inputs.
|
||||
pub original_scale: i32,
|
||||
/// if true then the operation is a multiplicative division
|
||||
pub div_rebasing: bool,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
@@ -141,6 +143,7 @@ impl RebaseScale {
|
||||
global_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
scale_rebase_multiplier: u32,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
@@ -154,6 +157,7 @@ impl RebaseScale {
|
||||
target_scale: op.target_scale,
|
||||
multiplier: op.multiplier * multiplier,
|
||||
original_scale: op.original_scale,
|
||||
div_rebasing,
|
||||
})
|
||||
} else {
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
@@ -161,6 +165,7 @@ impl RebaseScale {
|
||||
target_scale: global_scale * scale_rebase_multiplier as i32,
|
||||
multiplier,
|
||||
original_scale: op_out_scale,
|
||||
div_rebasing,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
@@ -173,6 +178,7 @@ impl RebaseScale {
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
|
||||
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
|
||||
@@ -182,6 +188,7 @@ impl RebaseScale {
|
||||
target_scale: op.target_scale,
|
||||
multiplier: op.multiplier * multiplier,
|
||||
original_scale: op.original_scale,
|
||||
div_rebasing,
|
||||
})
|
||||
} else {
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
@@ -189,12 +196,22 @@ impl RebaseScale {
|
||||
target_scale,
|
||||
multiplier,
|
||||
original_scale: op_out_scale,
|
||||
div_rebasing,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
inner
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the require range bracket for the operation
|
||||
fn range_bracket(&self) -> i128 {
|
||||
if self.div_rebasing {
|
||||
0
|
||||
} else {
|
||||
self.multiplier as i128 - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Op<Fp> for RebaseScale {
|
||||
@@ -203,19 +220,28 @@ impl Op<Fp> for RebaseScale {
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let ri = res.output.map(felt_to_i128);
|
||||
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
|
||||
res.output = rescaled.map(i128_to_felt);
|
||||
|
||||
res.intermediate_lookups.push(ri);
|
||||
if self.div_rebasing {
|
||||
let ri = res.output.map(felt_to_i128);
|
||||
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
|
||||
res.output = rescaled.map(i128_to_felt);
|
||||
res.intermediate_lookups.push(ri);
|
||||
} else {
|
||||
let ri = res.output.map(felt_to_i128);
|
||||
let divisor = Tensor::from(vec![self.multiplier as i128].into_iter());
|
||||
let rescaled = crate::tensor::ops::div(&[ri, divisor.clone()])?;
|
||||
res.output = rescaled.map(i128_to_felt);
|
||||
res.intermediate_lookups.extend([-divisor.clone(), divisor]);
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}) ({})",
|
||||
"REBASED (div={:?}, div_r={}) ({})",
|
||||
self.multiplier,
|
||||
self.div_rebasing,
|
||||
self.inner.as_string()
|
||||
)
|
||||
}
|
||||
@@ -225,13 +251,24 @@ impl Op<Fp> for RebaseScale {
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
let mut lookups = self.inner.required_lookups();
|
||||
lookups.push(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
});
|
||||
let mut lookups: Vec<LookupOp> = self.inner.required_lookups();
|
||||
if self.div_rebasing {
|
||||
lookups.push(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
});
|
||||
}
|
||||
lookups
|
||||
}
|
||||
|
||||
fn required_range_checks(&self) -> Vec<crate::circuit::table::Range> {
|
||||
let mut range_checks = self.inner.required_range_checks();
|
||||
if !self.div_rebasing {
|
||||
let bracket = self.range_bracket();
|
||||
range_checks.push((-bracket, bracket));
|
||||
}
|
||||
range_checks
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -243,14 +280,23 @@ impl Op<Fp> for RebaseScale {
|
||||
.layout(config, region, values)?
|
||||
.ok_or("no layout")?;
|
||||
|
||||
Ok(Some(crate::circuit::layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[original_res],
|
||||
&LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
},
|
||||
)?))
|
||||
if !self.div_rebasing {
|
||||
Ok(Some(crate::circuit::layouts::div(
|
||||
config,
|
||||
region,
|
||||
&[original_res],
|
||||
Fp::from(self.multiplier as u64),
|
||||
)?))
|
||||
} else {
|
||||
Ok(Some(crate::circuit::layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[original_res],
|
||||
&LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
},
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
@@ -437,6 +483,10 @@ impl Op<Fp> for SupportedOp {
|
||||
self.as_op().required_lookups()
|
||||
}
|
||||
|
||||
fn required_range_checks(&self) -> Vec<crate::circuit::table::Range> {
|
||||
self.as_op().required_range_checks()
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
@@ -477,6 +527,7 @@ impl Tabled for Node {
|
||||
"inputs",
|
||||
"out_dims",
|
||||
"required_lookups",
|
||||
"required_range_checks",
|
||||
] {
|
||||
headers.push(std::borrow::Cow::Borrowed(i));
|
||||
}
|
||||
@@ -498,6 +549,10 @@ impl Tabled for Node {
|
||||
.map(<LookupOp as Op<Fp>>::as_string)
|
||||
.collect_vec()
|
||||
)));
|
||||
fields.push(std::borrow::Cow::Owned(format!(
|
||||
"{:?}",
|
||||
self.opkind.required_range_checks()
|
||||
)));
|
||||
fields
|
||||
}
|
||||
}
|
||||
@@ -527,6 +582,7 @@ impl Node {
|
||||
param_visibility: &Visibility,
|
||||
idx: usize,
|
||||
symbol_values: &SymbolValues,
|
||||
div_rebasing: bool,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
use log::warn;
|
||||
|
||||
@@ -631,7 +687,13 @@ impl Node {
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
opkind = RebaseScale::rebase(
|
||||
opkind,
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -71,8 +71,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i12
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_i128(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier - shift;
|
||||
float_rep
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
@@ -717,8 +716,8 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
// Extract the slope layer hyperparams
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
// If the input scale is larger than the params scale
|
||||
let scale_diff = std::cmp::max(scales.input, scales.params) - inputs[0].out_scales()[0];
|
||||
let additional_scale = if scale_diff > 0 {
|
||||
scale_to_multiplier(scale_diff)
|
||||
|
||||
12
src/lib.rs
12
src/lib.rs
@@ -29,7 +29,7 @@
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
use circuit::Tolerance;
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use graph::Visibility;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -91,7 +91,7 @@ pub struct RunArgs {
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
|
||||
pub lookup_range: (i128, i128),
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
pub logrows: u32,
|
||||
@@ -110,6 +110,12 @@ pub struct RunArgs {
|
||||
/// Flags whether params are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
pub param_visibility: Visibility,
|
||||
#[arg(long, default_value = "false")]
|
||||
/// Multiplicative division
|
||||
pub div_rebasing: bool,
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[arg(long, default_value = "unsafe")]
|
||||
pub check_mode: CheckMode,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -126,6 +132,8 @@ impl Default for RunArgs {
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
div_rebasing: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
#[pyo3(get, set)]
|
||||
pub lookup_range: (i128, i128),
|
||||
pub lookup_range: crate::circuit::table::Range,
|
||||
#[pyo3(get, set)]
|
||||
pub logrows: u32,
|
||||
#[pyo3(get, set)]
|
||||
@@ -159,6 +159,10 @@ struct PyRunArgs {
|
||||
pub param_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
pub check_mode: CheckMode,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -185,6 +189,8 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
check_mode: py_run_args.check_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,6 +209,8 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
check_mode: self.check_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -511,7 +519,9 @@ fn gen_settings(
|
||||
target = CalibrationTarget::default(), // default is "resources
|
||||
lookup_safety_margin = DEFAULT_LOOKUP_SAFETY_MARGIN.parse().unwrap(),
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
div_rebasing = None,
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
data: PathBuf,
|
||||
@@ -520,7 +530,9 @@ fn calibrate_settings(
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
div_rebasing: Option<bool>,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
@@ -529,6 +541,8 @@ fn calibrate_settings(
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
@@ -807,6 +821,7 @@ fn verify_aggr(
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier(
|
||||
vk_path: PathBuf,
|
||||
@@ -814,12 +829,20 @@ fn create_evm_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
crate::execute::create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
@@ -885,7 +908,7 @@ fn setup_test_evm_witness(
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
@@ -902,6 +925,39 @@ fn deploy_evm(
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_vk_evm(
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<bool, PyErr> {
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
@@ -953,12 +1009,14 @@ fn deploy_da_evm(
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
rpc_url=None,
|
||||
addr_da = None,
|
||||
addr_vk = None,
|
||||
))]
|
||||
fn verify_evm(
|
||||
addr_verifier: &str,
|
||||
proof_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<&str>,
|
||||
addr_vk: Option<&str>,
|
||||
) -> Result<bool, PyErr> {
|
||||
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
@@ -973,6 +1031,15 @@ fn verify_evm(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let addr_vk = if let Some(addr_vk) = addr_vk {
|
||||
let addr_vk = H160::from_str(addr_vk).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Some(addr_vk)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
@@ -981,6 +1048,7 @@ fn verify_evm(
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify_evm: {}", e);
|
||||
@@ -998,6 +1066,7 @@ fn verify_evm(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier_aggr(
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
@@ -1006,6 +1075,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path: PathBuf,
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
@@ -1014,6 +1084,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
|
||||
@@ -1067,6 +1138,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
|
||||
|
||||
@@ -30,11 +30,11 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use std::cmp::max;
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
use thiserror::Error;
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -1452,6 +1452,43 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
/// # Arguments
|
||||
/// * `self` - Tensor
|
||||
/// * `rhs` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use std::ops::Rem;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.rem(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the broadcasted shape of two tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::get_broadcasted_shape;
|
||||
|
||||
@@ -950,8 +950,7 @@ pub fn neg<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sy
|
||||
/// Elementwise multiplies multiple tensors.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Tensor
|
||||
/// * `t` - Tensors
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -993,6 +992,45 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Divides multiple tensors.
|
||||
/// # Arguments
|
||||
/// * `t` - Tensors
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::div;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i128>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = div(&[x, k]).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn div<
|
||||
T: TensorType
|
||||
+ Div<Output = T>
|
||||
+ Mul<Output = T>
|
||||
+ From<u64>
|
||||
+ std::marker::Send
|
||||
+ std::marker::Sync,
|
||||
>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
for e in t[1..].iter() {
|
||||
output = (output / e.clone())?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Rescale a tensor with a const integer (similar to const_mult).
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
@@ -512,13 +512,23 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_div_rebase_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -528,7 +538,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -538,7 +548,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -549,7 +559,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -826,10 +836,10 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, Some(vec![0,1]), true, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
@@ -839,7 +849,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, Some(vec![0,1]), true, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
test_dir.close().unwrap();
|
||||
@@ -855,7 +865,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, Some(vec![0,6]), false, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, None, false, "single");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -865,7 +875,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", Some(vec![0,6]));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -882,7 +892,8 @@ mod native_tests {
|
||||
use crate::native_tests::TESTS_EVM_AGGR;
|
||||
use test_case::test_case;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify;
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify_render_seperately;
|
||||
|
||||
use crate::native_tests::kzg_evm_on_chain_input_prove_and_verify;
|
||||
use crate::native_tests::kzg_evm_aggr_prove_and_verify;
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
@@ -988,6 +999,19 @@ mod native_tests {
|
||||
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_prove_and_verify_render_seperately_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_hashed_input_prove_and_verify_(test: &str) {
|
||||
@@ -1259,6 +1283,7 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -1285,22 +1310,29 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
"-M".to_string(),
|
||||
format!("{}/{}/network.onnx", test_dir, example_name),
|
||||
format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size={}", batch_size),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
args.push("--div-rebasing".to_string());
|
||||
};
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-settings",
|
||||
"-M",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
&format!("--variables=batch_size={}", batch_size),
|
||||
&format!("--input-visibility={}", input_visibility),
|
||||
&format!("--param-visibility={}", param_visibility),
|
||||
&format!("--output-visibility={}", output_visibility),
|
||||
&format!("--num-inner-cols={}", num_inner_columns),
|
||||
])
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1378,6 +1410,7 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
target_perc: f32,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1389,6 +1422,7 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -1647,6 +1681,7 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -1723,6 +1758,7 @@ mod native_tests {
|
||||
"resources",
|
||||
None,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -1833,6 +1869,137 @@ mod native_tests {
|
||||
assert!(!status.success());
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
fn kzg_evm_prove_and_verify_render_seperately(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
input_visibility: &str,
|
||||
param_visibility: &str,
|
||||
output_visibility: &str,
|
||||
) {
|
||||
let anvil_url = ANVIL_URL.as_str();
|
||||
|
||||
kzg_prove_and_verify(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
"safe",
|
||||
input_visibility,
|
||||
param_visibility,
|
||||
output_visibility,
|
||||
num_inner_columns,
|
||||
None,
|
||||
false,
|
||||
"single",
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
init_params(settings_path.clone().into());
|
||||
|
||||
let vk_arg = format!("{}/{}/key.vk", test_dir, example_name);
|
||||
let rpc_arg = format!("--rpc-url={}", anvil_url);
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr.txt", test_dir, example_name);
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
let sol_arg = format!("--sol-code-path={}/{}/kzg.sol", test_dir, example_name);
|
||||
|
||||
// create the verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--render-vk-seperately",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
let args = vec![
|
||||
"create-evm-vk",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg_vk,
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier
|
||||
let args = vec![
|
||||
"deploy-evm-verifier",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address
|
||||
let addr = std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", addr);
|
||||
|
||||
// deploy the vk
|
||||
let args = vec![
|
||||
"deploy-evm-vk",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg_vk.as_str(),
|
||||
sol_arg_vk.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address
|
||||
let addr_vk = std::fs::read_to_string(format!("{}/{}/addr_vk.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
let deployed_addr_arg_vk = format!("--addr-vk={}", addr_vk);
|
||||
|
||||
// now verify the proof
|
||||
let pf_arg = format!("{}/{}/proof.pf", test_dir, example_name);
|
||||
let mut args = vec![
|
||||
"verify-evm",
|
||||
"--proof-path",
|
||||
pf_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
deployed_addr_arg.as_str(),
|
||||
deployed_addr_arg_vk.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
args[2] = PF_FAILURE;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str) {
|
||||
let status = Command::new("pnpm")
|
||||
@@ -1867,6 +2034,7 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![7, 8]),
|
||||
1,
|
||||
false,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
Binary file not shown.
@@ -1 +1,60 @@
|
||||
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":0,"scale_rebase_multiplier":10,"lookup_range":[-2,0],"logrows":6,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_rows":16,"total_assignments":32,"total_const_size":8,"model_instance_shapes":[[1,4]],"model_output_scales":[0],"model_input_scales":[0],"module_sizes":{"kzg":[],"poseidon":[0,[0]]},"required_lookups":["ReLU"],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1702474230544}
|
||||
{
|
||||
"run_args": {
|
||||
"tolerance": {
|
||||
"val": 0.0,
|
||||
"scale": 1.0
|
||||
},
|
||||
"input_scale": 0,
|
||||
"param_scale": 0,
|
||||
"scale_rebase_multiplier": 10,
|
||||
"lookup_range": [
|
||||
-2,
|
||||
0
|
||||
],
|
||||
"logrows": 6,
|
||||
"num_inner_cols": 2,
|
||||
"variables": [
|
||||
[
|
||||
"batch_size",
|
||||
1
|
||||
]
|
||||
],
|
||||
"input_visibility": "Private",
|
||||
"output_visibility": "Public",
|
||||
"param_visibility": "Private",
|
||||
"div_rebasing": false,
|
||||
"check_mode": "UNSAFE"
|
||||
},
|
||||
"num_rows": 16,
|
||||
"total_assignments": 32,
|
||||
"total_const_size": 8,
|
||||
"model_instance_shapes": [
|
||||
[
|
||||
1,
|
||||
4
|
||||
]
|
||||
],
|
||||
"model_output_scales": [
|
||||
0
|
||||
],
|
||||
"model_input_scales": [
|
||||
0
|
||||
],
|
||||
"module_sizes": {
|
||||
"kzg": [],
|
||||
"poseidon": [
|
||||
0,
|
||||
[
|
||||
0
|
||||
]
|
||||
]
|
||||
},
|
||||
"required_lookups": [
|
||||
"ReLU"
|
||||
],
|
||||
"required_range_checks": [],
|
||||
"check_mode": "UNSAFE",
|
||||
"version": "0.0.0",
|
||||
"num_blinding_factors": null,
|
||||
"timestamp": 1702474230544
|
||||
}
|
||||
Reference in New Issue
Block a user