mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
26 Commits
v12.0.0
...
ac/small-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d21b43502d | ||
|
|
9f1495bf7f | ||
|
|
355aed3b1f | ||
|
|
ca7fe53687 | ||
|
|
1b49f8774f | ||
|
|
ec6be275cf | ||
|
|
51ace92b73 | ||
|
|
e4332ebb14 | ||
|
|
186515448c | ||
|
|
c70fd153e2 | ||
|
|
66d733d0ce | ||
|
|
89c5238130 | ||
|
|
e88f362596 | ||
|
|
069ac6ee6e | ||
|
|
c6cba69bc3 | ||
|
|
eec19d6058 | ||
|
|
685f462766 | ||
|
|
9b027fc908 | ||
|
|
2e211d1314 | ||
|
|
feae28042e | ||
|
|
885cd880d2 | ||
|
|
b88bb6ccda | ||
|
|
7b32a99856 | ||
|
|
e80eae41f0 | ||
|
|
f4d2fc0ccb | ||
|
|
e7f76c5ae6 |
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -260,8 +260,6 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
- name: hashed params
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
- name: hashed params public inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
- name: hashed outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
- name: hashed inputs + params + outputs
|
||||
|
||||
@@ -72,7 +72,6 @@ impl Circuit<Fr> for MyCircuit {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -2,7 +2,6 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
@@ -85,7 +84,7 @@ fn runrelu(c: &mut Criterion) {
|
||||
};
|
||||
|
||||
let input: Tensor<Value<Fr>> =
|
||||
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
|
||||
let circuit = NLCircuit {
|
||||
input: ValTensor::from(input.clone()),
|
||||
|
||||
@@ -2,7 +2,8 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{
|
||||
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
|
||||
};
|
||||
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::fieldutils;
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -41,8 +42,8 @@ const NUM_INNER_COLS: usize = 1;
|
||||
struct Config<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -65,8 +66,8 @@ struct Config<
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -89,8 +90,8 @@ struct MyCircuit<
|
||||
impl<
|
||||
const LEN: usize,
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -204,7 +205,6 @@ where
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
group: 1,
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
@@ -315,11 +315,7 @@ pub fn runconv() {
|
||||
.test_set_length(10_000)
|
||||
.finalize();
|
||||
|
||||
let mut train_data = Tensor::from(
|
||||
trn_img
|
||||
.iter()
|
||||
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
|
||||
);
|
||||
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
|
||||
train_data.reshape(&[50_000, 28, 28]).unwrap();
|
||||
|
||||
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
|
||||
@@ -347,8 +343,8 @@ pub fn runconv() {
|
||||
.map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -359,8 +355,7 @@ pub fn runconv() {
|
||||
|
||||
let l0_kernels = l0_kernels.try_into().unwrap();
|
||||
|
||||
let mut l0_bias =
|
||||
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
|
||||
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
|
||||
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let l0_bias = l0_bias.try_into().unwrap();
|
||||
@@ -368,8 +363,8 @@ pub fn runconv() {
|
||||
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
}));
|
||||
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
|
||||
@@ -379,8 +374,8 @@ pub fn runconv() {
|
||||
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
}));
|
||||
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
|
||||
@@ -406,13 +401,13 @@ pub fn runconv() {
|
||||
l2_params: [l2_weights, l2_biases],
|
||||
};
|
||||
|
||||
let public_input: Tensor<IntegerRep> = vec![
|
||||
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
|
||||
let public_input: Tensor<i32> = vec![
|
||||
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
|
||||
]
|
||||
.into_iter()
|
||||
.into();
|
||||
|
||||
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
|
||||
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
|
||||
|
||||
println!("MOCK PROVING");
|
||||
let now = Instant::now();
|
||||
|
||||
@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{
|
||||
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
|
||||
};
|
||||
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
@@ -23,8 +23,8 @@ struct MyConfig {
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
> {
|
||||
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
|
||||
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
|
||||
@@ -34,7 +34,7 @@ struct MyCircuit<
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
|
||||
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
|
||||
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
|
||||
{
|
||||
type Config = MyConfig;
|
||||
@@ -215,33 +215,33 @@ pub fn runmlp() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
env_logger::init();
|
||||
// parameters
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
|
||||
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
|
||||
&[4, 4],
|
||||
)
|
||||
.unwrap()
|
||||
.map(integer_rep_to_felt);
|
||||
.map(i32_to_felt);
|
||||
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
.unwrap()
|
||||
.map(integer_rep_to_felt);
|
||||
.map(i32_to_felt);
|
||||
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
|
||||
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
|
||||
&[4, 4],
|
||||
)
|
||||
.unwrap()
|
||||
.map(integer_rep_to_felt);
|
||||
.map(i32_to_felt);
|
||||
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
// input data, with 1 padding to allow for bias
|
||||
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
|
||||
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
|
||||
.unwrap()
|
||||
.into();
|
||||
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
.unwrap()
|
||||
.map(integer_rep_to_felt);
|
||||
.map(i32_to_felt);
|
||||
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let circuit = MyCircuit::<4, -8192, 8192> {
|
||||
@@ -251,12 +251,12 @@ pub fn runmlp() {
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let public_input: Vec<IntegerRep> = unsafe {
|
||||
let public_input: Vec<i32> = unsafe {
|
||||
vec![
|
||||
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
|
||||
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(2849f32 / 128f32).to_int_unchecked::<i32>(),
|
||||
]
|
||||
};
|
||||
|
||||
@@ -265,10 +265,7 @@ pub fn runmlp() {
|
||||
let prover = MockProver::run(
|
||||
K as u32,
|
||||
&circuit,
|
||||
vec![public_input
|
||||
.iter()
|
||||
.map(|x| integer_rep_to_felt::<F>(*x))
|
||||
.collect()],
|
||||
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
|
||||
)
|
||||
.unwrap();
|
||||
prover.assert_satisfied();
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, marker::PhantomData};
|
||||
|
||||
@@ -327,7 +327,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::{fieldutils::IntegerRep, tensor::TensorError};
|
||||
use crate::tensor::TensorError;
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -57,7 +57,7 @@ pub enum CircuitError {
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Invalid min/max lookup range
|
||||
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
|
||||
InvalidMinMaxRange(IntegerRep, IntegerRep),
|
||||
InvalidMinMaxRange(i64, i64),
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
@@ -81,7 +81,7 @@ pub enum CircuitError {
|
||||
MissingSelectors(String),
|
||||
/// Table lookup error
|
||||
#[error("value ({0}) out of range: ({1}, {2})")]
|
||||
TableOOR(IntegerRep, IntegerRep, IntegerRep),
|
||||
TableOOR(i64, i64, i64),
|
||||
/// Loookup not configured
|
||||
#[error("lookup not configured: {0}")]
|
||||
LookupNotConfigured(String),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::integer_rep_to_felt,
|
||||
fieldutils::i64_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -71,7 +71,7 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
@@ -184,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
i64_to_felt(input_scale.0 as i64),
|
||||
i64_to_felt(output_scale.0 as i64),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
i64_to_felt(denom.0 as i64),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
fieldutils::{felt_to_i64, i64_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
use super::Op;
|
||||
@@ -131,16 +131,16 @@ impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as IntegerRep;
|
||||
let range = range as i64;
|
||||
(-range, range)
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let x = x[0].clone().map(|x| felt_to_i64(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
|
||||
@@ -227,13 +227,13 @@ impl LookupOp {
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| integer_rep_to_felt(x));
|
||||
let output = res.map(|x| i64_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -31,12 +31,12 @@ pub use errors::CircuitError;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Returns a string representation of the operation.
|
||||
@@ -75,7 +75,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_dyn()
|
||||
}
|
||||
@@ -142,7 +142,7 @@ pub struct Input {
|
||||
pub datum_type: InputType,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
@@ -197,7 +197,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(0)
|
||||
}
|
||||
@@ -224,7 +224,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
///
|
||||
pub quantized_values: Tensor<F>,
|
||||
///
|
||||
@@ -234,7 +234,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
@@ -267,7 +267,8 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
|
||||
@@ -33,7 +33,6 @@ pub enum PolyOp {
|
||||
Conv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -44,7 +43,6 @@ pub enum PolyOp {
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -100,7 +98,7 @@ impl<
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
,
|
||||
+ IntoI64,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
@@ -150,25 +148,17 @@ impl<
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv {
|
||||
stride,
|
||||
padding,
|
||||
group,
|
||||
} => {
|
||||
format!(
|
||||
"CONV (stride={:?}, padding={:?}, group={})",
|
||||
stride, padding, group
|
||||
)
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
group,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
|
||||
stride, padding, output_padding, group
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -222,18 +212,9 @@ impl<
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv {
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
} => layouts::conv(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
*group,
|
||||
)?,
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
@@ -280,7 +261,6 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
group,
|
||||
} => layouts::deconv(
|
||||
config,
|
||||
region,
|
||||
@@ -288,7 +268,6 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
*group,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
fieldutils::IntegerRep,
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -12,6 +11,7 @@ use halo2_proofs::{
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::iter::ParallelExtend;
|
||||
use portable_atomic::AtomicI64 as AtomicInt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -86,78 +86,6 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Some settings for a region to differentiate it across the different phases of proof generation
|
||||
pub struct RegionSettings {
|
||||
/// whether we are in witness generation mode
|
||||
pub witness_gen: bool,
|
||||
/// whether we should check range checks for validity
|
||||
pub check_range: bool,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Sync for RegionSettings {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for RegionSettings {}
|
||||
|
||||
impl RegionSettings {
|
||||
/// Create a new region settings
|
||||
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen,
|
||||
check_range,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all true
|
||||
pub fn all_true() -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: true,
|
||||
check_range: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all false
|
||||
pub fn all_false() -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: false,
|
||||
check_range: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
/// Region statistics
|
||||
pub struct RegionStatistics {
|
||||
/// the current maximum value of the lookup inputs
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// the current minimum value of the lookup inputs
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// the current maximum value of the range size
|
||||
pub max_range_size: IntegerRep,
|
||||
/// the current set of used lookups
|
||||
pub used_lookups: HashSet<LookupOp>,
|
||||
/// the current set of used range checks
|
||||
pub used_range_checks: HashSet<Range>,
|
||||
}
|
||||
|
||||
impl RegionStatistics {
|
||||
/// update the statistics with another set of statistics
|
||||
pub fn update(&mut self, other: &RegionStatistics) {
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(other.max_lookup_inputs);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(other.min_lookup_inputs);
|
||||
self.max_range_size = self.max_range_size.max(other.max_range_size);
|
||||
self.used_lookups.extend(other.used_lookups.clone());
|
||||
self.used_range_checks
|
||||
.extend(other.used_range_checks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Sync for RegionStatistics {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for RegionStatistics {}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
@@ -167,8 +95,13 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i64,
|
||||
min_lookup_inputs: i64,
|
||||
max_range_size: i64,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
}
|
||||
|
||||
@@ -220,17 +153,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
///
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.settings.witness_gen
|
||||
self.witness_gen
|
||||
}
|
||||
|
||||
///
|
||||
pub fn check_range(&self) -> bool {
|
||||
self.settings.check_range
|
||||
}
|
||||
|
||||
///
|
||||
pub fn statistics(&self) -> &RegionStatistics {
|
||||
&self.statistics
|
||||
pub fn check_lookup_range(&self) -> bool {
|
||||
self.check_lookup_range
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
@@ -245,8 +173,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: true,
|
||||
check_lookup_range: true,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -262,12 +195,39 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
pub fn from_wrapped_region(
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let linear_coord = row * num_inner_cols;
|
||||
RegionCtx {
|
||||
region,
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
dynamic_lookup_index,
|
||||
shuffle_index,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: false,
|
||||
check_lookup_range: false,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
settings: RegionSettings,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
@@ -279,8 +239,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -290,7 +255,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
settings: RegionSettings,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -300,8 +266,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -352,9 +323,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
) -> Result<(), CircuitError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
@@ -368,7 +342,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.settings.clone(),
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -378,9 +353,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut statistics = statistics.lock().unwrap();
|
||||
statistics.update(local_reg.statistics());
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
@@ -394,11 +374,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
res
|
||||
})?;
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner();
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner();
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.statistics = Arc::try_unwrap(statistics)
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
@@ -422,11 +411,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
) -> Result<(), CircuitError> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -438,7 +427,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
|
||||
self.statistics.max_range_size = self.statistics.max_range_size.max(range_size);
|
||||
self.max_range_size = self.max_range_size.max(range_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -453,13 +442,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.used_lookups.insert(lookup);
|
||||
self.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
self.statistics.used_range_checks.insert(range);
|
||||
self.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
@@ -500,27 +489,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.statistics.used_lookups.clone()
|
||||
self.used_lookups.clone()
|
||||
}
|
||||
|
||||
/// get used range checks
|
||||
pub fn used_range_checks(&self) -> HashSet<Range> {
|
||||
self.statistics.used_range_checks.clone()
|
||||
self.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> IntegerRep {
|
||||
self.statistics.max_lookup_inputs
|
||||
pub fn max_lookup_inputs(&self) -> i64 {
|
||||
self.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> IntegerRep {
|
||||
self.statistics.min_lookup_inputs
|
||||
pub fn min_lookup_inputs(&self) -> i64 {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// max range check
|
||||
pub fn max_range_size(&self) -> IntegerRep {
|
||||
self.statistics.max_range_size
|
||||
pub fn max_range_size(&self) -> i64 {
|
||||
self.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
|
||||
@@ -11,17 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
tensor::{Tensor, TensorType},
|
||||
fieldutils::i64_to_felt,
|
||||
tensor::{IntoI64, Tensor, TensorType},
|
||||
};
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (IntegerRep, IntegerRep);
|
||||
pub type Range = (i64, i64);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
@@ -96,22 +96,21 @@ pub struct Table<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
|
||||
/ (self.col_size as IntegerRep);
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
|
||||
integer_rep_to_felt(chunk)
|
||||
i64_to_felt(chunk)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
|
||||
let chunk = chunk as IntegerRep;
|
||||
let chunk = chunk as i64;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element =
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0);
|
||||
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
@@ -131,12 +130,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
(range_len / (col_size as i64)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -203,7 +202,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
@@ -273,12 +272,12 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as IntegerRep;
|
||||
let chunk = chunk as i64;
|
||||
// we index from 1 to prevent soundness issues
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
@@ -294,14 +293,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
|
||||
/ (self.col_size as IntegerRep);
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
|
||||
integer_rep_to_felt(chunk)
|
||||
i64_to_felt(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
@@ -351,7 +350,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
@@ -1050,7 +1050,6 @@ mod conv {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1201,7 +1200,6 @@ mod conv_col_ultra_overflow {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1347,7 +1345,6 @@ mod conv_relu_col_ultra_overflow {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
@@ -786,8 +785,6 @@ pub(crate) async fn gen_witness(
|
||||
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
|
||||
let region_settings = RegionSettings::all_true();
|
||||
|
||||
let start_time = Instant::now();
|
||||
let witness = if settings.module_requires_polycommit() {
|
||||
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
|
||||
@@ -802,7 +799,8 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
region_settings,
|
||||
true,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
Commitments::IPA => {
|
||||
@@ -816,7 +814,8 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
region_settings,
|
||||
true,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
}
|
||||
@@ -826,16 +825,12 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
region_settings,
|
||||
true,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
} else {
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
region_settings,
|
||||
)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
|
||||
};
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
@@ -1028,8 +1023,6 @@ pub(crate) async fn calibrate(
|
||||
use std::collections::HashMap;
|
||||
use tabled::Table;
|
||||
|
||||
use crate::fieldutils::IntegerRep;
|
||||
|
||||
let data = GraphData::from_path(data)?;
|
||||
// load the pre-generated settings
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
@@ -1138,7 +1131,7 @@ pub(crate) async fn calibrate(
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
lookup_range: (i64::MIN, i64::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
@@ -1178,7 +1171,8 @@ pub(crate) async fn calibrate(
|
||||
&mut data.clone(),
|
||||
None,
|
||||
None,
|
||||
RegionSettings::all_true(),
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
|
||||
@@ -2,11 +2,17 @@ use halo2_proofs::arithmetic::Field;
|
||||
/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa).
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
/// Converts an i32 to a PrimeField element.
|
||||
pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
|
||||
if x >= 0 {
|
||||
F::from(x as u64)
|
||||
} else {
|
||||
-F::from(x.unsigned_abs() as u64)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else {
|
||||
@@ -14,9 +20,24 @@ pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i32.
|
||||
pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
|
||||
if x > F::from(i32::MAX as u64) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap());
|
||||
-(lower_32 as i32)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap());
|
||||
lower_32 as i32
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -30,17 +51,17 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
-(lower_128 as IntegerRep)
|
||||
-(lower_128 as i64)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
|
||||
lower_128 as IntegerRep
|
||||
lower_128 as i64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,24 +73,33 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_conv() {
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
let res: F = i32_to_felt(-15i32);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = integer_rep_to_felt(2_i128.pow(17));
|
||||
let res: F = i32_to_felt(2_i32.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
let res: F = i64_to_felt(-15i64);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = integer_rep_to_felt(2_i128.pow(17));
|
||||
let res: F = i64_to_felt(2_i64.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
fn felttoi32() {
|
||||
for x in -(2i32.pow(16))..(2i32.pow(16)) {
|
||||
let fieldx: F = i32_to_felt::<F>(x);
|
||||
let xf: i32 = felt_to_i32::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttoi64() {
|
||||
for x in -(2i64.pow(20))..(2i64.pow(20)) {
|
||||
let fieldx: F = i64_to_felt::<F>(x);
|
||||
let xf: i64 = felt_to_i64::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::postgres::Client;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -128,7 +128,7 @@ impl FileSourceInner {
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
@@ -150,7 +150,7 @@ impl FileSourceInner {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,10 +34,10 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::{ConstantsMap, RegionSettings};
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::{felt_to_f64, IntegerRep};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
@@ -69,14 +69,13 @@ pub use vars::*;
|
||||
use crate::pfsys::field_to_string;
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
|
||||
/// The maximum number of columns in a lookup table.
|
||||
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
|
||||
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: IntegerRep =
|
||||
(MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
@@ -127,11 +126,11 @@ pub struct GraphWitness {
|
||||
/// Any hashes of outputs generated during the forward pass
|
||||
pub processed_outputs: Option<ModuleForwardResult>,
|
||||
/// max lookup input
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
pub max_lookup_inputs: i64,
|
||||
/// max lookup input
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
pub min_lookup_inputs: i64,
|
||||
/// max range check size
|
||||
pub max_range_size: IntegerRep,
|
||||
pub max_range_size: i64,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -805,26 +804,18 @@ impl GraphCircuit {
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs: Vec<Fp> = vec![];
|
||||
|
||||
// we first process the inputs
|
||||
if let Some(processed_inputs) = &data.processed_inputs {
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
|
||||
} else if let Some(processed_inputs) = &data.processed_inputs {
|
||||
public_inputs.extend(processed_inputs.get_instances().into_iter().flatten());
|
||||
}
|
||||
|
||||
// we then process the params
|
||||
if let Some(processed_params) = &data.processed_params {
|
||||
public_inputs.extend(processed_params.get_instances().into_iter().flatten());
|
||||
}
|
||||
|
||||
// if the inputs are public, we add them to the public inputs AFTER the processed params as they are configured in that order as Column<Instances>
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
|
||||
}
|
||||
|
||||
// if the outputs are public, we add them to the public inputs
|
||||
if self.settings().run_args.output_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten());
|
||||
// if the outputs are processed, we add the processed outputs to the public inputs
|
||||
} else if let Some(processed_outputs) = &data.processed_outputs {
|
||||
public_inputs.extend(processed_outputs.get_instances().into_iter().flatten());
|
||||
}
|
||||
@@ -1045,12 +1036,12 @@ impl GraphCircuit {
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
|
||||
(
|
||||
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as IntegerRep,
|
||||
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as IntegerRep,
|
||||
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as i64,
|
||||
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as i64,
|
||||
)
|
||||
}
|
||||
|
||||
fn calc_num_cols(range_len: IntegerRep, max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
|
||||
num_cols_required(range_len, max_col_size)
|
||||
}
|
||||
@@ -1058,7 +1049,7 @@ impl GraphCircuit {
|
||||
fn table_size_logrows(
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: IntegerRep,
|
||||
max_range_size: i64,
|
||||
) -> Result<u32, GraphError> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
@@ -1077,7 +1068,7 @@ impl GraphCircuit {
|
||||
pub fn calc_min_logrows(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
max_range_size: IntegerRep,
|
||||
max_range_size: i64,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: f64,
|
||||
) -> Result<(), GraphError> {
|
||||
@@ -1095,7 +1086,7 @@ impl GraphCircuit {
|
||||
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
|
||||
// check if has overflowed max lookup input
|
||||
|
||||
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as IntegerRep {
|
||||
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as i64 {
|
||||
return Err(GraphError::LookupRangeTooLarge(
|
||||
lookup_size.unsigned_abs() as usize
|
||||
));
|
||||
@@ -1175,7 +1166,7 @@ impl GraphCircuit {
|
||||
&self,
|
||||
k: u32,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: IntegerRep,
|
||||
max_range_size: i64,
|
||||
) -> bool {
|
||||
// if num cols is too large then the extended k is too large
|
||||
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|
||||
@@ -1233,7 +1224,8 @@ impl GraphCircuit {
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
region_settings: RegionSettings,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<GraphWitness, GraphError> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1282,7 +1274,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, region_settings)?;
|
||||
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
|
||||
@@ -7,12 +7,10 @@ use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -66,11 +64,11 @@ pub struct ForwardResult {
|
||||
/// The outputs of the forward pass.
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// The maximum value of any input to a lookup operation.
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
pub max_lookup_inputs: i64,
|
||||
/// The minimum value of any input to a lookup operation.
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
pub min_lookup_inputs: i64,
|
||||
/// The max range check size
|
||||
pub max_range_size: IntegerRep,
|
||||
pub max_range_size: i64,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
@@ -118,11 +116,11 @@ pub struct DummyPassRes {
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
pub max_lookup_inputs: i64,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
pub min_lookup_inputs: i64,
|
||||
/// min range check
|
||||
pub max_range_size: IntegerRep,
|
||||
pub max_range_size: i64,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
@@ -547,7 +545,7 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, RegionSettings::all_false())?;
|
||||
let res = self.dummy_layout(run_args, &inputs, false, false)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -590,13 +588,14 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
region_settings: RegionSettings,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<ForwardResult, GraphError> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, region_settings)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -1391,7 +1390,8 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
region_settings: RegionSettings,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<DummyPassRes, GraphError> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1410,7 +1410,8 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, region_settings);
|
||||
let mut region =
|
||||
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -51,20 +50,16 @@ use tract_onnx::tract_hir::{
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
elem: &f64,
|
||||
shift: f64,
|
||||
scale: crate::Scale,
|
||||
) -> Result<IntegerRep, TensorError> {
|
||||
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
// we parallelize the quantization process as it seems to be quite slow at times
|
||||
let scaled = (mult * *elem + shift).round() as IntegerRep;
|
||||
let scaled = (mult * *elem + shift).round() as i64;
|
||||
|
||||
Ok(scaled)
|
||||
}
|
||||
@@ -75,7 +70,7 @@ pub fn quantize_float(
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_integer_rep(felt);
|
||||
let int_rep = crate::fieldutils::felt_to_i64(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
@@ -288,7 +283,10 @@ pub fn new_op_from_onnx(
|
||||
.flat_map(|x| x.out_scales())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::<Vec<_>>();
|
||||
let input_dims = inputs
|
||||
.iter()
|
||||
.flat_map(|x| x.out_dims())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut replace_const = |scale: crate::Scale,
|
||||
index: usize,
|
||||
@@ -1194,13 +1192,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
let group = conv_node.group;
|
||||
|
||||
SupportedOp::Linear(PolyOp::Conv {
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
})
|
||||
SupportedOp::Linear(PolyOp::Conv { padding, stride })
|
||||
}
|
||||
"Not" => SupportedOp::Linear(PolyOp::Not),
|
||||
"And" => SupportedOp::Linear(PolyOp::And),
|
||||
@@ -1255,7 +1247,6 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
group: deconv_node.group,
|
||||
})
|
||||
}
|
||||
"Downsample" => {
|
||||
@@ -1429,7 +1420,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
visibility: &Visibility,
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
|
||||
@@ -85,7 +85,6 @@ use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
@@ -244,7 +243,7 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768")]
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)]
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
@@ -331,9 +331,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
|
||||
@@ -357,7 +357,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
))]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
Ok(float_rep)
|
||||
@@ -385,7 +385,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
let felt = i64_to_felt(int_rep);
|
||||
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
use halo2curves::{bn256::Fr, ff::PrimeField};
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
|
||||
@@ -26,7 +26,7 @@ use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -108,7 +108,7 @@ impl TensorType for f32 {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
@@ -131,7 +131,7 @@ impl TensorType for f64 {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
@@ -150,7 +150,8 @@ impl TensorType for f64 {
|
||||
}
|
||||
|
||||
tensor_type!(bool, Bool, false, true);
|
||||
tensor_type!(IntegerRep, IntegerRep, 0, 1);
|
||||
tensor_type!(i64, Int64, 0, 1);
|
||||
tensor_type!(i32, Int32, 0, 1);
|
||||
tensor_type!(usize, USize, 0, 1);
|
||||
tensor_type!((), Empty, (), ());
|
||||
tensor_type!(utils::F32, F32, utils::F32(0.0), utils::F32(1.0));
|
||||
@@ -315,6 +316,92 @@ impl<T: TensorType> DerefMut for Tensor<T> {
|
||||
self.inner.deref_mut()
|
||||
}
|
||||
}
|
||||
/// Convert to i64 trait
|
||||
pub trait IntoI64 {
|
||||
/// Convert to i64
|
||||
fn into_i64(self) -> i64;
|
||||
|
||||
/// From i64
|
||||
fn from_i64(i: i64) -> Self;
|
||||
}
|
||||
|
||||
impl IntoI64 for i64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self
|
||||
}
|
||||
fn from_i64(i: i64) -> i64 {
|
||||
i
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for i32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for usize {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for () {
|
||||
fn into_i64(self) -> i64 {
|
||||
0
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {}
|
||||
}
|
||||
|
||||
impl IntoI64 for Fr {
|
||||
fn into_i64(self) -> i64 {
|
||||
felt_to_i64(self)
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i64_to_felt::<Fr>(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + IntoI64> IntoI64 for Value<F> {
|
||||
fn into_i64(self) -> i64 {
|
||||
let mut res = vec![];
|
||||
self.map(|x| res.push(x.into_i64()));
|
||||
|
||||
if res.is_empty() {
|
||||
0
|
||||
} else {
|
||||
res[0]
|
||||
}
|
||||
}
|
||||
|
||||
fn from_i64(i: i64) -> Self {
|
||||
Value::known(F::from_i64(i))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
|
||||
fn eq(&self, other: &Tensor<T>) -> bool {
|
||||
@@ -344,6 +431,42 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
|
||||
for Tensor<i32>
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
value.map(|x| {
|
||||
x.evaluate().value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output.push(e);
|
||||
e
|
||||
})
|
||||
});
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>>
|
||||
for Tensor<i32>
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<F, F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
value.map(|x| {
|
||||
let mut i = 0;
|
||||
x.value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output.push(e);
|
||||
i += 1;
|
||||
});
|
||||
if i == 0 {
|
||||
output.push(0);
|
||||
}
|
||||
});
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
|
||||
for Tensor<Value<F>>
|
||||
{
|
||||
@@ -356,6 +479,24 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>> for Tensor<i32> {
|
||||
fn from(t: Tensor<Value<F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
t.map(|x| {
|
||||
let mut i = 0;
|
||||
x.map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output.push(e);
|
||||
i += 1;
|
||||
});
|
||||
if i == 0 {
|
||||
output.push(0);
|
||||
}
|
||||
});
|
||||
Tensor::new(Some(&output), t.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
|
||||
for Tensor<Value<Assigned<F>>>
|
||||
{
|
||||
@@ -367,10 +508,20 @@ impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<IntegerRep>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<IntegerRep>) -> Tensor<Value<F>> {
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i32>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i32>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(integer_rep_to_felt::<F>(t[i]))));
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i32_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i64>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i64>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
@@ -482,8 +633,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(None, &[3, 3, 3]).unwrap();
|
||||
///
|
||||
/// a.set(&[0, 0, 1], 10);
|
||||
/// assert_eq!(a[0 + 0 + 1], 10);
|
||||
@@ -500,8 +650,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
/// assert_eq!(a.get(&[1, 1, 1]), 5);
|
||||
@@ -515,8 +664,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
/// assert_eq!(a.get(&[1, 1, 1]), 5);
|
||||
@@ -536,12 +684,11 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Pad to a length that is divisible by n
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
|
||||
@@ -557,8 +704,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// let flat_index = 1*15 + 1*5 + 1;
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
@@ -585,9 +731,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Get a slice from the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_slice(&[0..2]).unwrap(), b);
|
||||
/// ```
|
||||
@@ -637,10 +782,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Set a slice of the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// assert_eq!(a, b);
|
||||
/// ```
|
||||
pub fn set_slice(
|
||||
@@ -701,7 +845,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_index(&[2, 2, 2]), 26);
|
||||
@@ -725,13 +868,12 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
|
||||
/// assert_eq!(a.duplicate_every_n(3, 1, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.duplicate_every_n(7, 1, 0).unwrap(), a);
|
||||
///
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
|
||||
/// assert_eq!(a.duplicate_every_n(3, 1, 2).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
@@ -758,9 +900,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
|
||||
/// assert_eq!(a.remove_every_n(4, 1, 0).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
@@ -794,15 +935,14 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// let mut indices = vec![3, 4];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let a = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let mut indices = vec![63, 67];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
|
||||
/// ```
|
||||
@@ -832,7 +972,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///Reshape the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// a.reshape(&[9, 3]);
|
||||
/// assert_eq!(a.dims(), &[9, 3]);
|
||||
@@ -867,23 +1006,22 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Move axis of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(2, 1).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
@@ -948,23 +1086,22 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Swap axes of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.swap_axes(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
|
||||
/// let b = a.swap_axes(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.swap_axes(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.swap_axes(2, 1).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
@@ -1011,10 +1148,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Broadcasts the tensor to a given shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
///
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(a.expand(&[3, 3]).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
@@ -1068,7 +1204,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///Flatten the tensor shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// a.flatten();
|
||||
/// assert_eq!(a.dims(), &[27]);
|
||||
@@ -1082,9 +1217,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.map(|x| IntegerRep::pow(x,2));
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.map(|x| i32::pow(x,2));
|
||||
/// assert_eq!(c, Tensor::from([1, 16].into_iter()))
|
||||
/// ```
|
||||
pub fn map<F: FnMut(T) -> G, G: TensorType>(&self, mut f: F) -> Tensor<G> {
|
||||
@@ -1097,9 +1231,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn enum_map<F: FnMut(usize, T) -> Result<G, E>, G: TensorType, E: Error>(
|
||||
@@ -1121,9 +1254,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map<
|
||||
@@ -1152,9 +1284,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Get last elem from Tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[3]), &[1]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[3]), &[1]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.last().unwrap(), b);
|
||||
/// ```
|
||||
@@ -1177,9 +1308,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map_mut_filtered<
|
||||
@@ -1202,13 +1332,103 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[allow(unsafe_code)]
|
||||
/// Perform a tensor operation on the GPU using Metal.
|
||||
pub fn metal_tensor_op<T: Clone + TensorType + IntoI64 + Send + Sync>(
|
||||
v: &Tensor<T>,
|
||||
w: &Tensor<T>,
|
||||
op: &str,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(v.dims(), w.dims());
|
||||
|
||||
log::trace!("------------------------------------------------");
|
||||
|
||||
let start = Instant::now();
|
||||
let v = v
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
let w = w
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
log::trace!("Time to map tensors: {:?}", start.elapsed());
|
||||
|
||||
objc::rc::autoreleasepool(|| {
|
||||
// create function pipeline.
|
||||
// this compiles the function, so a pipline can't be created in performance sensitive code.
|
||||
|
||||
let pipeline = &PIPELINES[op];
|
||||
|
||||
let length = v.len() as u64;
|
||||
let size = length * core::mem::size_of::<i64>() as u64;
|
||||
assert_eq!(v.len(), w.len());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let buffer_a = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(v.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_b = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(w.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_result = DEVICE.new_buffer(
|
||||
size, // the operation will return an array with the same size.
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
log::trace!("Time to load buffers: {:?}", start.elapsed());
|
||||
|
||||
// for sending commands, a command buffer is needed.
|
||||
let start = Instant::now();
|
||||
let command_buffer = QUEUE.new_command_buffer();
|
||||
log::trace!("Time to load command buffer: {:?}", start.elapsed());
|
||||
// to write commands into a buffer an encoder is needed, in our case a compute encoder.
|
||||
let start = Instant::now();
|
||||
let compute_encoder = command_buffer.new_compute_command_encoder();
|
||||
compute_encoder.set_compute_pipeline_state(&pipeline);
|
||||
compute_encoder.set_buffers(
|
||||
0,
|
||||
&[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)],
|
||||
&[0; 3],
|
||||
);
|
||||
log::trace!("Time to load compute encoder: {:?}", start.elapsed());
|
||||
|
||||
// specify thread count and organization
|
||||
let start = Instant::now();
|
||||
let grid_size = MTLSize::new(length, 1, 1);
|
||||
let threadgroup_size = MTLSize::new(length, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_size, threadgroup_size);
|
||||
log::trace!("Time to dispatch threads: {:?}", start.elapsed());
|
||||
|
||||
// end encoding and execute commands
|
||||
let start = Instant::now();
|
||||
compute_encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
|
||||
command_buffer.wait_until_completed();
|
||||
log::trace!("Time to commit: {:?}", start.elapsed());
|
||||
|
||||
let start = Instant::now();
|
||||
let ptr = buffer_result.contents() as *const i64;
|
||||
let len = buffer_result.length() as usize / std::mem::size_of::<i64>();
|
||||
let slice = unsafe { core::slice::from_raw_parts(ptr, len) };
|
||||
let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap();
|
||||
log::trace!("Time to get result: {:?}", start.elapsed());
|
||||
|
||||
res.map(|x| T::from_i64(x))
|
||||
})
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
/// Flattens a tensor of tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
/// let mut c = Tensor::new(Some(&[a,b]), &[2]).unwrap();
|
||||
/// let mut d = c.combine().unwrap();
|
||||
/// assert_eq!(d.dims(), &[8]);
|
||||
@@ -1224,7 +1444,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Add for Tensor<T> {
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Add
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Adds tensors.
|
||||
/// # Arguments
|
||||
@@ -1234,43 +1456,42 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Add;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D casting
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1]).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // Now test 2D casting
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 3]),
|
||||
/// &[2]).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
@@ -1304,14 +1525,13 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Neg;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.neg();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn neg(self) -> Self {
|
||||
@@ -1324,7 +1544,9 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Sub for Tensor<T> {
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Sub
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Subtracts tensors.
|
||||
/// # Arguments
|
||||
@@ -1334,44 +1556,43 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Sub;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D sub
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 2D sub
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 3]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
@@ -1397,7 +1618,9 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mul for Tensor<T> {
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Mul
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Elementwise multiplies tensors.
|
||||
/// # Arguments
|
||||
@@ -1407,42 +1630,41 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Mul;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D mult
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1]).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 2D mult
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// Some(&[2, 2]),
|
||||
/// &[2]).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
@@ -1468,7 +1690,7 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Tensor<T> {
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Tensor<T> {
|
||||
/// Elementwise raise a tensor to the nth power.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1477,14 +1699,13 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Te
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Mul;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.pow(3).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn pow(&self, mut exp: u32) -> Result<Self, TensorError> {
|
||||
@@ -1518,31 +1739,30 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Div;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.div(y).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // test 1D casting
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = x.div(y).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
@@ -1569,18 +1789,17 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Rem;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.rem(y).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
@@ -1664,25 +1883,25 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tensor_clone() {
|
||||
let x = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let x = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
assert_eq!(x, x.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_eq() {
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let mut b = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
b.reshape(&[3]).unwrap();
|
||||
let c = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3]).unwrap();
|
||||
let d = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
|
||||
let c = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3]).unwrap();
|
||||
let d = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
|
||||
assert_eq!(a, b);
|
||||
assert_ne!(a, c);
|
||||
assert_ne!(a, d);
|
||||
}
|
||||
#[test]
|
||||
fn tensor_slice() {
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
use crate::{circuit::region::ConstantsMap, fieldutils::felt_to_integer_rep};
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use maybe_rayon::slice::Iter;
|
||||
|
||||
use super::{
|
||||
@@ -54,44 +54,6 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
|
||||
AssignedConstant(AssignedCell<F, F>, F),
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for IntegerRep {
|
||||
fn from(val: ValType<F>) -> Self {
|
||||
match val {
|
||||
ValType::Value(v) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.map(|y| {
|
||||
let e = felt_to_integer_rep(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::AssignedValue(v) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.evaluate().map(|y| {
|
||||
let e = felt_to_integer_rep(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.value().map(|y| {
|
||||
let e = felt_to_integer_rep(*y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::Constant(v) => felt_to_integer_rep(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
|
||||
/// Returns the inner cell of the [ValType].
|
||||
pub fn cell(&self) -> Option<Cell> {
|
||||
@@ -159,6 +121,44 @@ impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + Partia
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for i32 {
|
||||
fn from(val: ValType<F>) -> Self {
|
||||
match val {
|
||||
ValType::Value(v) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::AssignedValue(v) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.evaluate().map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::Constant(v) => felt_to_i32(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<F> for ValType<F> {
|
||||
fn from(t: F) -> ValType<F> {
|
||||
ValType::Constant(t)
|
||||
@@ -317,8 +317,8 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64].
|
||||
pub fn from_integer_rep_tensor(t: Tensor<IntegerRep>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(integer_rep_to_felt(x))));
|
||||
pub fn from_i64_tensor(t: Tensor<i64>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
@@ -521,9 +521,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<IntegerRep> = vec![];
|
||||
let mut integer_evals: Vec<i64> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
@@ -531,26 +531,25 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
|
||||
let _ = v.map(|vaf| match vaf {
|
||||
ValType::Value(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f));
|
||||
}),
|
||||
ValType::AssignedValue(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
}),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
v.value_field().map(|f| {
|
||||
integer_evals
|
||||
.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
})
|
||||
}
|
||||
ValType::Constant(v) => {
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(v));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(v));
|
||||
Value::unknown()
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
let mut tensor: Tensor<IntegerRep> = integer_evals.into_iter().into();
|
||||
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
@@ -1003,22 +1002,25 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
/// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph.
|
||||
pub fn show(&self) -> String {
|
||||
let r = match self.int_evals() {
|
||||
Ok(v) => v,
|
||||
Err(_) => return "ValTensor not PrevAssigned".into(),
|
||||
};
|
||||
|
||||
if r.len() > 10 {
|
||||
let start = r[..5].to_vec();
|
||||
let end = r[r.len() - 5..].to_vec();
|
||||
// print the two split by ... in the middle
|
||||
format!(
|
||||
"[{} ... {}]",
|
||||
start.iter().map(|x| format!("{}", x)).join(", "),
|
||||
end.iter().map(|x| format!("{}", x)).join(", ")
|
||||
)
|
||||
} else {
|
||||
format!("{:?}", r)
|
||||
match self.clone() {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
} => {
|
||||
let r: Tensor<i32> = v.map(|x| x.into());
|
||||
if r.len() > 10 {
|
||||
let start = r[..5].to_vec();
|
||||
let end = r[r.len() - 5..].to_vec();
|
||||
// print the two split by ... in the middle
|
||||
format!(
|
||||
"[{} ... {}]",
|
||||
start.iter().map(|x| format!("{}", x)).join(", "),
|
||||
end.iter().map(|x| format!("{}", x)).join(", ")
|
||||
)
|
||||
} else {
|
||||
format!("{:?}", r)
|
||||
}
|
||||
}
|
||||
_ => "ValTensor not PrevAssigned".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,7 +368,7 @@ impl VarTensor {
|
||||
.sum::<usize>();
|
||||
let dims = &dims[*idx];
|
||||
// this should never ever fail
|
||||
let t: Tensor<IntegerRep> = Tensor::new(None, dims).unwrap();
|
||||
let t: Tensor<i32> = Tensor::new(None, dims).unwrap();
|
||||
Ok(t.enum_map(|coord, _| {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
region.assign_advice_from_instance(
|
||||
@@ -497,7 +497,7 @@ impl VarTensor {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
|
||||
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
|
||||
};
|
||||
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
@@ -533,14 +533,13 @@ impl VarTensor {
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
// during key generation this will be 0 so we use this as a flag to check
|
||||
// TODO: this isn't very safe and would be better to get the phase directly
|
||||
let res_evals = res.int_evals().unwrap();
|
||||
let is_assigned = res_evals
|
||||
let is_assigned = !Into::<Tensor<i32>>::into(res.clone().get_inner().unwrap())
|
||||
.iter()
|
||||
.all(|&x| x == 0);
|
||||
if !is_assigned {
|
||||
if is_assigned {
|
||||
assert_eq!(
|
||||
values.int_evals().unwrap(),
|
||||
res_evals
|
||||
Into::<Tensor<i32>>::into(values.get_inner().unwrap()),
|
||||
Into::<Tensor<i32>>::into(res.get_inner().unwrap())
|
||||
)};
|
||||
}
|
||||
|
||||
|
||||
89
src/wasm.rs
89
src/wasm.rs
@@ -1,52 +1,41 @@
|
||||
use crate::{
|
||||
circuit::{
|
||||
modules::{
|
||||
polycommit::PolyCommitChip,
|
||||
poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
},
|
||||
Module,
|
||||
},
|
||||
region::RegionSettings,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
pfsys::{
|
||||
create_proof_circuit,
|
||||
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
|
||||
verify_proof_circuit, TranscriptType,
|
||||
},
|
||||
tensor::TensorType,
|
||||
CheckMode, Commitments,
|
||||
};
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::PoseidonChip;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::fieldutils::felt_to_i64;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::quantize_float;
|
||||
use crate::graph::scale_to_multiplier;
|
||||
use crate::graph::{GraphCircuit, GraphSettings};
|
||||
use crate::pfsys::create_proof_circuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::pfsys::verify_proof_circuit;
|
||||
use crate::pfsys::TranscriptType;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::CheckMode;
|
||||
use crate::Commitments;
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::{
|
||||
commitment::{CommitmentScheme, ParamsProver},
|
||||
ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
multiopen::{ProverIPA, VerifierIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
},
|
||||
kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
},
|
||||
VerificationStrategy,
|
||||
},
|
||||
use halo2_proofs::plonk::*;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
|
||||
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
|
||||
use halo2_proofs::poly::ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::{FromUniformBytes, PrimeField},
|
||||
};
|
||||
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
@@ -124,7 +113,7 @@ pub fn feltToInt(
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&felt_to_integer_rep(felt))
|
||||
serde_json::to_vec(&felt_to_i64(felt))
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
|
||||
))
|
||||
}
|
||||
@@ -138,7 +127,7 @@ pub fn feltToFloat(
|
||||
) -> Result<f64, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
Ok(int_rep as f64 / multiplier)
|
||||
}
|
||||
@@ -152,7 +141,7 @@ pub fn floatToFelt(
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
let felt = i64_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
|
||||
@@ -286,7 +275,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, RegionSettings::all_false())
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false, false)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -646,15 +646,6 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_hashed_params_public_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_fixed_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -1422,10 +1413,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as IntegerRep,
|
||||
as i64,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -1445,10 +1436,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as IntegerRep,
|
||||
as i64,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user