mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
457196f9c1 | ||
|
|
a3c131dac0 |
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==15.2.0
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '15.2.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ def main():
|
||||
torch_model = Circuit()
|
||||
# Input to the model
|
||||
shape = [3, 2, 3]
|
||||
w = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
x = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
y = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
torch_out = torch_model(w, x, y)
|
||||
# Export the model
|
||||
torch.onnx.export(torch_model, # model being run
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.6261028051376343, 0.49872446060180664, -0.04514765739440918, 0.5936200618743896, 0.9271858930587769, 0.6688600778579712, -0.20331168174743652, -0.7016235589981079, 0.025863051414489746, -0.19426143169403076, 0.9827852249145508, 0.4897397756576538, 0.2992602586746216, 0.7011144161224365, 0.9278832674026489, 0.5943725109100342, -0.573331356048584, 0.3675816059112549], [0.7803324460983276, -0.9616303443908691, 0.6070173978805542, -0.028337717056274414, -0.5080242156982422, -0.9280107021331787, 0.6150380373001099, 0.3865993022918701, -0.43668973445892334, 0.17152702808380127, 0.5144252777099609, -0.28881049156188965, 0.8932310342788696, 0.059034109115600586, 0.6865451335906982, 0.009820222854614258, 0.23011493682861328, -0.9492779970169067], [-0.21352827548980713, -0.16015326976776123, -0.38964390754699707, 0.13464701175689697, -0.8814496994018555, 0.5037975311279297, -0.804405927658081, 0.9858957529067993, 0.19567716121673584, 0.9777265787124634, 0.6151977777481079, 0.568595290184021, 0.10584986209869385, -0.8975653648376465, 0.6235959529876709, -0.547879695892334, 0.9289869070053101, 0.7567293643951416]], "output_data": [[1.0, 0.0, -0.0, 1.0, 1.0, 1.0, -0.0, -1.0, 0.0, -0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 0.0], [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0], [-0.0, -0.0, -0.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0]]}
|
||||
@@ -1,10 +1,11 @@
|
||||
pytorch2.0.1:â
|
||||
pytorch2.2.2:ă
|
||||
|
||||
woutput_w/Round"Round
|
||||
|
||||
xoutput_x/Floor"Floor
|
||||
|
||||
youtput_y/Ceil"Ceil torch_jitZ%
|
||||
youtput_y/Ceil"Ceil
|
||||
main_graphZ%
|
||||
w
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,18 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -96,6 +108,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Max => format!("MAX"),
|
||||
HybridOp::Min => format!("MIN"),
|
||||
HybridOp::Recip {
|
||||
@@ -166,6 +181,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Floor { scale, legs } => {
|
||||
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Round { scale, legs } => {
|
||||
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::SumPool {
|
||||
|
||||
@@ -4155,8 +4155,40 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(assigned_argmin)
|
||||
}
|
||||
|
||||
/// max layout
|
||||
pub(crate) fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Max layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 2]
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::max_comp;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let y = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = max_comp::<Fp>(&dummy_config, &mut dummy_region, &[x, y]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[5, 2, 3, 1]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -4176,8 +4208,38 @@ pub(crate) fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
pairwise(config, region, &[max_val_p1, max_val_p2], BaseOp::Add)
|
||||
}
|
||||
|
||||
/// min comp layout
|
||||
pub(crate) fn min_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Min comp layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 2]
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::min_comp;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let y = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = min_comp::<Fp>(&dummy_config, &mut dummy_region, &[x, y]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[5, 1, 1, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn min_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -4220,6 +4282,438 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// floor layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::floor;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, -3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = floor::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, -2, -4, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn floor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1);
|
||||
let assigned_negative_one = region.assign(&config.custom_gates.inputs[1], &negative_one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let last_elem_is_zero = equals_zero(config, region, &[last_elem.clone()])?;
|
||||
let last_elem_is_not_zero = not(config, region, &[last_elem_is_zero.clone()])?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?;
|
||||
|
||||
let is_negative_and_not_zero = and(
|
||||
config,
|
||||
region,
|
||||
&[last_elem_is_not_zero.clone(), is_negative.clone()],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
is_negative_and_not_zero.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
/// ceil layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::ceil;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = ceil::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -2, 4, 2]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let one = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let last_elem_is_zero = equals_zero(config, region, &[last_elem.clone()])?;
|
||||
let last_elem_is_not_zero = not(config, region, &[last_elem_is_zero.clone()])?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_positive = equals(config, region, &[sign, assigned_one.clone()])?;
|
||||
|
||||
let is_positive_and_not_zero = and(
|
||||
config,
|
||||
region,
|
||||
&[last_elem_is_not_zero.clone(), is_positive.clone()],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
is_positive_and_not_zero.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
/// round layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::round;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = round::<Fp>(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let one = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?;
|
||||
let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1);
|
||||
let assigned_negative_one = region.assign(&config.custom_gates.output, &negative_one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
// if scale is not exactly divisible by 2 we warn
|
||||
if scale.0 % 2.0 != 0.0 {
|
||||
log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate");
|
||||
}
|
||||
|
||||
let midway_point: ValTensor<F> = create_constant_tensor(
|
||||
integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep),
|
||||
1,
|
||||
);
|
||||
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_positive = equals(config, region, &[sign.clone(), assigned_one.clone()])?;
|
||||
let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?;
|
||||
|
||||
let is_greater_than_midway = greater_equal(
|
||||
config,
|
||||
region,
|
||||
&[last_elem.clone(), assigned_midway_point.clone()],
|
||||
)?;
|
||||
|
||||
// if greater than midway point and positive, increment
|
||||
let is_positive_and_more_than_midway = and(
|
||||
config,
|
||||
region,
|
||||
&[is_positive.clone(), is_greater_than_midway.clone()],
|
||||
)?;
|
||||
|
||||
// is less than midway point and negative, decrement
|
||||
let is_negative_and_more_than_midway = and(
|
||||
config,
|
||||
region,
|
||||
&[is_negative.clone(), is_greater_than_midway],
|
||||
)?;
|
||||
|
||||
let conditions_for_increment = or(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
is_positive_and_more_than_midway.clone(),
|
||||
is_negative_and_more_than_midway.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
conditions_for_increment.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = values[0].clone();
|
||||
|
||||
let first_dims = input.dims().to_vec()[..input.dims().len() - 1].to_vec();
|
||||
let n = input.dims().last().unwrap() - 1;
|
||||
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
let bases: ValTensor<F> = Tensor::from(
|
||||
(0..n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
|
||||
)
|
||||
.into();
|
||||
|
||||
// multiply and sum the values
|
||||
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
}
|
||||
|
||||
// get the sign bit and make sure it is valid
|
||||
let sign = sliced_input.first()?;
|
||||
let rest = sliced_input.get_slice(&[1..sliced_input.len()])?;
|
||||
|
||||
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
Ok(signed_decomp.get_inner_tensor()?.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let mut combined_output = output.combine()?;
|
||||
|
||||
combined_output.reshape(&first_dims)?;
|
||||
|
||||
Ok(combined_output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
|
||||
@@ -17,9 +17,6 @@ use halo2curves::ff::PrimeField;
|
||||
pub enum LookupOp {
|
||||
Div { denom: utils::F32 },
|
||||
Cast { scale: utils::F32 },
|
||||
Ceil { scale: utils::F32 },
|
||||
Floor { scale: utils::F32 },
|
||||
Round { scale: utils::F32 },
|
||||
RoundHalfToEven { scale: utils::F32 },
|
||||
Sqrt { scale: utils::F32 },
|
||||
Rsqrt { scale: utils::F32 },
|
||||
@@ -54,9 +51,6 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
@@ -91,15 +85,6 @@ impl LookupOp {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::Ceil { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Floor { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Round { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
|
||||
}
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
@@ -186,9 +171,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
|
||||
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
|
||||
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
|
||||
@@ -1083,14 +1083,17 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Nonlinear(LookupOp::Floor {
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Nonlinear(LookupOp::Round {
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) {
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -1421,85 +1421,6 @@ pub fn slice<T: TensorType + Send + Sync>(
|
||||
pub mod nonlinearities {
|
||||
use super::*;
|
||||
|
||||
/// Ceiling operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
///
|
||||
/// use ezkl::tensor::ops::nonlinearities::ceil;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = ceil(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ceil(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.ceil() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Floor operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::floor;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = floor(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn floor(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.floor() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::round;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = round(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn round(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.round() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round half to even operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
@@ -1721,27 +1642,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sign to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::sign;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[-2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sign(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn sign(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies square root to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2225,101 +2125,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies leaky relu to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// * `slope` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::leakyrelu;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = leakyrelu(&x, 0.1);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn leakyrelu(a: &Tensor<IntegerRep>, slope: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rounded = if a_i < 0 {
|
||||
let d_inv_x = (slope) * (a_i as f64);
|
||||
d_inv_x.round() as IntegerRep
|
||||
} else {
|
||||
let d_inv_x = a_i as f64;
|
||||
d_inv_x.round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies max to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::max;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = max(&x, 1.0, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn max(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x <= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies min to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::min;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = min(&x, 1.0, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn min(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x >= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise divides a tensor with a const integer element.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2400,104 +2205,6 @@ pub mod nonlinearities {
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) > 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than_equal;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) >= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) < 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than_equal;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) <= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ops that return the transcript i.e intermediate calcs of an op
|
||||
|
||||
@@ -852,9 +852,11 @@ mod native_tests {
|
||||
fn kzg_prove_and_verify_tight_lookup_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let path = test_dir.into_path();
|
||||
let path = path.to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, false, "single", Commitments::KZG, 1);
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -1632,7 +1634,6 @@ mod native_tests {
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
Reference in New Issue
Block a user