mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-14 00:38:15 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea535e2ecd | ||
|
|
f8aa91ed08 | ||
|
|
a59e3780b2 | ||
|
|
345fb5672a | ||
|
|
70daaff2e4 |
@@ -17,7 +17,6 @@ pub enum BaseOp {
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
IsZero,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
@@ -35,7 +34,6 @@ impl BaseOp {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::IsZero => b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
@@ -76,7 +74,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::IsZero => "ISZERO",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
@@ -93,7 +90,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::IsZero => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
@@ -110,7 +106,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
@@ -127,7 +122,6 @@ impl BaseOp {
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,7 +387,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
@@ -432,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -60,8 +60,6 @@ pub enum PolyOp {
|
||||
len_prod: usize,
|
||||
},
|
||||
Pow(u32),
|
||||
Pack(u32, u32),
|
||||
GlobalSumPool,
|
||||
Concat {
|
||||
axis: usize,
|
||||
},
|
||||
@@ -110,8 +108,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Sum { .. } => "SUM".into(),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Pack(_, _) => "PACK".into(),
|
||||
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -181,13 +177,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pack(base, scale) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pack inputs".to_string()));
|
||||
}
|
||||
|
||||
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
|
||||
}
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
@@ -206,7 +195,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::GlobalSumPool => unreachable!(),
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
@@ -288,7 +276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
@@ -334,10 +322,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
PolyOp::Pack(base, scale) => {
|
||||
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
|
||||
}
|
||||
PolyOp::GlobalSumPool => unreachable!(),
|
||||
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
|
||||
|
||||
@@ -2243,75 +2243,6 @@ mod pow {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pack {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 8;
|
||||
const LEN: usize = 4;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 1],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
Box::new(PolyOp::Pack(2, 1)),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn packcircuit() {
|
||||
// parameters
|
||||
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod matmul_relu {
|
||||
use super::*;
|
||||
|
||||
@@ -402,9 +402,6 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
logrows: Option<u32>,
|
||||
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check: CheckMode,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
|
||||
@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
check,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows).await,
|
||||
Commands::Table { model, args } => table(model, args),
|
||||
#[cfg(feature = "render")]
|
||||
Commands::RenderCircuit {
|
||||
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
use std::io::Read;
|
||||
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut buffer = vec![];
|
||||
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let bytes_read = reader.read_to_end(&mut buffer)?;
|
||||
|
||||
info!(
|
||||
"read {} bytes from SRS file (vector of len = {})",
|
||||
"read {} bytes from file (vector of len = {})",
|
||||
bytes_read,
|
||||
buffer.len()
|
||||
);
|
||||
|
||||
let hash = sha256::digest(buffer);
|
||||
info!("SRS hash: {}", hash);
|
||||
info!("file hash: {}", hash);
|
||||
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
Some(h) => h,
|
||||
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// logrows overrides settings
|
||||
|
||||
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated");
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
|
||||
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
buffer.write_all(reader.get_ref())?;
|
||||
buffer.flush()?;
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -902,7 +904,12 @@ pub(crate) fn calibrate(
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
@@ -971,29 +978,21 @@ pub(crate) fn calibrate(
|
||||
#[cfg(unix)]
|
||||
drop(_g);
|
||||
|
||||
let min_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
let result = forward_pass_res.get(&key).ok_or("key not found")?;
|
||||
|
||||
let min_lookup_range = result
|
||||
.iter()
|
||||
.map(|x| x.min_lookup_inputs)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
let max_lookup_range = result
|
||||
.iter()
|
||||
.map(|x| x.max_lookup_inputs)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_range_size = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.max_range_size)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
let max_range_size = result.iter().map(|x| x.max_range_size).max().unwrap_or(0);
|
||||
|
||||
let res = circuit.calc_min_logrows(
|
||||
(min_lookup_range, max_lookup_range),
|
||||
@@ -1114,6 +1113,7 @@ pub(crate) fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -1697,7 +1697,7 @@ pub(crate) fn fuzz(
|
||||
let logrows = circuit.settings().run_args.logrows;
|
||||
|
||||
info!("setting up tests");
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout()?;
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
@@ -1715,6 +1715,7 @@ pub(crate) fn fuzz(
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
info!("starting fuzzing");
|
||||
@@ -1905,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
passed: &AtomicBool,
|
||||
) {
|
||||
let num_failures = AtomicI64::new(0);
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout().unwrap();
|
||||
|
||||
let pb = init_bar(num_runs as u64);
|
||||
@@ -1918,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
pb.inc(1);
|
||||
});
|
||||
pb.finish_with_message("Done.");
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
info!(
|
||||
"num failures: {} out of {}",
|
||||
|
||||
@@ -484,7 +484,6 @@ fn get_srs(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
CheckMode::SAFE,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
|
||||
@@ -4,6 +4,37 @@ use super::{
|
||||
};
|
||||
use halo2_proofs::{arithmetic::Field, plonk::Instance};
|
||||
|
||||
pub(crate) fn create_constant_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
val: F,
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
|
||||
constant.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(constant)
|
||||
}
|
||||
|
||||
pub(crate) fn create_unit_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
|
||||
unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(unit)
|
||||
}
|
||||
|
||||
pub(crate) fn create_zero_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
|
||||
zero.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(zero)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// A [ValType] is a wrapper around Halo2 value(s).
|
||||
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
|
||||
@@ -463,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(integer_evals.into_iter().into())
|
||||
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
|
||||
Reference in New Issue
Block a user