feat: dictionary of reusable constants (#754)

This commit is contained in:
dante
2024-03-26 13:12:09 +00:00
committed by GitHub
parent 3be988a6a0
commit 7fe179b8d4
25 changed files with 672 additions and 396 deletions

View File

@@ -139,6 +139,20 @@ jobs:
target: ${{ matrix.target }}
manylinux: auto
args: --release --out dist --features python-bindings
before-script-linux: |
# If we're running on rhel centos, install needed packages.
if command -v yum &> /dev/null; then
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
# If we're running on i686 we need to symlink libatomic
# in order to build openssl with -latomic flag.
if [[ ! -d "/usr/lib64" ]]; then
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
fi
else
# If we're running on debian-based system.
apt update -y && apt-get install -y libssl-dev openssl pkg-config
fi
- name: Install built wheel
if: matrix.target == 'x86_64'

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
@@ -48,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
}
}

View File

@@ -15,6 +15,8 @@ pub use planner::*;
use crate::tensor::{TensorType, ValTensor};
use super::region::ConstantsMap;
/// Module trait used to extend ezkl functionality
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Config
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
/// Layout
fn layout(
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;

View File

@@ -4,6 +4,8 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
@@ -13,6 +15,7 @@ use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::group::Curve;
use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::Module;
@@ -107,6 +110,7 @@ impl Module<Fp> for PolyCommitChip {
&self,
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
Ok(())
}
@@ -119,11 +123,24 @@ impl Module<Fp> for PolyCommitChip {
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| self.config.inputs.assign(&mut region, 0, &input[0]),
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
}
@@ -184,7 +201,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let polycommit_chip = PolyCommitChip::new(config);
polycommit_chip.layout(&mut layouter, &[self.message.clone()], 0);
polycommit_chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
);
Ok(())
}

View File

@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::Module;
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
&self,
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
let start_time = instant::Instant::now();
let local_constants = constants.clone();
let res = layouter.assign_region(
|| "load message",
|mut region| {
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
Ok(res)
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -434,7 +453,7 @@ mod tests {
*,
};
use std::marker::PhantomData;
use std::{collections::HashMap, marker::PhantomData};
use halo2_gadgets::poseidon::primitives::Spec;
use halo2_proofs::{
@@ -477,7 +496,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
)?;
Ok(())
}

View File

@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {

View File

@@ -46,7 +46,8 @@ pub enum HybridOp {
dim: usize,
},
Softmax {
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
},
RangeCheck(Tolerance),
@@ -70,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -130,9 +131,16 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax { scale, axes } => {
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => tensor::ops::nonlinearities::softmax_axes(
&x,
input_scale.into(),
output_scale.into(),
axes,
),
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
@@ -203,8 +211,15 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
HybridOp::Softmax { scale, axes } => {
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
@@ -324,9 +339,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::ReduceArgMin { dim } => {
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
}
HybridOp::Softmax { scale, axes } => {
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => layouts::softmax_axes(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
axes,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
@@ -359,8 +383,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
_ => in_scales[0],
};
Ok(scale)

View File

@@ -9,6 +9,7 @@ use halo2curves::ff::PrimeField;
use itertools::Itertools;
use log::{error, trace};
use maybe_rayon::{
iter::IntoParallelRefIterator,
prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
};
@@ -33,7 +34,7 @@ use super::*;
use crate::circuit::ops::lookup::LookupOp;
/// Same as div but splits the division into N parts
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -68,7 +69,7 @@ pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
}
/// Div accumulated layout
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -93,9 +94,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
@@ -133,7 +134,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
}
/// recip accumulated layout
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -166,9 +167,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
felt_to_i128(input_scale) as f64,
felt_to_i128(output_scale) as f64,
)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
@@ -226,7 +227,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
}
/// Dot product accumulated layout
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -337,7 +338,7 @@ pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
}
/// Einsum
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[ValTensor<F>],
@@ -524,14 +525,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
// Compute the product of all input tensors
for pair in input_pairs {
let product_across_pair = prod(
config,
region,
&[pair.try_into().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?],
)?;
let product_across_pair = prod(config, region, &[pair.into()])?;
if let Some(product) = prod_res {
prod_res = Some(
@@ -563,7 +557,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -578,8 +572,8 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
@@ -607,7 +601,7 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
}
///
fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -622,7 +616,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
}
/// Select top k elements
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -642,11 +636,12 @@ pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
fn select<F: PrimeField + TensorType + PartialOrd>(
fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let start = instant::Instant::now();
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
@@ -656,12 +651,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
let output: ValTensor<F> = if is_assigned {
let output: ValTensor<F> = if is_assigned && region.witness_gen() {
let felt_evals = input.get_felt_evals()?;
index
.get_int_evals()?
.iter()
.map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize]))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(felt_evals.get(&[*x as usize])))
.collect::<Tensor<Value<F>>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); index.len()]),
@@ -673,10 +669,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
let (_, assigned_output) =
dynamic_lookup(config, region, &[index, output], &[dim_indices, input])?;
let end = start.elapsed();
trace!("select took: {:?}", end);
Ok(assigned_output)
}
fn one_hot<F: PrimeField + TensorType + PartialOrd>(
fn one_hot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -692,7 +691,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
let output: ValTensor<F> = if is_assigned {
let int_evals = input.get_int_evals()?;
let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<_>>()
} else {
@@ -728,12 +727,13 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
}
/// Dynamic lookup
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let start = instant::Instant::now();
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err("lookups must be same length".into());
@@ -802,11 +802,14 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);
let end = start.elapsed();
trace!("dynamic lookup took: {:?}", end);
Ok((lookup_0, lookup_1))
}
/// Shuffle arg
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
@@ -869,7 +872,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
}
/// One hot accumulated layout
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -922,7 +925,7 @@ pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -950,7 +953,7 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -973,7 +976,7 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1024,7 +1027,7 @@ pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1032,6 +1035,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let start_time = instant::Instant::now();
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
@@ -1105,6 +1109,9 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
region.apply_in_loop(&mut output, inner_loop_function)?;
let elapsed = start_time.elapsed();
trace!("linearize_element_index took: {:?}", elapsed);
Ok(output.into())
}
@@ -1125,7 +1132,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1232,7 +1239,6 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
let const_offset = create_constant_tensor(const_offset, 1);
let mut results = vec![];
@@ -1250,16 +1256,18 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
let res = sum(config, region, &[res])?;
results.push(res.get_inner_tensor()?.clone());
// assert than res is less than the product of the dims
assert!(
if region.witness_gen() {
assert!(
res.get_int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as i128),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
}
let result_tensor = Tensor::from(results.into_iter());
@@ -1273,7 +1281,9 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn get_missing_set_elements<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1304,7 +1314,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
}
fullset_evals
.iter()
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1337,7 +1347,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -1354,14 +1364,14 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1419,7 +1429,7 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Scatter Nd
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -1433,14 +1443,14 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1457,7 +1467,6 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) =
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
@@ -1498,7 +1507,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
}
/// sum accumulated layout
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1581,7 +1590,7 @@ pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
}
/// product accumulated layout
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1661,7 +1670,7 @@ pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
}
/// Axes wise op wrapper
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1722,7 +1731,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
}
/// Sum accumulated layout
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1733,7 +1742,7 @@ pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Sum accumulated layout
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1744,7 +1753,7 @@ pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// argmax layout
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1762,7 +1771,7 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Max accumulated layout
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1774,7 +1783,7 @@ pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmin layout
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1792,7 +1801,7 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Min accumulated layout
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1804,7 +1813,7 @@ pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Pairwise (elementwise) op layout
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1959,7 +1968,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
}
/// expand the tensor to the given shape
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1972,7 +1981,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1995,7 +2004,7 @@ pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2018,7 +2027,7 @@ pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2028,7 +2037,7 @@ pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2038,7 +2047,7 @@ pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
}
/// And boolean operation
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2049,7 +2058,7 @@ pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
}
/// Or boolean operation
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2065,7 +2074,7 @@ pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
}
/// Equality boolean operation
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2075,7 +2084,7 @@ pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
}
/// Equality boolean operation
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2109,7 +2118,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
}
/// Xor boolean operation
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2135,7 +2144,7 @@ pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
}
/// Not boolean operation
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2151,7 +2160,7 @@ pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
}
/// Iff
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -2175,7 +2184,7 @@ pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
}
/// Negation operation accumulated layout
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2185,7 +2194,7 @@ pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
}
/// Sumpool accumulated layout
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
@@ -2239,7 +2248,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
}
/// Convolution accumulated layout
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2304,7 +2313,7 @@ pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
/// DeConvolution accumulated layout
pub(crate) fn deconv<
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -2397,7 +2406,7 @@ pub(crate) fn deconv<
/// Convolution accumulated layout
pub(crate) fn conv<
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -2578,7 +2587,7 @@ pub(crate) fn conv<
}
/// Power accumulated layout
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2594,7 +2603,7 @@ pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
}
/// Rescaled op accumulated layout
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
@@ -2616,7 +2625,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
}
/// Dummy (no contraints) reshape layout
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
new_dims: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
@@ -2626,7 +2635,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
}
/// Dummy (no contraints) move_axis layout
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
source: usize,
destination: usize,
@@ -2637,7 +2646,7 @@ pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
}
/// resize layout
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2651,7 +2660,7 @@ pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
}
/// Slice layout
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2674,7 +2683,7 @@ pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
}
/// Trilu layout
pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2698,7 +2707,7 @@ pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd>(
}
/// Concat layout
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>],
axis: &usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
@@ -2710,7 +2719,7 @@ pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
}
/// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2725,7 +2734,7 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2761,7 +2770,7 @@ pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
}
/// Downsample layout
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2778,7 +2787,7 @@ pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for enforcing two sets of cells to be equal
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2804,7 +2813,7 @@ pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for range check.
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2860,12 +2869,13 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if region.throw_range_check_error() {
let is_assigned = !w.any_unknowns()?;
if is_assigned && region.witness_gen() {
// assert is within range
let int_values = w.get_int_evals()?;
for v in int_values {
if v < range.0 || v > range.1 {
log::debug!("Value ({:?}) out of range: {:?}", v, range);
for v in int_values.iter() {
if v < &range.0 || v > &range.1 {
log::error!("Value ({:?}) out of range: {:?}", v, range);
return Err(Box::new(TensorError::TableLookupError));
}
}
@@ -2885,7 +2895,7 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for nonlinearity check.
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2987,7 +2997,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmax
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3023,7 +3033,7 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmin
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3059,7 +3069,7 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
}
/// max layout
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3069,7 +3079,7 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
}
/// min layout
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3077,7 +3087,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
_sort_ascending(config, region, values)?.get_slice(&[0..1])
}
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3180,18 +3190,19 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
}
/// softmax layout
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let soft_max_at_scale = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
softmax(config, region, values, scale)
softmax(config, region, values, input_scale, output_scale)
};
let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?;
@@ -3199,33 +3210,62 @@ pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// softmax func
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd>(
/// percent func
pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let is_assigned = values[0].all_prev_assigned();
let mut input = values[0].clone();
if !is_assigned {
input = region.assign(&config.custom_gates.inputs[0], &values[0])?;
region.increment(input.len());
};
// sum of exps
let denom = sum(config, region, &[input.clone()])?;
let input_felt_scale = F::from(input_scale.0 as u64);
let output_felt_scale = F::from(output_scale.0 as u64);
let inv_denom = recip(
config,
region,
&[denom],
input_felt_scale,
output_felt_scale,
)?;
// product of num * (1 / denom) = 2*output_scale
let percent = pairwise(config, region, &[input, inv_denom], BaseOp::Mult)?;
// rebase the percent to 2x the scale
loop_div(config, region, &[percent], input_felt_scale)
}
/// softmax func
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// elementwise exponential
let ex = nonlinearity(config, region, values, &LookupOp::Exp { scale })?;
let ex = nonlinearity(
config,
region,
values,
&LookupOp::Exp { scale: input_scale },
)?;
// sum of exps
let denom = sum(config, region, &[ex.clone()])?;
// get the inverse
let felt_scale = F::from(scale.0 as u64);
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
// product of num * (1 / denom) = 2*output_scale
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
Ok(softmax)
percent(config, region, &[ex.clone()], input_scale, output_scale)
}
/// Checks that the percent error between the expected public output and the actual output value
/// is within the percent error expressed by the `tol` input, where `tol == 1.0` means the percent
/// error tolerance is 1 percent.
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],

View File

@@ -137,7 +137,7 @@ impl LookupOp {
}
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self

View File

@@ -27,12 +27,14 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
@@ -98,7 +100,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
}
}
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -165,7 +167,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(self.scale)
}
@@ -226,7 +228,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(0)
}
@@ -256,7 +258,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
pub quantized_values: Tensor<F>,
///
@@ -266,7 +268,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -293,8 +295,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for Constant<F>
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -89,8 +89,14 @@ pub enum PolyOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for PolyOp
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {

View File

@@ -2,24 +2,28 @@ use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI128 as AtomicInt;
use std::{
cell::RefCell,
collections::HashSet,
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
@@ -120,12 +124,11 @@ impl From<Box<dyn std::error::Error>> for RegionError {
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
@@ -133,13 +136,34 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
witness_gen: bool,
assigned_constants: ConstantsMap<F>,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
#[cfg(not(target_arch = "wasm32"))]
///
pub fn increment_total_constants(&mut self, n: usize) {
self.total_constants += n;
pub fn debug_report(&self) {
log::debug!(
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
self.row().to_string().blue(),
self.linear_coord().to_string().yellow(),
self.total_constants().to_string().red(),
self.max_lookup_inputs().to_string().green(),
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green());
}
///
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
&self.assigned_constants
}
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants.into_iter());
}
///
@@ -163,8 +187,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
pub fn witness_gen(&self) -> bool {
self.witness_gen
}
/// Create a new region context
@@ -177,7 +201,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -185,9 +208,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: true,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_with_constants(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
@@ -202,7 +238,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
@@ -210,16 +245,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -228,7 +260,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -236,17 +267,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy_with_constants(
pub fn new_dummy_with_linear_coord(
row: usize,
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
witness_gen: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -254,7 +285,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -262,7 +292,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
@@ -312,29 +343,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(), RegionError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
self.throw_range_check_error,
self.witness_gen,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -343,10 +372,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
constants.fetch_add(
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
@@ -362,11 +387,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants.into_iter());
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
@@ -410,6 +437,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
Ok(())
}
@@ -435,7 +470,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
}
let range_size = (range.1 - range.0).abs();
@@ -477,7 +512,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Get the total number of constants
pub fn total_constants(&self) -> usize {
self.total_constants
self.assigned_constants.len()
}
/// Get the dynamic lookup index
@@ -525,26 +560,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
if let Some(region) = &self.region {
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
Ok(cell.into())
} else {
Ok(value.into())
}
}
/// Assign a valtensor to a vartensor
pub fn assign(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -560,14 +591,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -594,13 +627,21 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
} else {
self.total_constants += values.num_constants();
let mut values_map = values.create_constants_map();
let inner_tensor = values.get_inner_tensor().unwrap();
for o in ommissions {
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -615,24 +656,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len, total_assigned_constants) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((res, len))
} else {
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((values.clone(), len))
}
}
@@ -699,9 +740,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
Ok(())
}
/// increment constants
pub fn increment_constants(&mut self, n: usize) {
self.total_constants += n
}
}

View File

@@ -98,7 +98,7 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
@@ -138,7 +138,7 @@ pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -275,7 +275,7 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
@@ -303,7 +303,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);

View File

@@ -1911,6 +1911,8 @@ mod add_with_overflow {
#[cfg(test)]
mod add_with_overflow_and_poseidon {
use std::collections::HashMap;
use halo2curves::bn256::Fr;
use crate::circuit::modules::{
@@ -1969,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
let assigned_inputs_a =
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
let assigned_inputs_b =
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
layouter.assign_region(|| "_new_module", |_| Ok(()))?;

View File

@@ -24,6 +24,8 @@ use crate::pfsys::{
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::{Commitments, RunArgs};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
@@ -538,7 +540,7 @@ fn check_srs_hash(
let path = get_srs_path(logrows, srs_path, commitment);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
let predefined_hash = match crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) {
Some(h) => h,
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
};
@@ -684,7 +686,7 @@ pub(crate) async fn gen_witness(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
)?
}
Commitments::IPA => {
@@ -698,16 +700,16 @@ pub(crate) async fn gen_witness(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
)?
}
}
} else {
warn!("SRS for poly commit does not exist (will be ignored)");
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
}
} else {
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
};
// print each variable tuple (symbol, value) as symbol=value
@@ -819,7 +821,15 @@ impl AccuracyResults {
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let percentage_error = error.enum_map(|i, x| {
// if everything is 0 then we can't divide by 0 so we just return 0
let res = if original[i] == 0.0 && x == 0.0 {
0.0
} else {
x / original[i]
};
Ok::<f32, TensorError>(res)
})?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error);
@@ -888,6 +898,7 @@ pub(crate) fn calibrate(
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use log::error;
use std::collections::HashMap;
use tabled::Table;
@@ -900,9 +911,9 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
debug!("num of calibration batches: {}", chunks.len());
info!("running onnx predictions...");
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
&settings.run_args,
&model_path,
@@ -970,10 +981,18 @@ pub(crate) fn calibrate(
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
let mut num_failed = 0;
let mut num_passed = 0;
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().blue(),
div_rebasing.to_string().yellow(),
num_failed.to_string().red(),
num_passed.to_string().green()
));
let key = (
@@ -1007,7 +1026,9 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
debug!("circuit creation from run args failed: {:?}", e);
error!("circuit creation from run args failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
};
@@ -1039,7 +1060,9 @@ pub(crate) fn calibrate(
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
}
@@ -1104,8 +1127,10 @@ pub(crate) fn calibrate(
"found settings: \n {}",
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else {
debug!("calibration failed {}", res.err().unwrap());
error!("calibration failed {}", res.err().unwrap());
num_failed += 1;
}
pb.inc(1);
@@ -1879,7 +1904,9 @@ pub(crate) fn mock_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1922,7 +1949,9 @@ pub(crate) fn setup_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG",).into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1983,7 +2012,9 @@ pub(crate) fn aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}

View File

@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
@@ -155,7 +156,7 @@ use std::cell::RefCell;
thread_local!(
/// This is a global variable that holds the settings for the graph
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
);
/// Result from a forward pass
@@ -1051,12 +1052,10 @@ impl GraphCircuit {
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let margin = (
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
margin
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
@@ -1240,7 +1239,7 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1289,7 +1288,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
.forward(inputs, &self.settings().run_args, witness_gen)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1604,6 +1603,8 @@ impl Circuit<Fp> for GraphCircuit {
let output_vis = &self.settings().run_args.output_visibility;
let mut graph_modules = GraphModules::new();
let mut constants = ConstantsMap::new();
let mut config = config.clone();
let mut inputs = self
@@ -1649,6 +1650,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut input_outlets,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace inputs with the outlets
for (i, outlet) in outlets.iter().enumerate() {
@@ -1661,6 +1663,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut inputs,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
}
@@ -1697,6 +1700,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut flattened_params,
param_visibility,
&mut instance_offset,
&mut constants,
)?;
let shapes = self.model().const_shapes();
@@ -1725,6 +1729,7 @@ impl Circuit<Fp> for GraphCircuit {
&inputs,
&mut vars,
&outputs,
&mut constants,
)
.map_err(|e| {
log::error!("{}", e);
@@ -1749,6 +1754,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut output_outlets,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace outputs with the outlets
@@ -1762,6 +1768,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut outputs,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
}

View File

@@ -5,6 +5,7 @@ use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
@@ -404,7 +405,7 @@ impl ParsedNodes {
.get(input)
.ok_or(GraphError::MissingNode(*input))?;
let input_dims = node.out_dims();
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
inputs.push(input_dim.clone());
}
@@ -514,21 +515,24 @@ impl Model {
instance_shapes.len().to_string().blue(),
"instances".blue()
);
// this is the total number of variables we will need to allocate
// for the circuit
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
let len = shape.iter().product();
let mut t: ValTensor<Fp> = (0..len)
.map(|_| {
if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
t.reshape(shape)?;
Ok(t)
})
@@ -577,13 +581,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
Ok(res.into())
}
@@ -1071,6 +1075,8 @@ impl Model {
/// * `layouter` - Halo2 Layouter.
/// * `inputs` - The values to feed into the circuit.
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1079,6 +1085,7 @@ impl Model {
inputs: &[ValTensor<Fp>],
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
info!("model layout...");
@@ -1104,14 +1111,12 @@ impl Model {
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
let mut total_const_size = 0;
let original_constants = constants.clone();
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1157,29 +1162,17 @@ impl Model {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
} else if !run_args.output_visibility.is_private() {
for output in &outputs {
thread_safe_region.increment_total_constants(output.num_constants());
}
}
num_rows = thread_safe_region.row();
linear_coord = thread_safe_region.linear_coord();
total_const_size = thread_safe_region.total_constants();
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
thread_safe_region.debug_report();
*constants = thread_safe_region.assigned_constants().clone();
Ok(outputs)
},
)?;
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
"rows".blue(),
linear_coord.to_string().yellow(),
total_const_size.to_string().red()
);
)?;
let duration = start_time.elapsed();
trace!("model layout took: {:?}", duration);
@@ -1213,16 +1206,10 @@ impl Model {
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
@@ -1380,7 +1367,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1399,29 +1386,31 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let default_value = if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
@@ -1432,7 +1421,7 @@ impl Model {
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
region.update_constants(output.create_constants_map());
}
}
@@ -1441,14 +1430,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
"rows".blue(),
region.linear_coord().to_string().yellow(),
region.total_constants().to_string().red()
);
region.debug_report();
let outputs = outputs
.iter()

View File

@@ -2,6 +2,7 @@ use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
@@ -211,12 +212,13 @@ impl GraphModules {
layouter: &mut impl Layouter<Fp>,
x: &mut Vec<ValTensor<Fp>>,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
// reserve module 0 for ... modules
// hash the input and replace the constrained cells in the input
let cloned_x = (*x).clone();
x[0] = module
.layout(layouter, &cloned_x, instance_offset.to_owned())
.layout(layouter, &cloned_x, instance_offset.to_owned(), constants)
.unwrap();
for inc in module.instance_increment_input().iter() {
// increment the instance offset to make way for future module layouts
@@ -234,6 +236,7 @@ impl GraphModules {
values: &mut [ValTensor<Fp>],
element_visibility: &Visibility,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
if element_visibility.is_polycommit() && !values.is_empty() {
// concat values and sk to get the inputs
@@ -248,7 +251,7 @@ impl GraphModules {
layouter
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
.unwrap();
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
// increment the current index
self.polycommit_idx += 1;
});
@@ -270,7 +273,7 @@ impl GraphModules {
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
});
// replace the inputs with the outputs
values.iter_mut().enumerate().for_each(|(i, x)| {

View File

@@ -1072,8 +1072,12 @@ pub fn new_op_from_onnx(
}
};
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
})
}

View File

@@ -346,7 +346,7 @@ pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Get instance col
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {

View File

@@ -23,7 +23,6 @@
)]
// we allow this for our dynamic range based indexing scheme
#![allow(clippy::single_range_in_vec_init)]
#![feature(round_ties_even)]
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!

View File

@@ -378,7 +378,7 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
{
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<Value<F>> {
let mut output = Vec::new();
for (_, x) in value.iter().enumerate() {
for x in value.iter() {
output.push(x.value_field().evaluate());
}
Tensor::new(Some(&output), value.dims()).unwrap()
@@ -434,6 +434,18 @@ impl<F: PrimeField + TensorType + Clone> From<Tensor<i128>> for Tensor<Value<F>>
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::FromParallelIterator<T> for Tensor<T>
{
fn from_par_iter<I>(par_iter: I) -> Self
where
I: maybe_rayon::iter::IntoParallelIterator<Item = T>,
{
let inner: Vec<T> = par_iter.into_par_iter().collect();
Tensor::new(Some(&inner), &[inner.len()]).unwrap()
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::IntoParallelIterator for Tensor<T>
{

View File

@@ -1875,11 +1875,7 @@ pub fn topk<T: TensorType + PartialOrd>(
let mut indexed_a = a.clone();
indexed_a.flatten();
let mut indexed_a = a
.iter()
.enumerate()
.map(|(i, x)| (i, x))
.collect::<Vec<_>>();
let mut indexed_a = a.iter().enumerate().collect::<Vec<_>>();
if largest {
indexed_a.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
@@ -3532,12 +3528,17 @@ pub mod nonlinearities {
}
/// softmax layout
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
pub fn softmax_axes(
a: &Tensor<i128>,
input_scale: f64,
output_scale: f64,
axes: &[usize],
) -> Tensor<i128> {
// we want this to be as small as possible so we set the output scale to 1
let dims = a.dims();
if dims.len() == 1 {
return softmax(a, scale);
return softmax(a, input_scale, output_scale);
}
let cartesian_coord = dims[..dims.len() - 1]
@@ -3560,7 +3561,7 @@ pub mod nonlinearities {
let softmax_input = a.get_slice(&sum_dims).unwrap();
let res = softmax(&softmax_input, scale);
let res = softmax(&softmax_input, input_scale, output_scale);
outputs.push(res);
}
@@ -3587,20 +3588,25 @@ pub mod nonlinearities {
/// Some(&[2, 2, 3, 2, 2, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = softmax(&x, 128.0);
/// let result = softmax(&x, 128.0, 128.0 * 128.0);
/// // doubles the scale of the input
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2734, 2734, 2755, 2734, 2734, 2692]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
pub fn softmax(a: &Tensor<i128>, input_scale: f64, output_scale: f64) -> Tensor<i128> {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let exp = exp(a, scale);
let exp = exp(a, input_scale);
let sum = sum(&exp).unwrap();
let inv_denom = recip(&sum, scale, scale);
(exp * inv_denom).unwrap()
let inv_denom = recip(&sum, input_scale, output_scale);
let mut res = (exp * inv_denom).unwrap();
res = res
.iter()
.map(|x| ((*x as f64) / input_scale).round() as i128)
.collect();
res.reshape(a.dims()).unwrap();
res
}
/// Applies range_check_percent

View File

@@ -1,8 +1,10 @@
use crate::circuit::region::ConstantsMap;
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -51,6 +53,24 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
}
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
/// Returns the inner cell of the [ValType].
pub fn cell(&self) -> Option<Cell> {
match self {
ValType::PrevAssigned(cell) => Some(cell.cell()),
ValType::AssignedConstant(cell, _) => Some(cell.cell()),
_ => None,
}
}
/// Returns the assigned cell of the [ValType].
pub fn assigned_cell(&self) -> Option<AssignedCell<F, F>> {
match self {
ValType::PrevAssigned(cell) => Some(cell.clone()),
ValType::AssignedConstant(cell, _) => Some(cell.clone()),
_ => None,
}
}
/// Returns true if the value is previously assigned.
pub fn is_prev_assigned(&self) -> bool {
matches!(
@@ -293,7 +313,7 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
}
}
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
pub fn new_instance(
cs: &mut ConstraintSystem<F>,
@@ -435,6 +455,22 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => {
let map = inner.iter().fold(ConstantsMap::new(), |mut acc, x| {
if let ValType::Constant(c) = x {
acc.insert(*c, x.clone());
}
acc
});
map
}
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}
/// Fetch the underlying [Tensor] of field elements.
pub fn get_felt_evals(&self) -> Result<Tensor<F>, Box<dyn Error>> {
let mut felt_evals: Vec<F> = vec![];

View File

@@ -2,7 +2,7 @@ use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::CheckMode;
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
@@ -289,9 +289,10 @@ impl VarTensor {
&self,
region: &mut Region<F>,
offset: usize,
coord: usize,
constant: F,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset);
let (x, y, z) = self.cartesian_coord(offset + coord);
match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice_from_constant(|| "constant", advices[x][y], z, constant)
@@ -304,33 +305,28 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<&usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
unimplemented!("cannot assign instance to advice columns with omissions")
}
ValTensor::Value { inner: v, .. } => Ok::<_, halo2_proofs::plonk::Error>(
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok(k);
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell = self.assign_value(region, offset, k.clone(), assigned_coord)?;
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
@@ -340,11 +336,12 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut res: ValTensor<F> = match values {
ValTensor::Instance {
@@ -382,14 +379,7 @@ impl VarTensor {
},
ValTensor::Value { inner: v, .. } => Ok(v
.enum_map(|coord, k| {
let cell = self.assign_value(region, offset, k.clone(), coord)?;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
self.assign_value(region, offset, k.clone(), coord, constants)
})?
.into()),
}?;
@@ -399,13 +389,16 @@ impl VarTensor {
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn dummy_assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn dummy_assign_with_duplication<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
row: usize,
offset: usize,
values: &ValTensor<F>,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
@@ -430,21 +423,24 @@ impl VarTensor {
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
let constants_map = res.create_constants_map();
constants.extend(constants_map);
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
row: usize,
@@ -452,7 +448,8 @@ impl VarTensor {
values: &ValTensor<F>,
check_mode: &CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
let mut prev_cell = None;
match values {
@@ -494,7 +491,7 @@ impl VarTensor {
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
};
let cell = self.assign_value(region, offset, k.clone(), coord * step)?;
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
if single_inner_col {
if z == 0 {
@@ -502,28 +499,23 @@ impl VarTensor {
prev_cell = Some(cell.clone());
} else if coord > 0 && z == 0 && single_inner_col {
if let Some(prev_cell) = prev_cell.as_ref() {
region.constrain_equal(prev_cell.cell(),cell.cell())?;
let cell = cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Error copy-constraining previous value: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}}
match k {
ValType::Constant(f) => {
Ok(ValType::AssignedConstant(cell, f))
},
ValType::AssignedConstant(_, f) => {
Ok(ValType::AssignedConstant(cell, f))
},
_ => {
Ok(ValType::PrevAssigned(cell))
}
}
Ok(cell)
})?.into()};
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
@@ -542,42 +534,61 @@ impl VarTensor {
)};
}
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
fn assign_value<F: PrimeField + TensorType + PartialOrd>(
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
k: ValType<F>,
coord: usize,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset + coord);
match k {
let res = match k {
ValType::Value(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice(|| "k", advices[x][y], z, || v)
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
},
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => match &self {
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
v.copy_advice(|| "k", region, advices[x][y], z)
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => {
error!("PrevAssigned is only supported for advice columns");
Err(halo2_proofs::plonk::Error::Synthesis)
_ => unimplemented!(),
},
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
},
ValType::AssignedValue(v) => match &self {
VarTensor::Advice { inner: advices, .. } => region
.assign_advice(|| "k", advices[x][y], z, || v)
.map(|a| a.evaluate()),
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
region
.assign_advice(|| "k", advices[x][y], z, || v)?
.evaluate(),
),
_ => unimplemented!(),
},
ValType::Constant(v) => self.assign_constant(region, offset + coord, v),
}
ValType::Constant(v) => {
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
let value = ValType::AssignedConstant(
self.assign_constant(region, offset, coord, v)?,
v,
);
e.insert(value.clone());
value
} else {
let cell = constants.get(&v).unwrap();
self.assign_value(region, offset, cell.clone(), coord, constants)?
}
}
};
Ok(res)
}
}