mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
feat: don't require model file for verifier (#216)
This commit is contained in:
12
README.md
12
README.md
@@ -107,14 +107,14 @@ ezkl -K=17 gen-srs --params-path=kzg.params
|
||||
|
||||
|
||||
```bash
|
||||
ezkl --bits=16 -K=17 prove -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --vk-path 1l_relu.vk --params-path=kzg.params
|
||||
ezkl --bits=16 -K=17 prove -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --vk-path 1l_relu.vk --params-path=kzg.params --circuit-params-path=circuit.params
|
||||
```
|
||||
|
||||
This command generates a proof that the model was correctly run on private inputs (this is the default setting). It then outputs the resulting proof at the path specfifed by `--proof-path`, parameters that can be used for subsequent verification at `--params-path` and the verifier key at `--vk-path`.
|
||||
Luckily `ezkl` also provides command to verify the generated proofs:
|
||||
|
||||
```bash
|
||||
ezkl --bits=16 -K=17 verify -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --vk-path 1l_relu.vk --params-path=kzg.params
|
||||
ezkl --bits=16 -K=17 verify --proof-path 1l_relu.pf --vk-path 1l_relu.vk --params-path=kzg.params --circuit-params-path=circuit.params
|
||||
```
|
||||
|
||||
To display a table of the loaded onnx nodes, their associated parameters, set `RUST_LOG=DEBUG` or run:
|
||||
@@ -131,11 +131,11 @@ Note that the above prove and verify stats can also be run with an EVM verifier.
|
||||
|
||||
```bash
|
||||
# gen proof
|
||||
ezkl --bits=16 -K=17 prove -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --vk-path 1l_relu.vk --params-path=kzg.params --transcript=evm
|
||||
ezkl --bits=16 -K=17 prove -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --vk-path 1l_relu.vk --params-path=kzg.params --transcript=evm --circuit-params-path=circuit.params
|
||||
```
|
||||
```bash
|
||||
# gen evm verifier
|
||||
ezkl -K=17 --bits=16 create-evm-verifier -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --deployment-code-path 1l_relu.code --params-path=kzg.params --vk-path 1l_relu.vk --sol-code-path 1l_relu.sol
|
||||
ezkl -K=17 --bits=16 create-evm-verifier --deployment-code-path 1l_relu.code --params-path=kzg.params --vk-path 1l_relu.vk --sol-code-path 1l_relu.sol --circuit-params-path=circuit.params
|
||||
```
|
||||
```bash
|
||||
# Verify (EVM)
|
||||
@@ -153,12 +153,12 @@ ezkl -K=20 gen-srs --params-path=kzg.params
|
||||
|
||||
```bash
|
||||
# Single proof -> single proof we are going to feed into aggregation circuit. (Mock)-verifies + verifies natively as sanity check
|
||||
ezkl -K=17 --bits=16 prove --transcript=poseidon --strategy=accum -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --params-path=kzg.params --vk-path=1l_relu.vk
|
||||
ezkl -K=17 --bits=16 prove --transcript=poseidon --strategy=accum -D ./examples/onnx/1l_relu/input.json -M ./examples/onnx/1l_relu/network.onnx --proof-path 1l_relu.pf --params-path=kzg.params --vk-path=1l_relu.vk --circuit-params-path=circuit.params
|
||||
```
|
||||
|
||||
```bash
|
||||
# Aggregate -> generates aggregate proof and also (mock)-verifies + verifies natively as sanity check
|
||||
ezkl -K=20 --bits=16 aggregate --app-logrows=17 --transcript=evm -M ./examples/onnx/1l_relu/network.onnx --aggregation-snarks=1l_relu.pf --aggregation-vk-paths 1l_relu.vk --vk-path aggr_1l_relu.vk --proof-path aggr_1l_relu.pf --params-path=kzg.params
|
||||
ezkl -K=20 --bits=16 aggregate --app-logrows=17 --transcript=evm --circuit-params-paths=circuit.params --aggregation-snarks=1l_relu.pf --aggregation-vk-paths 1l_relu.vk --vk-path aggr_1l_relu.vk --proof-path aggr_1l_relu.pf --params-path=kzg.params
|
||||
```
|
||||
|
||||
```bash
|
||||
|
||||
@@ -268,9 +268,9 @@ pub enum Commands {
|
||||
/// Aggregates proofs :)
|
||||
#[command(arg_required_else_help = true)]
|
||||
Aggregate {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long)]
|
||||
model: PathBuf,
|
||||
/// The path to the params files.
|
||||
#[arg(long)]
|
||||
circuit_params_paths: Vec<PathBuf>,
|
||||
///the logrows used when generating the snarks we're aggregating
|
||||
#[arg(long)]
|
||||
app_logrows: u32,
|
||||
@@ -315,9 +315,12 @@ pub enum Commands {
|
||||
/// The path to the desired output file
|
||||
#[arg(long)]
|
||||
proof_path: PathBuf,
|
||||
/// The transcript type
|
||||
/// The parameter path
|
||||
#[arg(long)]
|
||||
params_path: PathBuf,
|
||||
/// The path to save circuit params to
|
||||
#[arg(long)]
|
||||
circuit_params_path: PathBuf,
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
@@ -341,15 +344,12 @@ pub enum Commands {
|
||||
/// Creates an EVM verifier for a single proof
|
||||
#[command(name = "create-evm-verifier", arg_required_else_help = true)]
|
||||
CreateEVMVerifier {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long)]
|
||||
data: String,
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long)]
|
||||
model: PathBuf,
|
||||
/// The path to load the desired params file
|
||||
#[arg(long)]
|
||||
params_path: PathBuf,
|
||||
/// The path to save circuit params to
|
||||
#[arg(long)]
|
||||
circuit_params_path: PathBuf,
|
||||
/// The path to load the desired verfication key file
|
||||
#[arg(long)]
|
||||
vk_path: PathBuf,
|
||||
@@ -394,7 +394,6 @@ pub enum Commands {
|
||||
/// The path to output the Solidity code (optional) supercedes deployment_code_path in priority
|
||||
#[arg(long)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
// todo, optionally allow supplying proving key
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -421,9 +420,9 @@ pub enum Commands {
|
||||
/// Verifies a proof, returning accept or reject
|
||||
#[command(arg_required_else_help = true)]
|
||||
Verify {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long)]
|
||||
model: PathBuf,
|
||||
/// The path to save circuit params to
|
||||
#[arg(long)]
|
||||
circuit_params_path: PathBuf,
|
||||
/// The path to the proof file
|
||||
#[arg(long)]
|
||||
proof_path: PathBuf,
|
||||
|
||||
@@ -43,6 +43,7 @@ use std::fs::File;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tabled::Table;
|
||||
@@ -90,16 +91,15 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
Commands::Mock { data, model: _ } => mock(data, cli.args.logrows),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMVerifier {
|
||||
data,
|
||||
model: _,
|
||||
vk_path,
|
||||
params_path,
|
||||
circuit_params_path,
|
||||
deployment_code_path,
|
||||
sol_code_path,
|
||||
} => create_evm_verifier(
|
||||
data,
|
||||
vk_path,
|
||||
params_path,
|
||||
circuit_params_path,
|
||||
deployment_code_path,
|
||||
sol_code_path,
|
||||
cli.args.logrows,
|
||||
@@ -116,6 +116,7 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
vk_path,
|
||||
proof_path,
|
||||
params_path,
|
||||
circuit_params_path,
|
||||
transcript,
|
||||
strategy,
|
||||
} => prove(
|
||||
@@ -123,13 +124,14 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
vk_path,
|
||||
proof_path,
|
||||
params_path,
|
||||
circuit_params_path,
|
||||
transcript,
|
||||
strategy,
|
||||
cli.args.logrows,
|
||||
cli.args.check_mode,
|
||||
),
|
||||
Commands::Aggregate {
|
||||
model: _,
|
||||
circuit_params_paths,
|
||||
proof_path,
|
||||
aggregation_snarks,
|
||||
aggregation_vk_paths,
|
||||
@@ -140,6 +142,7 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
} => aggregate(
|
||||
proof_path,
|
||||
aggregation_snarks,
|
||||
circuit_params_paths,
|
||||
aggregation_vk_paths,
|
||||
vk_path,
|
||||
params_path,
|
||||
@@ -149,13 +152,14 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
cli.args.check_mode,
|
||||
),
|
||||
Commands::Verify {
|
||||
model: _,
|
||||
proof_path,
|
||||
circuit_params_path,
|
||||
vk_path,
|
||||
params_path,
|
||||
transcript,
|
||||
} => verify(
|
||||
proof_path,
|
||||
circuit_params_path,
|
||||
vk_path,
|
||||
params_path,
|
||||
transcript,
|
||||
@@ -411,19 +415,20 @@ fn render(data: String, output: String, logrows: u32) -> Result<(), Box<dyn Erro
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn create_evm_verifier(
|
||||
data: String,
|
||||
vk_path: PathBuf,
|
||||
params_path: PathBuf,
|
||||
circuit_params_path: PathBuf,
|
||||
deployment_code_path: Option<PathBuf>,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let data = prepare_data(data)?;
|
||||
let circuit = ModelCircuit::<Fr>::from_arg(&data)?;
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
let num_instance = public_inputs.iter().map(|x| x.len()).collect();
|
||||
let params = load_params_cmd(params_path, logrows)?;
|
||||
let model_circuit_params = load_model_circuit_params();
|
||||
let model_circuit_params = ModelParams::load(&circuit_params_path);
|
||||
let num_instance = model_circuit_params
|
||||
.instance_shapes
|
||||
.iter()
|
||||
.map(|x| x.iter().product())
|
||||
.collect();
|
||||
|
||||
let vk =
|
||||
load_vk::<KZGCommitmentScheme<Bn256>, Fr, ModelCircuit<Fr>>(vk_path, model_circuit_params)?;
|
||||
@@ -491,6 +496,7 @@ fn prove(
|
||||
vk_path: PathBuf,
|
||||
proof_path: PathBuf,
|
||||
params_path: PathBuf,
|
||||
circuit_params_path: PathBuf,
|
||||
transcript: TranscriptType,
|
||||
strategy: StrategyType,
|
||||
logrows: u32,
|
||||
@@ -507,6 +513,7 @@ fn prove(
|
||||
trace!("params computed");
|
||||
|
||||
let now = Instant::now();
|
||||
let circuit_params = circuit.params.clone();
|
||||
// creates and verifies the proof
|
||||
let snark = match strategy {
|
||||
StrategyType::Single => {
|
||||
@@ -539,12 +546,16 @@ fn prove(
|
||||
|
||||
snark.save(&proof_path)?;
|
||||
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_path, pk.get_vk())?;
|
||||
|
||||
circuit_params.save(&circuit_params_path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn aggregate(
|
||||
proof_path: PathBuf,
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
circuit_params_paths: Vec<PathBuf>,
|
||||
aggregation_vk_paths: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
params_path: PathBuf,
|
||||
@@ -560,9 +571,12 @@ fn aggregate(
|
||||
// the K used when generating the application snark proof. we assume K is homogenous across snarks to aggregate
|
||||
let params_app = load_params_cmd(params_path, app_logrows)?;
|
||||
|
||||
let model_circuit_params = load_model_circuit_params();
|
||||
|
||||
for (proof_path, vk_path) in aggregation_snarks.iter().zip(aggregation_vk_paths) {
|
||||
for ((proof_path, vk_path), circuit_params_path) in aggregation_snarks
|
||||
.iter()
|
||||
.zip(aggregation_vk_paths)
|
||||
.zip(circuit_params_paths)
|
||||
{
|
||||
let model_circuit_params = ModelParams::load(&circuit_params_path);
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, ModelCircuit<Fr>>(
|
||||
vk_path.to_path_buf(),
|
||||
// safe to clone as the inner model is wrapped in an Arc
|
||||
@@ -602,6 +616,7 @@ fn aggregate(
|
||||
|
||||
fn verify(
|
||||
proof_path: PathBuf,
|
||||
circuit_params_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
params_path: PathBuf,
|
||||
transcript: TranscriptType,
|
||||
@@ -610,7 +625,7 @@ fn verify(
|
||||
let params = load_params_cmd(params_path, logrows)?;
|
||||
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path, None, None)?;
|
||||
let model_circuit_params = load_model_circuit_params();
|
||||
let model_circuit_params = ModelParams::load(&circuit_params_path);
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
let vk =
|
||||
@@ -648,22 +663,3 @@ pub fn load_params_cmd(params_path: PathBuf, logrows: u32) -> Result<ParamsKZG<B
|
||||
}
|
||||
Ok(params)
|
||||
}
|
||||
|
||||
fn load_model_circuit_params() -> ModelParams<Fr> {
|
||||
let model: Arc<Model<Fr>> = Arc::new(Model::from_arg().expect("model should load"));
|
||||
|
||||
let instance_shapes = model.instance_shapes();
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let num_constraints = if let Some(num_constraints) = model.run_args.allocated_constraints {
|
||||
num_constraints
|
||||
} else {
|
||||
model.dummy_layout(&model.input_shapes()).unwrap()
|
||||
};
|
||||
|
||||
ModelParams {
|
||||
model,
|
||||
instance_shapes,
|
||||
num_constraints,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
pub mod utilities;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use utilities::*;
|
||||
/// Crate for defining a computational graph and building a ZK-circuit from it.
|
||||
pub mod model;
|
||||
@@ -10,7 +11,8 @@ pub mod node;
|
||||
/// Representations of a computational graph's variables.
|
||||
pub mod vars;
|
||||
|
||||
use crate::commands::Cli;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::commands::{Cli, RunArgs};
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
use crate::pfsys::ModelInput;
|
||||
use crate::tensor::ops::pack;
|
||||
@@ -24,10 +26,8 @@ use halo2_proofs::{
|
||||
use log::{info, trace};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
// use std::fs::File;
|
||||
// use std::io::{BufReader, BufWriter, Read, Write};
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
// use std::path::PathBuf;
|
||||
use thiserror::Error;
|
||||
pub use vars::*;
|
||||
|
||||
@@ -79,23 +79,44 @@ pub enum GraphError {
|
||||
}
|
||||
|
||||
/// model parameters
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ModelParams<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// An onnx model quantized and configured for zkSNARKs
|
||||
pub model: Arc<Model<F>>,
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct ModelParams {
|
||||
/// run args
|
||||
pub run_args: RunArgs,
|
||||
/// the visibility of the variables in the circuit
|
||||
pub visibility: VarVisibility,
|
||||
/// the potential number of constraints in the circuit
|
||||
pub num_constraints: usize,
|
||||
/// the shape of public inputs to the circuit (in order of appearance)
|
||||
pub instance_shapes: Vec<Vec<usize>>,
|
||||
/// required_lookups
|
||||
pub required_lookups: Vec<LookupOp>,
|
||||
}
|
||||
|
||||
impl ModelParams {
|
||||
///
|
||||
pub fn save(&self, path: &std::path::PathBuf) {
|
||||
let mut file = std::fs::File::create(path).unwrap();
|
||||
let encoded: Vec<u8> = bincode::serialize(&self).unwrap();
|
||||
file.write_all(&encoded).unwrap();
|
||||
}
|
||||
///
|
||||
pub fn load(path: &std::path::PathBuf) -> Self {
|
||||
let file = std::fs::File::open(path).unwrap();
|
||||
let decoded: Self = bincode::deserialize_from(file).unwrap();
|
||||
decoded
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ModelCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// The model / graph of computations.
|
||||
pub model: Arc<Model<F>>,
|
||||
/// Vector of input tensors to the model / graph of computations.
|
||||
pub inputs: Vec<Tensor<i128>>,
|
||||
///
|
||||
pub params: ModelParams<F>,
|
||||
pub params: ModelParams,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ModelCircuit<F> {
|
||||
@@ -118,22 +139,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelCircuit<F> {
|
||||
inputs.push(t);
|
||||
}
|
||||
|
||||
let instance_shapes = model.instance_shapes();
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let num_constraints = if let Some(num_constraints) = model.run_args.allocated_constraints {
|
||||
num_constraints
|
||||
} else {
|
||||
model.dummy_layout(&model.input_shapes()).unwrap()
|
||||
};
|
||||
|
||||
let params = ModelParams {
|
||||
model,
|
||||
instance_shapes,
|
||||
num_constraints,
|
||||
};
|
||||
|
||||
Ok(ModelCircuit::<F> { inputs, params })
|
||||
Ok(ModelCircuit::<F> {
|
||||
model: model.clone(),
|
||||
inputs,
|
||||
params: model.gen_params()?,
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
@@ -148,17 +158,17 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelCircuit<F> {
|
||||
&self,
|
||||
data: &ModelInput,
|
||||
) -> Result<Vec<Vec<F>>, Box<dyn std::error::Error>> {
|
||||
let out_scales = self.params.model.get_output_scales();
|
||||
let out_scales = self.model.get_output_scales();
|
||||
|
||||
// quantize the supplied data using the provided scale.
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs = vec![];
|
||||
if self.params.model.visibility.input.is_public() {
|
||||
if self.model.visibility.input.is_public() {
|
||||
for v in data.input_data.iter() {
|
||||
let t: Vec<i128> = v
|
||||
.par_iter()
|
||||
.map(|x| quantize_float(x, 0.0, self.params.model.run_args.scale).unwrap())
|
||||
.map(|x| quantize_float(x, 0.0, self.model.run_args.scale).unwrap())
|
||||
.collect();
|
||||
|
||||
let t: Tensor<i128> = t.into_iter().into();
|
||||
@@ -166,7 +176,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelCircuit<F> {
|
||||
public_inputs.push(t);
|
||||
}
|
||||
}
|
||||
if self.params.model.visibility.output.is_public() {
|
||||
if self.model.visibility.output.is_public() {
|
||||
for (idx, v) in data.output_data.iter().enumerate() {
|
||||
let t: Vec<i128> = v
|
||||
.par_iter()
|
||||
@@ -176,18 +186,16 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelCircuit<F> {
|
||||
let mut t: Tensor<i128> = t.into_iter().into();
|
||||
|
||||
let len = t.len();
|
||||
if self.params.model.run_args.pack_base > 1 {
|
||||
if self.model.run_args.pack_base > 1 {
|
||||
let max_exponent =
|
||||
(((len - 1) as u32) * (self.params.model.run_args.scale + 1)) as f64;
|
||||
if max_exponent
|
||||
> (i128::MAX as f64).log(self.params.model.run_args.pack_base as f64)
|
||||
{
|
||||
(((len - 1) as u32) * (self.model.run_args.scale + 1)) as f64;
|
||||
if max_exponent > (i128::MAX as f64).log(self.model.run_args.pack_base as f64) {
|
||||
return Err(Box::new(GraphError::PackingExponent));
|
||||
}
|
||||
t = pack(
|
||||
&t,
|
||||
self.params.model.run_args.pack_base as i128,
|
||||
self.params.model.run_args.scale,
|
||||
self.model.run_args.pack_base as i128,
|
||||
self.model.run_args.scale,
|
||||
)?;
|
||||
}
|
||||
public_inputs.push(t);
|
||||
@@ -214,7 +222,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelCircuit<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for ModelCircuit<F> {
|
||||
type Config = ModelConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ModelParams<F>;
|
||||
type Params = ModelParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
@@ -228,14 +236,15 @@ impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for ModelCircuit<F> {
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<F>, params: Self::Params) -> Self::Config {
|
||||
let mut vars = ModelVars::new(
|
||||
cs,
|
||||
params.model.run_args.logrows as usize,
|
||||
params.run_args.logrows as usize,
|
||||
params.num_constraints,
|
||||
params.instance_shapes.clone(),
|
||||
params.model.visibility.clone(),
|
||||
params.model.run_args.scale,
|
||||
params.visibility.clone(),
|
||||
params.run_args.scale,
|
||||
);
|
||||
|
||||
let base = params.model.configure(cs, &mut vars).unwrap();
|
||||
let base =
|
||||
Model::<F>::configure(cs, &mut vars, params.run_args, params.required_lookups).unwrap();
|
||||
|
||||
ModelConfig { base, vars }
|
||||
}
|
||||
@@ -256,8 +265,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for ModelCircuit<F> {
|
||||
.map(|i| ValTensor::from(<Tensor<i128> as Into<Tensor<Value<F>>>>::into(i.clone())))
|
||||
.collect::<Vec<ValTensor<F>>>();
|
||||
trace!("Laying out model");
|
||||
self.params
|
||||
.model
|
||||
self.model
|
||||
.layout(config.clone(), &mut layouter, &inputs, &config.vars)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use super::node::*;
|
||||
use super::vars::*;
|
||||
use super::GraphError;
|
||||
use super::ModelParams;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::BaseConfig as PolyConfig;
|
||||
use crate::circuit::Op;
|
||||
@@ -31,6 +33,7 @@ use itertools::Itertools;
|
||||
use log::error;
|
||||
use log::{debug, info, trace};
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashSet;
|
||||
use std::error::Error;
|
||||
use std::path::Path;
|
||||
use tabled::Table;
|
||||
@@ -105,6 +108,37 @@ impl<F: PrimeField + TensorType + PartialOrd> Model<F> {
|
||||
Ok(om)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn gen_params(&self) -> Result<ModelParams, Box<dyn Error>> {
|
||||
let instance_shapes = self.instance_shapes();
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let num_constraints = if let Some(num_constraints) = self.run_args.allocated_constraints {
|
||||
num_constraints
|
||||
} else {
|
||||
self.dummy_layout(&self.input_shapes()).unwrap()
|
||||
};
|
||||
|
||||
// extract the requisite lookup ops from the model
|
||||
let mut lookup_ops: Vec<LookupOp> = self
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|(_, n)| n.opkind.required_lookups())
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
|
||||
lookup_ops.extend(set.into_iter().sorted());
|
||||
|
||||
Ok(ModelParams {
|
||||
run_args: self.run_args.clone(),
|
||||
visibility: self.visibility.clone(),
|
||||
instance_shapes,
|
||||
num_constraints,
|
||||
required_lookups: lookup_ops,
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs a forward pass on sample data !
|
||||
/// # Arguments
|
||||
/// * `path` - A path to an Onnx file.
|
||||
@@ -350,15 +384,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Model<F> {
|
||||
Commands::Table { model } | Commands::Mock { model, .. } => {
|
||||
Model::new(model, cli.args, Mode::Mock, visibility)
|
||||
}
|
||||
Commands::Prove { model, .. }
|
||||
| Commands::Verify { model, .. }
|
||||
| Commands::Aggregate { model, .. } => {
|
||||
Model::new(model, cli.args, Mode::Prove, visibility)
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMVerifier { model, .. } => {
|
||||
Model::new(model, cli.args, Mode::Prove, visibility)
|
||||
}
|
||||
Commands::Prove { model, .. } => Model::new(model, cli.args, Mode::Prove, visibility),
|
||||
#[cfg(feature = "render")]
|
||||
Commands::RenderCircuit { model, .. } => {
|
||||
Model::new(model, cli.args, Mode::Table, visibility)
|
||||
@@ -380,9 +406,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Model<F> {
|
||||
/// * `meta` - Halo2 ConstraintSystem.
|
||||
/// * `advices` - A `VarTensor` holding columns of advices. Must be sufficiently large to configure all the nodes loaded in `self.nodes`.
|
||||
pub fn configure(
|
||||
&self,
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
vars: &mut ModelVars<F>,
|
||||
run_args: RunArgs,
|
||||
required_lookups: Vec<LookupOp>,
|
||||
) -> Result<PolyConfig<F>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
|
||||
@@ -390,47 +417,19 @@ impl<F: PrimeField + TensorType + PartialOrd> Model<F> {
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
&vars.advices[2],
|
||||
self.run_args.check_mode,
|
||||
self.run_args.tolerance as i32,
|
||||
run_args.check_mode,
|
||||
run_args.tolerance as i32,
|
||||
);
|
||||
|
||||
let lookup_ops: BTreeMap<&usize, &Node<F>> = self
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|(_, n)| n.opkind.required_lookups().len() > 0)
|
||||
.collect();
|
||||
|
||||
for node in lookup_ops.values() {
|
||||
self.conf_lookup(&mut base_gate, node, meta, vars)?;
|
||||
for op in required_lookups {
|
||||
let input = &vars.advices[0];
|
||||
let output = &vars.advices[1];
|
||||
base_gate.configure_lookup(meta, input, output, run_args.bits, &op)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
/// Configures a lookup table based operation. These correspond to operations that are represented in
|
||||
/// the `circuit::eltwise` module.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node` - The [Node] must represent a lookup based op.
|
||||
/// * `meta` - Halo2 ConstraintSystem.
|
||||
/// * `vars` - [ModelVars] for the model.
|
||||
fn conf_lookup(
|
||||
&self,
|
||||
config: &mut PolyConfig<F>,
|
||||
node: &Node<F>,
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
vars: &mut ModelVars<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let input = &vars.advices[0];
|
||||
let output = &vars.advices[1];
|
||||
|
||||
for op in node.opkind.required_lookups() {
|
||||
config.configure_lookup(meta, input, output, self.run_args.bits, &op)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Assigns values to the regions created when calling `configure`.
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
@@ -565,6 +565,11 @@ fn kzg_aggr_prove_and_verify(example_name: String) {
|
||||
"--params-path={}/kzg23.params",
|
||||
TEST_DIR.path().to_str().unwrap()
|
||||
),
|
||||
&format!(
|
||||
"--circuit-params-path={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--transcript=poseidon",
|
||||
"--strategy=accum",
|
||||
])
|
||||
@@ -577,8 +582,11 @@ fn kzg_aggr_prove_and_verify(example_name: String) {
|
||||
"-K=23",
|
||||
"aggregate",
|
||||
"--app-logrows=17",
|
||||
"-M",
|
||||
format!("./examples/onnx/{}/network.onnx", example_name).as_str(),
|
||||
&format!(
|
||||
"--circuit-params-paths={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--aggregation-snarks",
|
||||
&format!("{}/{}.pf", TEST_DIR.path().to_str().unwrap(), example_name),
|
||||
"--aggregation-vk-paths",
|
||||
@@ -659,6 +667,11 @@ fn kzg_evm_aggr_prove_and_verify(example_name: String) {
|
||||
"--params-path={}/kzg23.params",
|
||||
TEST_DIR.path().to_str().unwrap()
|
||||
),
|
||||
&format!(
|
||||
"--circuit-params-path={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--transcript=poseidon",
|
||||
"--strategy=accum",
|
||||
])
|
||||
@@ -671,8 +684,11 @@ fn kzg_evm_aggr_prove_and_verify(example_name: String) {
|
||||
"-K=23",
|
||||
"aggregate",
|
||||
"--app-logrows=17",
|
||||
"-M",
|
||||
format!("./examples/onnx/{}/network.onnx", example_name).as_str(),
|
||||
&format!(
|
||||
"--circuit-params-paths={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--aggregation-snarks",
|
||||
&format!(
|
||||
"{}/{}_evm.pf",
|
||||
@@ -773,6 +789,11 @@ fn kzg_prove_and_verify(example_name: String) {
|
||||
"--params-path={}/kzg17.params",
|
||||
TEST_DIR.path().to_str().unwrap()
|
||||
),
|
||||
&format!(
|
||||
"--circuit-params-path={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--transcript=blake",
|
||||
"--strategy=single",
|
||||
])
|
||||
@@ -784,8 +805,11 @@ fn kzg_prove_and_verify(example_name: String) {
|
||||
"--bits=16",
|
||||
"-K=17",
|
||||
"verify",
|
||||
"-M",
|
||||
format!("./examples/onnx/{}/network.onnx", example_name).as_str(),
|
||||
&format!(
|
||||
"--circuit-params-path={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--proof-path",
|
||||
&format!("{}/{}.pf", TEST_DIR.path().to_str().unwrap(), example_name),
|
||||
"--vk-path",
|
||||
@@ -820,6 +844,11 @@ fn kzg_evm_prove_and_verify(example_name: String, with_solidity: bool) {
|
||||
"--params-path={}/kzg17.params",
|
||||
TEST_DIR.path().to_str().unwrap()
|
||||
),
|
||||
&format!(
|
||||
"--circuit-params-path={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
),
|
||||
"--transcript=evm",
|
||||
"--strategy=single",
|
||||
])
|
||||
@@ -827,8 +856,11 @@ fn kzg_evm_prove_and_verify(example_name: String, with_solidity: bool) {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let input_arg = format!("./examples/onnx/{}/input.json", example_name);
|
||||
let network_arg = format!("./examples/onnx/{}/network.onnx", example_name);
|
||||
let circuit_params = format!(
|
||||
"--circuit-params-path={}/{}.params",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
example_name
|
||||
);
|
||||
let code_arg = format!(
|
||||
"{}/{}.code",
|
||||
TEST_DIR.path().to_str().unwrap(),
|
||||
@@ -844,10 +876,7 @@ fn kzg_evm_prove_and_verify(example_name: String, with_solidity: bool) {
|
||||
"--bits=16",
|
||||
"-K=17",
|
||||
"create-evm-verifier",
|
||||
"-D",
|
||||
input_arg.as_str(),
|
||||
"-M",
|
||||
network_arg.as_str(),
|
||||
circuit_params.as_str(),
|
||||
"--deployment-code-path",
|
||||
code_arg.as_str(),
|
||||
param_arg.as_str(),
|
||||
|
||||
Reference in New Issue
Block a user