refactor: lookup-less sign relu abs (#845)

This commit is contained in:
dante
2024-09-17 11:58:58 -04:00
committed by GitHub
parent c9f9d17f16
commit 64fbc8a1c9
44 changed files with 1098 additions and 1019 deletions

View File

@@ -65,40 +65,40 @@ jobs:
- name: Library tests (original lookup)
run: cargo nextest run --lib --verbose --no-default-features --features ezkl
ultra-overflow-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# ultra-overflow-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - uses: mwilliamson/setup-wasmtime-action@v2
# with:
# wasmtime-version: "3.0.1"
# - name: Install wasm32-wasi
# run: rustup target add wasm32-wasi
# - name: Install cargo-wasi
# run: cargo install cargo-wasi
# # - name: Matmul overflow (wasi)
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# # - name: Conv overflow (wasi)
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
# - name: lookup overflow
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Matmul overflow
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Conv overflow
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Conv + relu overflow
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
ultra-overflow-tests_og-lookup:
runs-on: non-gpu
@@ -184,7 +184,6 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -244,7 +243,7 @@ jobs:
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
- name: kzg params

View File

@@ -162,6 +162,10 @@ harness = false
name = "relu"
harness = false
[[bench]]
name = "relu_lookupless"
harness = false
[[bench]]
name = "accum_matmul_relu"
harness = false

View File

@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -57,7 +57,15 @@ impl Circuit<Fr> for MyCircuit {
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
MyConfig { base_config }
@@ -75,14 +83,18 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())
},

View File

@@ -58,7 +58,15 @@ impl Circuit<Fr> for MyCircuit {
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
k,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
MyConfig { base_config }
@@ -76,14 +84,18 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())
},

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -59,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.unwrap();

View File

@@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.unwrap();

View File

@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let mut config = Config::default();
@@ -63,9 +63,13 @@ impl Circuit<Fr> for NLCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())
},

143
benches/relu_lookupless.rs Normal file
View File

@@ -0,0 +1,143 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut LEN: usize = 4;
const K: usize = 16;
#[derive(Clone)]
struct NLCircuit {
pub input: ValTensor<Fr>,
}
impl Circuit<Fr> for NLCircuit {
type Config = Config<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unsafe {
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let mut config = Config::default();
config
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K)
.unwrap();
config
.configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
config
}
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout_range_checks(&mut layouter).unwrap();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runrelu(c: &mut Criterion) {
let mut group = c.benchmark_group("relu");
let mut rng = rand::thread_rng();
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 8].iter() {
unsafe {
LEN = len;
};
let input: Tensor<Value<Fr>> =
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
NLCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();
});
});
}
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runrelu
}
criterion_main!(benches);

View File

@@ -163,7 +163,7 @@ where
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
@@ -199,7 +199,7 @@ where
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2);
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
@@ -221,7 +221,11 @@ where
let x = config
.layer_config
.layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
let mut x = config

View File

@@ -69,7 +69,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
@@ -108,7 +108,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let x = config
.layer_config
.layout(
@@ -141,7 +141,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
println!("x shape: {:?}", x.dims());
let mut x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap()
.unwrap();
println!("3");
@@ -177,7 +181,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
println!("x shape: {:?}", x.dims());
let x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
println!("6");
println!("offset: {}", region.row());

File diff suppressed because one or more lines are too long

View File

@@ -232,7 +232,7 @@
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n",
"run_args.logrows = 15\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
]
@@ -404,7 +404,7 @@
"run_args.output_visibility = \"polycommit\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n"
"run_args.logrows = 15\n"
]
},
{
@@ -466,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -71,12 +71,17 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::LessEqual { .. } => {
vec![0, 1]
}
_ => vec![],
}
}
@@ -135,10 +140,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
HybridOp::GreaterEqual => "GREATEREQUAL".into(),
HybridOp::Less => "LESS".into(),
HybridOp::LessEqual => "LESSEQUAL".into(),
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
HybridOp::Less => "LESS".to_string(),
HybridOp::LessEqual => "LESSEQUAL".to_string(),
HybridOp::Equals => "EQUALS".into(),
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
HybridOp::TopK { k, dim, largest } => {

View File

@@ -240,7 +240,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::circuit::layouts::dot;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
@@ -374,7 +374,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// // matmul case
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
@@ -827,7 +827,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -1760,7 +1760,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -1864,7 +1864,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2027,7 +2027,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2063,7 +2063,7 @@ pub fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2099,7 +2099,7 @@ pub fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2141,7 +2141,7 @@ pub fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
@@ -2177,7 +2177,7 @@ pub fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2223,7 +2223,7 @@ pub fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2415,7 +2415,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
@@ -2472,7 +2472,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 12, 6, 4, 5, 6]),
@@ -2500,12 +2500,9 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
nonlinearity(
config,
region,
&[diff],
&LookupOp::GreaterThan { a: utils::F32(0.) },
)
let sign = sign(config, region, &[diff])?;
equals(config, region, &[sign, create_unit_tensor(1)])
}
/// Greater equals than operation.
@@ -2524,7 +2521,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
///
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
@@ -2544,21 +2541,17 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
let (mut lhs, mut rhs) = (values[0].clone(), values[1].clone());
let (lhs, rhs) = (values[0].clone(), values[1].clone());
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
lhs.expand(&broadcasted_shape)?;
rhs.expand(&broadcasted_shape)?;
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
nonlinearity(
// add 1 to lhs
let lhs_plus_one = pairwise(
config,
region,
&[diff],
&LookupOp::GreaterThanEqual { a: utils::F32(0.) },
)
&[lhs.clone(), create_unit_tensor(1)],
BaseOp::Add,
)?;
greater(config, region, &[lhs_plus_one, rhs])
}
/// Less than to operation.
@@ -2578,7 +2571,7 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 0, 5, 4, 5, 1]),
@@ -2619,7 +2612,7 @@ pub fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 0, 5, 4, 5, 1]),
@@ -2660,7 +2653,7 @@ pub fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1, 1, 0]),
@@ -2704,7 +2697,7 @@ pub fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1, 1, 0]),
@@ -2752,7 +2745,7 @@ pub fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1, 1, 0]),
@@ -2826,7 +2819,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::H
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1, 1, 0]),
@@ -2882,7 +2875,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1, 1, 0]),
@@ -2925,7 +2918,7 @@ pub fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let mask = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 0, 1, 0, 1, 0]),
@@ -2983,7 +2976,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
@@ -3015,7 +3008,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
@@ -3097,7 +3090,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
@@ -3198,7 +3191,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
@@ -3426,7 +3419,7 @@ pub fn deconv<
/// use ezkl::circuit::BaseConfig;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
@@ -4155,6 +4148,128 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
.map_err(|e| e.into())
}
pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
base: &usize,
n: &usize,
) -> Result<ValTensor<F>, CircuitError> {
let input = values[0].clone();
let is_assigned = !input.all_prev_assigned();
let bases: ValTensor<F> = Tensor::from(
(0..*n)
.rev()
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
)
.into();
let cartesian_coord = input
.dims()
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, input.dims())?;
let inner_loop_function =
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
let coord = cartesian_coord[i].clone();
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
if !is_assigned {
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
}
let mut claimed_output_slice = if region.witness_gen() {
sliced_input.decompose(*base, *n)?
} else {
Tensor::from(vec![ValType::Value(Value::unknown()); *n + 1].into_iter()).into()
};
claimed_output_slice =
region.assign(&config.custom_gates.inputs[1], &claimed_output_slice)?;
claimed_output_slice.flatten();
region.increment(claimed_output_slice.len());
// get the sign bit and make sure it is valid
let sign = claimed_output_slice.first()?;
let sign = range_check(config, region, &[sign], &(-1, 1))?;
// get the rest of the thing and make sure it is in the correct range
let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?;
let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?;
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
enforce_equality(config, region, &[sliced_input, signed_decomp])?;
Ok(claimed_output_slice.get_inner_tensor()?.clone())
};
region.apply_in_loop(&mut output, inner_loop_function)?;
let mut combined_output = output.combine()?;
let mut output_dims = input.dims().to_vec();
output_dims.push(*n + 1);
combined_output.reshape(&output_dims)?;
Ok(combined_output.into())
}
pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError> {
let mut decomp = decompose(config, region, values, &region.base(), &region.legs())?;
// get every n elements now, which correspond to the sign bit
decomp.get_every_n(region.legs() + 1)?;
decomp.reshape(values[0].dims())?;
Ok(decomp)
}
pub(crate) fn abs<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError> {
let sign = sign(config, region, values)?;
pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult)
}
pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError> {
let sign = sign(config, region, values)?;
let mut unit = create_unit_tensor(sign.len());
unit.reshape(sign.dims())?;
let relu_mask = equals(config, region, &[sign, unit])?;
pairwise(
config,
region,
&[values[0].clone(), relu_mask],
BaseOp::Mult,
)
}
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -4324,7 +4439,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
/// use ezkl::circuit::BaseConfig;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 2, 3, 2, 2, 0]),
@@ -4372,7 +4487,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::circuit::BaseConfig;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[100, 200, 300, 400, 500, 600]),

View File

@@ -15,14 +15,12 @@ use halo2curves::ff::PrimeField;
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Abs,
Div {
denom: utils::F32,
},
Cast {
scale: utils::F32,
},
ReLU,
Max {
scale: utils::F32,
a: utils::F32,
@@ -104,19 +102,6 @@ pub enum LookupOp {
Erf {
scale: utils::F32,
},
GreaterThan {
a: utils::F32,
},
LessThan {
a: utils::F32,
},
GreaterThanEqual {
a: utils::F32,
},
LessThanEqual {
a: utils::F32,
},
Sign,
KroneckerDelta,
Pow {
scale: utils::F32,
@@ -138,7 +123,6 @@ impl LookupOp {
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Abs => "abs".into(),
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
LookupOp::Floor { scale } => format!("floor_{}", scale),
LookupOp::Round { scale } => format!("round_{}", scale),
@@ -147,18 +131,12 @@ impl LookupOp {
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
LookupOp::Sign => "sign".into(),
LookupOp::LessThan { a } => format!("less_than_{}", a),
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::ReLU => "relu".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
@@ -188,91 +166,107 @@ impl LookupOp {
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())),
LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())),
LookupOp::RoundHalfToEven { scale } => Ok(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::Pow { scale, a } => Ok(tensor::ops::nonlinearities::pow(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::KroneckerDelta => Ok(tensor::ops::nonlinearities::kronecker_delta(&x)),
LookupOp::Max { scale, a } => Ok(tensor::ops::nonlinearities::max(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::Min { scale, a } => Ok(tensor::ops::nonlinearities::min(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)),
LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than(
&x,
f32::from(*a).into(),
)),
LookupOp::LessThanEqual { a } => Ok(tensor::ops::nonlinearities::less_than_equal(
&x,
f32::from(*a).into(),
)),
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
&x,
f32::from(*a).into(),
)),
LookupOp::GreaterThanEqual { a } => Ok(
tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()),
),
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*denom).into(),
)),
LookupOp::Cast { scale } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*scale).into(),
)),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
LookupOp::LeakyReLU { slope: a } => {
Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Sqrt { scale } => Ok(tensor::ops::nonlinearities::sqrt(&x, scale.into())),
LookupOp::Rsqrt { scale } => Ok(tensor::ops::nonlinearities::rsqrt(&x, scale.into())),
LookupOp::Erf { scale } => Ok(tensor::ops::nonlinearities::erffunc(&x, scale.into())),
LookupOp::Exp { scale } => Ok(tensor::ops::nonlinearities::exp(&x, scale.into())),
LookupOp::Ln { scale } => Ok(tensor::ops::nonlinearities::ln(&x, scale.into())),
LookupOp::Cos { scale } => Ok(tensor::ops::nonlinearities::cos(&x, scale.into())),
LookupOp::ACos { scale } => Ok(tensor::ops::nonlinearities::acos(&x, scale.into())),
LookupOp::Cosh { scale } => Ok(tensor::ops::nonlinearities::cosh(&x, scale.into())),
LookupOp::ACosh { scale } => Ok(tensor::ops::nonlinearities::acosh(&x, scale.into())),
LookupOp::Sin { scale } => Ok(tensor::ops::nonlinearities::sin(&x, scale.into())),
LookupOp::ASin { scale } => Ok(tensor::ops::nonlinearities::asin(&x, scale.into())),
LookupOp::Sinh { scale } => Ok(tensor::ops::nonlinearities::sinh(&x, scale.into())),
LookupOp::ASinh { scale } => Ok(tensor::ops::nonlinearities::asinh(&x, scale.into())),
LookupOp::Tan { scale } => Ok(tensor::ops::nonlinearities::tan(&x, scale.into())),
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
LookupOp::HardSwish { scale } => {
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
let res =
match &self {
LookupOp::Ceil { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
}
LookupOp::Floor { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
}
LookupOp::Round { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
}
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::KroneckerDelta => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
}
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
),
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
LookupOp::Cast { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::LeakyReLU { slope: a } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Sqrt { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
}
LookupOp::Rsqrt { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
}
LookupOp::Erf { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
}
LookupOp::Exp { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
}
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
}
LookupOp::Cos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
}
LookupOp::ACos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::acos(&x, scale.into()))
}
LookupOp::Cosh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cosh(&x, scale.into()))
}
LookupOp::ACosh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::acosh(&x, scale.into()))
}
LookupOp::Sin { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sin(&x, scale.into()))
}
LookupOp::ASin { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::asin(&x, scale.into()))
}
LookupOp::Sinh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sinh(&x, scale.into()))
}
LookupOp::ASinh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::asinh(&x, scale.into()))
}
LookupOp::Tan { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::tan(&x, scale.into()))
}
LookupOp::ATan { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::atan(&x, scale.into()))
}
LookupOp::ATanh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::atanh(&x, scale.into()))
}
LookupOp::Tanh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::tanh(&x, scale.into()))
}
LookupOp::HardSwish { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
let output = res.map(|x| integer_rep_to_felt(x));
@@ -289,7 +283,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the name of the operation
fn as_string(&self) -> String {
match self {
LookupOp::Abs => "ABS".into(),
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
@@ -298,11 +291,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,
@@ -313,7 +301,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::ReLU => "RELU".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
@@ -358,12 +345,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::Sign
| LookupOp::GreaterThan { .. }
| LookupOp::LessThan { .. }
| LookupOp::GreaterThanEqual { .. }
| LookupOp::LessThanEqual { .. }
| LookupOp::KroneckerDelta => 0,
LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
Ok(scale)

View File

@@ -9,6 +9,9 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
ReLU,
Abs,
Sign,
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
@@ -99,8 +102,7 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
,
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
@@ -110,6 +112,9 @@ impl<
fn as_string(&self) -> String {
match &self {
PolyOp::Abs => "ABS".to_string(),
PolyOp::Sign => "SIGN".to_string(),
PolyOp::ReLU => "RELU".to_string(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
@@ -191,6 +196,9 @@ impl<
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
@@ -368,6 +376,7 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
_ => in_scales[0],
};
Ok(scale)

View File

@@ -93,6 +93,10 @@ pub struct RegionSettings {
pub witness_gen: bool,
/// whether we should check range checks for validity
pub check_range: bool,
/// base for decompositions
pub base: usize,
/// number of legs for decompositions
pub legs: usize,
}
#[allow(unsafe_code)]
@@ -102,26 +106,32 @@ unsafe impl Send for RegionSettings {}
impl RegionSettings {
/// Create a new region settings
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
pub fn new(witness_gen: bool, check_range: bool, base: usize, legs: usize) -> RegionSettings {
RegionSettings {
witness_gen,
check_range,
base,
legs,
}
}
/// Create a new region settings with all true
pub fn all_true() -> RegionSettings {
pub fn all_true(base: usize, legs: usize) -> RegionSettings {
RegionSettings {
witness_gen: true,
check_range: true,
base,
legs,
}
}
/// Create a new region settings with all false
pub fn all_false() -> RegionSettings {
pub fn all_false(base: usize, legs: usize) -> RegionSettings {
RegionSettings {
witness_gen: false,
check_range: false,
base,
legs,
}
}
}
@@ -173,6 +183,16 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
}
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
/// get the region's decomposition base
pub fn base(&self) -> usize {
self.settings.base
}
/// get the region's decomposition legs
pub fn legs(&self) -> usize {
self.settings.legs
}
#[cfg(not(target_arch = "wasm32"))]
///
pub fn debug_report(&self) {
@@ -234,7 +254,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
pub fn new(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
let linear_coord = row * num_inner_cols;
@@ -246,7 +272,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(),
settings: RegionSettings::all_true(decomp_base, decomp_legs),
assigned_constants: HashMap::new(),
}
}
@@ -256,9 +282,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
new_self.assigned_constants = constants;
new_self
}

View File

@@ -8,6 +8,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
use ops::lookup::LookupOp;
use ops::region::RegionCtx;
use rand::rngs::OsRng;
@@ -55,7 +56,7 @@ mod matmul {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -132,7 +133,7 @@ mod matmul_col_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
config
.layout(
&mut region,
@@ -206,7 +207,7 @@ mod matmul_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -290,7 +291,7 @@ mod matmul_col_ultra_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
config
.layout(
&mut region,
@@ -407,7 +408,7 @@ mod matmul_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -518,7 +519,7 @@ mod dot {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -595,7 +596,7 @@ mod dot_col_overflow_triple_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 3);
let mut region = RegionCtx::new(region, 0, 3, 128, 2);
config
.layout(
&mut region,
@@ -668,7 +669,7 @@ mod dot_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -741,7 +742,7 @@ mod sum {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -811,7 +812,7 @@ mod sum_col_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
config
.layout(
&mut region,
@@ -880,7 +881,7 @@ mod sum_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -951,7 +952,7 @@ mod composition {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let _ = config
.layout(
&mut region,
@@ -1042,7 +1043,7 @@ mod conv {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -1193,7 +1194,7 @@ mod conv_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -1297,7 +1298,7 @@ mod conv_relu_col_ultra_overflow {
use super::*;
const K: usize = 4;
const K: usize = 8;
const LEN: usize = 15;
#[derive(Clone)]
@@ -1317,15 +1318,23 @@ mod conv_relu_col_ultra_overflow {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
let mut base_config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, (-3, 3), K, &LookupOp::ReLU)
.configure_range_check(cs, &a, &b, (-1, 1), K)
.unwrap();
base_config
.configure_range_check(cs, &a, &b, (0, 1), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 8, false);
base_config.clone()
}
@@ -1334,12 +1343,12 @@ mod conv_relu_col_ultra_overflow {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
config.layout_range_checks(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
let output = config
.layout(
&mut region,
@@ -1355,7 +1364,7 @@ mod conv_relu_col_ultra_overflow {
.layout(
&mut region,
&[output.unwrap().unwrap()],
Box::new(LookupOp::ReLU),
Box::new(PolyOp::ReLU),
)
.unwrap();
Ok(())
@@ -1476,7 +1485,7 @@ mod add_w_shape_casting {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1543,7 +1552,7 @@ mod add {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1627,7 +1636,7 @@ mod dynamic_lookup {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
@@ -1769,7 +1778,7 @@ mod shuffle {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
for i in 0..NUM_LOOP {
layouts::shuffles(
&config,
@@ -1884,7 +1893,7 @@ mod add_with_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1986,7 +1995,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(
|| "model",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.base
.layout(&mut region, &inputs, Box::new(PolyOp::Add))
@@ -2092,7 +2101,7 @@ mod sub {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.map_err(|_| Error::Synthesis)
@@ -2159,7 +2168,7 @@ mod mult {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.map_err(|_| Error::Synthesis)
@@ -2226,7 +2235,7 @@ mod pow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.map_err(|_| Error::Synthesis)
@@ -2258,7 +2267,6 @@ mod matmul_relu {
const K: usize = 18;
const LEN: usize = 32;
use crate::circuit::lookup::LookupOp;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -2288,11 +2296,17 @@ mod matmul_relu {
let mut base_config =
BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, &LookupOp::ReLU)
.configure_range_check(cs, &a, &b, (-1, 1), K)
.unwrap();
base_config
.configure_range_check(cs, &a, &b, (0, 1023), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 8, false);
MyConfig { base_config }
}
@@ -2301,11 +2315,14 @@ mod matmul_relu {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.base_config.layout_tables(&mut layouter).unwrap();
config
.base_config
.layout_range_checks(&mut layouter)
.unwrap();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
@@ -2315,7 +2332,7 @@ mod matmul_relu {
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU))
.unwrap();
Ok(())
},
@@ -2354,6 +2371,8 @@ mod relu {
plonk::{Circuit, ConstraintSystem, Error},
};
const K: u32 = 8;
#[derive(Clone)]
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
@@ -2370,16 +2389,26 @@ mod relu {
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
.map(|_| VarTensor::new_advice(cs, 8, 1, 3))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let mut config = BaseConfig::default();
let mut config = BaseConfig::configure(
cs,
&[advices[0].clone(), advices[1].clone()],
&advices[2],
CheckMode::SAFE,
);
config
.configure_lookup(cs, &advices[0], &advices[1], &advices[2], (-6, 6), 4, &nl)
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K as usize)
.unwrap();
config
.configure_range_check(cs, &advices[0], &advices[1], (0, 1), K as usize)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K as usize, 8, false);
config
}
@@ -2388,15 +2417,15 @@ mod relu {
mut config: Self::Config,
mut layouter: impl Layouter<F>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
config.layout_range_checks(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.map_err(|_| Error::Synthesis)
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
Ok(config
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.unwrap())
},
)
.unwrap();
@@ -2414,7 +2443,7 @@ mod relu {
input: ValTensor::from(input),
};
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
let prover = MockProver::run(K, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
@@ -2453,7 +2482,7 @@ mod lookup_ultra_overflow {
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let mut config = BaseConfig::default();
@@ -2481,9 +2510,13 @@ mod lookup_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.map_err(|_| Error::Synthesis)
},
)

View File

@@ -141,23 +141,23 @@ mod tests {
#[test]
fn f32_eq() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) != F32(5.0));
assert!(F32(5.0) != F32(std::f32::NAN));
assert!(F32(f32::NAN) == F32(f32::NAN));
assert!(F32(f32::NAN) != F32(5.0));
assert!(F32(5.0) != F32(f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}
#[test]
fn f32_cmp() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) < F32(5.0));
assert!(F32(5.0) > F32(std::f32::NAN));
assert!(F32(f32::NAN) == F32(f32::NAN));
assert!(F32(f32::NAN) < F32(5.0));
assert!(F32(5.0) > F32(f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}
#[test]
fn f32_hash() {
assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0)));
assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN)));
assert!(calculate_hash(&F32(f32::NAN)) == calculate_hash(&F32(-f32::NAN)));
}
}

View File

@@ -786,7 +786,8 @@ pub(crate) async fn gen_witness(
let commitment: Commitments = settings.run_args.commitment.into();
let region_settings = RegionSettings::all_true();
let region_settings =
RegionSettings::all_true(settings.run_args.decomp_base, settings.run_args.decomp_legs);
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
@@ -1178,7 +1179,10 @@ pub(crate) async fn calibrate(
&mut data.clone(),
None,
None,
RegionSettings::all_true(),
RegionSettings::all_true(
settings.run_args.decomp_base,
settings.run_args.decomp_legs,
),
)
.map_err(|e| format!("failed to forward: {}", e))?;
@@ -1372,8 +1376,10 @@ pub(crate) async fn calibrate(
let module_log_row = best_params.module_constraint_logrows_with_blinding();
let instance_logrows = best_params.log2_total_instances_with_blinding();
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
let range_check_logrows = best_params.range_check_log_rows_with_blinding();
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
reduction = std::cmp::max(reduction, range_check_logrows);
reduction = std::cmp::max(reduction, instance_logrows);
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);

View File

@@ -451,6 +451,18 @@ impl GraphSettings {
.ceil() as u32
}
/// Calc the number of rows required for the range checks
pub fn range_check_log_rows_with_blinding(&self) -> u32 {
let max_range = self
.required_range_checks
.iter()
.map(|x| x.1 - x.0)
.max()
.unwrap_or(0);
(max_range as f32).log2().ceil() as u32
}
fn model_constraint_logrows_with_blinding(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()

View File

@@ -547,7 +547,11 @@ impl Model {
})
.collect::<Result<Vec<_>, GraphError>>()?;
let res = self.dummy_layout(run_args, &inputs, RegionSettings::all_false())?;
let res = self.dummy_layout(
run_args,
&inputs,
RegionSettings::all_false(run_args.decomp_base, run_args.decomp_legs),
)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -883,16 +887,8 @@ impl Model {
);
}
None => {
let mut n = Node::new(
n.clone(),
&mut nodes,
scales,
&run_args.param_visibility,
i,
symbol_values,
run_args.div_rebasing,
run_args.rebase_frac_zero_constants,
)?;
let mut n =
Node::new(n.clone(), &mut nodes, scales, i, symbol_values, run_args)?;
if let Some(ref scales) = override_input_scales {
if let Some(inp) = n.opkind.get_input() {
let scale = scales[input_idx];
@@ -1112,6 +1108,8 @@ impl Model {
region,
0,
run_args.num_inner_cols,
run_args.decomp_base,
run_args.decomp_legs,
original_constants.clone(),
);
// we need to do this as this loop is called multiple times

View File

@@ -473,11 +473,9 @@ impl Node {
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
other_nodes: &mut BTreeMap<usize, super::NodeType>,
scales: &VarScales,
param_visibility: &Visibility,
idx: usize,
symbol_values: &SymbolValues,
div_rebasing: bool,
rebase_frac_zero_constants: bool,
run_args: &crate::RunArgs,
) -> Result<Self, GraphError> {
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
@@ -517,11 +515,10 @@ impl Node {
let (mut opkind, deleted_indices) = new_op_from_onnx(
idx,
scales,
param_visibility,
node.clone(),
&mut inputs,
symbol_values,
rebase_frac_zero_constants,
run_args,
)?; // parses the op name
// we can only take the inputs as mutable once -- so we need to collect them first
@@ -569,7 +566,7 @@ impl Node {
rescale_const_with_single_use(
constant,
in_scales.clone(),
param_visibility,
&run_args.param_visibility,
input_node.num_uses(),
)?;
input_node.replace_opkind(constant.clone_dyn().into());
@@ -589,7 +586,7 @@ impl Node {
global_scale,
out_scale,
scales.rebase_multiplier,
div_rebasing,
run_args.div_rebasing,
);
out_scale = opkind.out_scale(in_scales)?;

View File

@@ -273,11 +273,10 @@ fn load_op<C: tract_onnx::prelude::Op + Clone>(
pub fn new_op_from_onnx(
idx: usize,
scales: &VarScales,
param_visibility: &Visibility,
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
inputs: &mut [super::NodeType],
symbol_values: &SymbolValues,
rebase_frac_zero_constants: bool,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use tract_onnx::tract_core::ops::array::Trilu;
@@ -664,13 +663,16 @@ pub fn new_op_from_onnx(
// if all raw_values are round then set scale to 0
let all_round = raw_value.iter().all(|x| (x).fract() == 0.0);
if all_round && rebase_frac_zero_constants {
if all_round && run_args.rebase_frac_zero_constants {
constant_scale = 0;
}
// Quantize the raw value
let quantized_value =
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
let quantized_value = quantize_tensor(
raw_value.clone(),
constant_scale,
&run_args.param_visibility,
)?;
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
// Create a constant op
SupportedOp::Constant(c)
@@ -782,7 +784,7 @@ pub fn new_op_from_onnx(
deleted_indices.push(const_idx);
}
if unit == 0. {
SupportedOp::Nonlinear(LookupOp::ReLU)
SupportedOp::Linear(PolyOp::ReLU)
} else {
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
@@ -871,7 +873,7 @@ pub fn new_op_from_onnx(
"QuantizeLinearU8" | "DequantizeLinearF32" => {
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
"Abs" => SupportedOp::Linear(PolyOp::Abs),
"Neg" => SupportedOp::Linear(PolyOp::Neg),
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
@@ -1127,7 +1129,7 @@ pub fn new_op_from_onnx(
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Sign" => SupportedOp::Nonlinear(LookupOp::Sign),
"Sign" => SupportedOp::Linear(PolyOp::Sign),
"Pow" => {
// Extract the slope layer hyperparams from a const

View File

@@ -150,6 +150,7 @@ lazy_static! {
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(target_arch = "wasm32")]
@@ -277,6 +278,12 @@ pub struct RunArgs {
/// commitment scheme
#[arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other)]
pub commitment: Option<Commitments>,
/// the base used for decompositions
#[arg(long, default_value = "16384", value_hint = clap::ValueHint::Other)]
pub decomp_base: usize,
#[arg(long, default_value = "2", value_hint = clap::ValueHint::Other)]
/// the number of legs used for decompositions
pub decomp_legs: usize,
}
impl Default for RunArgs {
@@ -297,6 +304,8 @@ impl Default for RunArgs {
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: None,
decomp_base: 16384,
decomp_legs: 2,
}
}
}

View File

@@ -191,6 +191,12 @@ struct PyRunArgs {
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
/// int: The base used for decomposition
#[pyo3(get, set)]
pub decomp_base: usize,
/// int: The number of legs used for decomposition
#[pyo3(get, set)]
pub decomp_legs: usize,
}
/// default instantiation of PyRunArgs
@@ -221,6 +227,8 @@ impl From<PyRunArgs> for RunArgs {
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
}
}
}
@@ -243,6 +251,8 @@ impl Into<PyRunArgs> for RunArgs {
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
}
}
}

View File

@@ -1,5 +1,7 @@
use thiserror::Error;
use super::ops::DecompositionError;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
pub enum TensorError {
@@ -33,4 +35,7 @@ pub enum TensorError {
/// File load error
#[error("load error: {0}")]
FileLoadError(String),
/// Decomposition error
#[error("decomposition error: {0}")]
DecompositionError(#[from] DecompositionError),
}

View File

@@ -420,7 +420,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
std::fs::File::create(path).map_err(|e| TensorError::FileSaveError(e.to_string()))?;
let mut buf_writer = std::io::BufWriter::new(writer);
self.inner.iter().map(|x| x.clone()).for_each(|x| {
self.inner.iter().copied().for_each(|x| {
let x = x.to_repr();
buf_writer.write_all(x.as_ref()).unwrap();
});
@@ -764,6 +764,53 @@ impl<T: Clone + TensorType> Tensor<T> {
index
}
/// Fetches every nth element
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5]), &[3]).unwrap();
/// assert_eq!(a.get_every_n(2).unwrap(), expected);
/// assert_eq!(a.get_every_n(1).unwrap(), a);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 6]), &[2]).unwrap();
/// assert_eq!(a.get_every_n(5).unwrap(), expected);
///
/// ```
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if i % n == 0 {
inner.push(elem.clone());
}
}
Tensor::new(Some(&inner), &[inner.len()])
}
/// Excludes every nth element
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 4, 6]), &[3]).unwrap();
/// assert_eq!(a.exclude_every_n(2).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 3, 4, 5]), &[4]).unwrap();
/// assert_eq!(a.exclude_every_n(5).unwrap(), expected);
///
/// ```
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if i % n != 0 {
inner.push(elem.clone());
}
}
Tensor::new(Some(&inner), &[inner.len()])
}
/// Duplicates every nth element
///
/// ```
@@ -1217,6 +1264,31 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&[res]), &[1])
}
/// Get first elem from Tensor
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1]), &[1]).unwrap();
///
/// assert_eq!(a.first().unwrap(), b);
/// ```
pub fn first(&self) -> Result<Tensor<T>, TensorError>
where
T: Send + Sync,
{
let res = match self.inner.first() {
Some(e) => e.clone(),
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
))
}
};
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};

View File

@@ -7,6 +7,131 @@ use itertools::Itertools;
use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator};
pub use std::ops::{Add, Mul, Neg, Sub};
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
/// Decomposition error
pub enum DecompositionError {
/// Integer is too large to be represented by base and n
#[error("integer {} is too large to be represented by base {} and n {}", .0, .1, .2)]
TooLarge(IntegerRep, usize, usize),
}
/// Helper function to get the base decomp of an integer
/// # Arguments
/// * `x` - IntegerRep
/// * `n` - usize
/// * `base` - usize
///
pub fn get_rep(
x: &IntegerRep,
base: usize,
n: usize,
) -> Result<Vec<IntegerRep>, DecompositionError> {
// check if x is too large
if x.abs() > (base.pow(n as u32) as IntegerRep) {
return Err(DecompositionError::TooLarge(*x, base, n));
}
let mut rep = vec![0; n + 1];
// sign bit
rep[0] = if *x < 0 {
-1
} else if *x > 0 {
1
} else {
0
};
let mut x = x.abs();
//
for i in (1..rep.len()).rev() {
rep[i] = x % base as i128;
x /= base as i128;
}
Ok(rep)
}
/// Decompose a tensor of integers into a larger tensor with added dimension of size `n + 1` with the binary (or OTHER base) representation of the integer.
/// # Arguments
/// * `x` - Tensor
/// * `n` - usize
/// * `base` - usize
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::decompose;
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[0, 1, 2, -1]),
/// &[2, 2]).unwrap();
///
/// let result = decompose(&x, 2, 2).unwrap();
/// // result will have dims [2, 2, 3]
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0,
/// 1, 0, 1,
/// 1, 1, 0,
/// -1, 0, 1]), &[2, 2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = decompose(&x, 3, 1).unwrap();
///
///
/// // result will have dims [2, 2, 2]
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0,
/// 1, 1,
/// 1, 2,
/// -1, 1]), &[2, 2, 2]).unwrap();
///
/// assert_eq!(result, expected);
///
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[0, 11, 23, -1]),
/// &[2, 2]).unwrap();
///
/// let result = decompose(&x, 2, 5).unwrap();
/// // result will have dims [2, 2, 6]
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 0,
/// 1, 0, 1, 0, 1, 1,
/// 1, 1, 0, 1, 1, 1,
/// -1, 0, 0, 0, 0, 1]), &[2, 2, 6]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = decompose(&x, 16, 2).unwrap();
/// // result will have dims [2, 2, 3]
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0,
/// 1, 0, 11,
/// 1, 1, 7,
/// -1, 0, 1]), &[2, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
///
pub fn decompose(
x: &Tensor<IntegerRep>,
base: usize,
n: usize,
) -> Result<Tensor<IntegerRep>, TensorError> {
let mut dims = x.dims().to_vec();
dims.push(n + 1);
if n == 0 {
let mut x = x.clone();
x.reshape(&dims)?;
return Ok(x);
}
let resp = x
.par_iter()
.map(|val| get_rep(val, base, n))
// now collect the results into a Result<Vec<Vec<IntegerRep>>, DecompositionError>
.collect::<Result<Vec<Vec<IntegerRep>>, DecompositionError>>()?
.into_iter()
.flatten()
.collect::<Vec<IntegerRep>>();
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
Ok(output)
}
/// Trilu operation.
/// # Arguments
/// * `a` - Tensor
@@ -429,7 +554,7 @@ pub fn downsample<T: TensorType + Send + Sync>(
output = output.par_enum_map(|i, _: T| {
let coord = indices[i].clone();
Ok(input.get(&coord))
Ok::<_, TensorError>(input.get(&coord))
})?;
Ok(output)
@@ -489,7 +614,7 @@ pub fn gather<T: TensorType + Send + Sync>(
.map(|(i, x)| if i == dim { index_val } else { *x })
.collect::<Vec<_>>();
Ok(input.get(&new_coord))
Ok::<_, TensorError>(input.get(&new_coord))
})?;
// Reshape the output tensor
@@ -613,7 +738,7 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
let val = input.get(&new_coord);
Ok(val)
Ok::<_, TensorError>(val)
})?;
// Reshape the output tensor
@@ -927,7 +1052,7 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
Ok::<_, TensorError>(())
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1234,7 +1359,7 @@ pub fn concat<T: TensorType + Send + Sync>(
index += x;
}
Ok(inputs[input_index].get(&input_coord))
Ok::<_, TensorError>(inputs[input_index].get(&input_coord))
})?;
// Reshape the output tensor

View File

@@ -520,6 +520,54 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
}
/// Get the sign of the inner values
pub fn sign(&self) -> Result<Self, TensorError> {
let evals = self.int_evals()?;
Ok(evals
.par_enum_map(|_, val| {
Ok::<_, TensorError>(ValType::Value(Value::known(integer_rep_to_felt(
val.signum(),
))))
})?
.into())
}
/// Decompose the inner values into base `base` and `n` legs.
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let res = self
.get_inner()?
.par_iter()
.map(|x| {
let mut is_empty = true;
x.map(|_| is_empty = false);
if is_empty {
return Ok::<_, TensorError>(vec![Value::<F>::unknown(); n + 1]);
} else {
let mut res = vec![Value::unknown(); n + 1];
let mut int_rep = 0;
x.map(|f| {
int_rep = crate::fieldutils::felt_to_integer_rep(f);
});
let decompe = crate::tensor::ops::get_rep(&int_rep, base, n)?;
for (i, x) in decompe.iter().enumerate() {
res[i] = Value::known(crate::fieldutils::integer_rep_to_felt(*x));
}
Ok(res)
}
})
.collect::<Result<Vec<_>, _>>();
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);
tensor.reshape(&dims)?;
Ok(tensor.into())
}
/// Calls `int_evals` on the inner tensor.
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
// finally convert to vector of integers
@@ -574,7 +622,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Calls `get_slice` on the inner tensor.
/// Calls `last` on the inner tensor.
pub fn last(&self) -> Result<ValTensor<F>, TensorError> {
let slice = match self {
ValTensor::Value {
@@ -595,6 +643,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(slice)
}
/// Calls `first`
pub fn first(&self) -> Result<ValTensor<F>, TensorError> {
let slice = match self {
ValTensor::Value {
inner: v,
dims: _,
scale,
} => {
let inner = v.first()?;
let dims = inner.dims().to_vec();
ValTensor::Value {
inner,
dims,
scale: *scale,
}
}
_ => return Err(TensorError::WrongMethod),
};
Ok(slice)
}
/// Calls `get_slice` on the inner tensor.
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
@@ -775,6 +844,38 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Calls `get_every_n` on the inner [Tensor].
pub fn get_every_n(&mut self, n: usize) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.get_every_n(n)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
}
}
Ok(())
}
/// Calls `exclude_every_n` on the inner [Tensor].
pub fn exclude_every_n(&mut self, n: usize) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.exclude_every_n(n)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
}
}
Ok(())
}
/// remove constant zero values constants
pub fn remove_const_zero_values(&mut self) {
match self {

View File

@@ -286,7 +286,15 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, RegionSettings::all_false())
.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
None,
None,
RegionSettings::all_true(
circuit.settings().run_args.decomp_base,
circuit.settings().run_args.decomp_legs,
),
)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)

View File

@@ -23,7 +23,7 @@ examples_path = os.path.abspath(
srs_path = os.path.join(folder_path, 'kzg_test.params')
params_k17_path = os.path.join(folder_path, 'kzg_test_k17.params')
params_k20_path = os.path.join(folder_path, 'kzg_test_k20.params')
params_k21_path = os.path.join(folder_path, 'kzg_test_k21.params')
anvil_url = "http://localhost:3030"
@@ -104,8 +104,8 @@ def test_gen_srs():
ezkl.gen_srs(params_k17_path, 17)
assert os.path.isfile(params_k17_path)
ezkl.gen_srs(params_k20_path, 20)
assert os.path.isfile(params_k20_path)
ezkl.gen_srs(params_k21_path, 21)
assert os.path.isfile(params_k21_path)
@@ -677,7 +677,7 @@ async def test_aggregate_and_verify_aggr():
)
# mock aggregate
res = ezkl.mock_aggregate([proof_path], 20)
res = ezkl.mock_aggregate([proof_path], 21)
assert res == True
aggregate_proof_path = os.path.join(folder_path, 'aggr_1l_relu.pf')
@@ -688,8 +688,8 @@ async def test_aggregate_and_verify_aggr():
[proof_path],
aggregate_vk_path,
aggregate_pk_path,
20,
srs_path=params_k20_path,
21,
srs_path=params_k21_path,
)
res = ezkl.gen_vk_from_pk_aggr(aggregate_pk_path, aggregate_vk_path)
@@ -701,9 +701,9 @@ async def test_aggregate_and_verify_aggr():
aggregate_proof_path,
aggregate_pk_path,
"poseidon",
20,
21,
"unsafe",
srs_path=params_k20_path,
srs_path=params_k21_path,
)
assert res == True
@@ -713,8 +713,8 @@ async def test_aggregate_and_verify_aggr():
res = ezkl.verify_aggr(
aggregate_proof_path,
aggregate_vk_path,
20,
srs_path=params_k20_path,
21,
srs_path=params_k21_path,
)
assert res == True
@@ -793,8 +793,8 @@ async def test_evm_aggregate_and_verify_aggr():
[proof_path],
aggregate_vk_path,
aggregate_pk_path,
20,
srs_path=params_k20_path,
21,
srs_path=params_k21_path,
)
res = ezkl.aggregate(
@@ -802,9 +802,9 @@ async def test_evm_aggregate_and_verify_aggr():
aggregate_proof_path,
aggregate_pk_path,
"evm",
20,
21,
"unsafe",
srs_path=params_k20_path,
srs_path=params_k21_path,
)
assert res == True
@@ -819,8 +819,8 @@ async def test_evm_aggregate_and_verify_aggr():
aggregate_vk_path,
sol_code_path,
abi_path,
logrows=20,
srs_path=params_k20_path,
logrows=21,
srs_path=params_k21_path,
)
assert res == True
@@ -838,8 +838,8 @@ async def test_evm_aggregate_and_verify_aggr():
res = ezkl.verify_aggr(
aggregate_proof_path,
aggregate_vk_path,
20,
srs_path=params_k20_path,
21,
srs_path=params_k21_path,
)
assert res == True

View File

@@ -16,14 +16,12 @@ mod wasm32 {
srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::commitment::CommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::Bn256;
use halo2curves::bn256::{Fr, G1Affine};
use snark_verifier::util::arithmetic::PrimeField;
use wasm_bindgen::JsError;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
use wasm_bindgen_test::*;

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -8,7 +8,7 @@
"param_scale": 0,
"scale_rebase_multiplier": 10,
"lookup_range": [
-2,
0,
0
],
"logrows": 6,
@@ -24,15 +24,18 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE"
"check_mode": "UNSAFE",
"commitment": "KZG",
"decomp_base": 128,
"decomp_legs": 2
},
"num_rows": 16,
"num_rows": 46,
"total_assignments": 92,
"total_const_size": 3,
"total_dynamic_col_size": 0,
"num_dynamic_lookups": 0,
"num_shuffles": 0,
"total_shuffle_col_size": 0,
"total_assignments": 32,
"total_const_size": 8,
"model_instance_shapes": [
[
1,
@@ -54,12 +57,19 @@
]
]
},
"required_lookups": [
"ReLU"
"required_lookups": [],
"required_range_checks": [
[
-1,
1
],
[
0,
127
]
],
"required_range_checks": [],
"check_mode": "UNSAFE",
"version": "0.0.0",
"num_blinding_factors": null,
"timestamp": 1702474230544
"timestamp": 1726429587279
}

Binary file not shown.

View File

@@ -1 +1 @@
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127}