mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-14 00:38:15 -05:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00155e585f | ||
|
|
0876faa12c |
@@ -592,7 +592,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -648,10 +648,10 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
42
examples/onnx/log/gen.py
Normal file
42
examples/onnx/log/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.log(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 3)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/log/input.json
Normal file
1
examples/onnx/log/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}
|
||||
14
examples/onnx/log/network.onnx
Normal file
14
examples/onnx/log/network.onnx
Normal file
@@ -0,0 +1,14 @@
|
||||
pytorch2.2.2:o
|
||||
|
||||
inputoutput/Log"Log
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -1 +1,148 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.6261028051376343, 0.49872446060180664, -0.04514765739440918, 0.5936200618743896, 0.9271858930587769, 0.6688600778579712, -0.20331168174743652, -0.7016235589981079, 0.025863051414489746, -0.19426143169403076, 0.9827852249145508, 0.4897397756576538, 0.2992602586746216, 0.7011144161224365, 0.9278832674026489, 0.5943725109100342, -0.573331356048584, 0.3675816059112549], [0.7803324460983276, -0.9616303443908691, 0.6070173978805542, -0.028337717056274414, -0.5080242156982422, -0.9280107021331787, 0.6150380373001099, 0.3865993022918701, -0.43668973445892334, 0.17152702808380127, 0.5144252777099609, -0.28881049156188965, 0.8932310342788696, 0.059034109115600586, 0.6865451335906982, 0.009820222854614258, 0.23011493682861328, -0.9492779970169067], [-0.21352827548980713, -0.16015326976776123, -0.38964390754699707, 0.13464701175689697, -0.8814496994018555, 0.5037975311279297, -0.804405927658081, 0.9858957529067993, 0.19567716121673584, 0.9777265787124634, 0.6151977777481079, 0.568595290184021, 0.10584986209869385, -0.8975653648376465, 0.6235959529876709, -0.547879695892334, 0.9289869070053101, 0.7567293643951416]], "output_data": [[1.0, 0.0, -0.0, 1.0, 1.0, 1.0, -0.0, -1.0, 0.0, -0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 0.0], [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0], [-0.0, -0.0, -0.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0]]}
|
||||
{
|
||||
"input_shapes": [
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
]
|
||||
],
|
||||
"input_data": [
|
||||
[
|
||||
0.5,
|
||||
1.5,
|
||||
-0.04514765739440918,
|
||||
0.5936200618743896,
|
||||
0.9271858930587769,
|
||||
0.6688600778579712,
|
||||
-0.20331168174743652,
|
||||
-0.7016235589981079,
|
||||
0.025863051414489746,
|
||||
-0.19426143169403076,
|
||||
0.9827852249145508,
|
||||
0.4897397756576538,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.9278832674026489,
|
||||
0.5943725109100342,
|
||||
-0.573331356048584,
|
||||
0.3675816059112549
|
||||
],
|
||||
[
|
||||
0.7803324460983276,
|
||||
-0.9616303443908691,
|
||||
0.6070173978805542,
|
||||
-0.028337717056274414,
|
||||
-0.5080242156982422,
|
||||
-0.9280107021331787,
|
||||
0.6150380373001099,
|
||||
0.3865993022918701,
|
||||
-0.43668973445892334,
|
||||
0.17152702808380127,
|
||||
0.5144252777099609,
|
||||
-0.28881049156188965,
|
||||
0.8932310342788696,
|
||||
0.059034109115600586,
|
||||
0.6865451335906982,
|
||||
0.009820222854614258,
|
||||
0.23011493682861328,
|
||||
-0.9492779970169067
|
||||
],
|
||||
[
|
||||
-0.21352827548980713,
|
||||
-0.16015326976776123,
|
||||
-0.38964390754699707,
|
||||
0.13464701175689697,
|
||||
-0.8814496994018555,
|
||||
0.5037975311279297,
|
||||
-0.804405927658081,
|
||||
0.9858957529067993,
|
||||
0.19567716121673584,
|
||||
0.9777265787124634,
|
||||
0.6151977777481079,
|
||||
0.568595290184021,
|
||||
0.10584986209869385,
|
||||
-0.8975653648376465,
|
||||
0.6235959529876709,
|
||||
-0.547879695892334,
|
||||
0.9289869070053101,
|
||||
0.7567293643951416
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
0.0
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0
|
||||
],
|
||||
[
|
||||
-0.0,
|
||||
-0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0
|
||||
]
|
||||
]
|
||||
}
|
||||
@@ -197,6 +197,9 @@ struct PyRunArgs {
|
||||
/// int: The number of legs used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_legs: usize,
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -212,6 +215,7 @@ impl PyRunArgs {
|
||||
impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
@@ -236,6 +240,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
|
||||
@@ -13,6 +13,14 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
@@ -108,9 +116,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
|
||||
HybridOp::Max => format!("MAX"),
|
||||
HybridOp::Min => format!("MIN"),
|
||||
HybridOp::Recip {
|
||||
@@ -181,6 +194,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
@@ -316,6 +333,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -4507,6 +4507,332 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
)
|
||||
}
|
||||
|
||||
/// integer ln layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::ln;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = ln::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into()).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 0, 4, -8]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// first generate the claimed val
|
||||
|
||||
let mut input = values[0].clone();
|
||||
|
||||
println!("input {}", input.show());
|
||||
|
||||
let scale_as_felt = integer_rep_to_felt(scale.0.round() as IntegerRep);
|
||||
|
||||
let assigned_triple_scaled_as_felt_tensor = region.assign(
|
||||
&config.custom_gates.inputs[1],
|
||||
&create_constant_tensor(scale_as_felt * scale_as_felt * scale_as_felt, 1),
|
||||
)?;
|
||||
|
||||
// natural ln is log2(x) * ln(2)
|
||||
let ln2 = utils::F32::from(2.0_f32.ln());
|
||||
// now create a constant tensor for ln2 with scale
|
||||
let ln2_tensor: ValTensor<F> = create_constant_tensor(
|
||||
integer_rep_to_felt((ln2.0 * scale.0).round() as IntegerRep),
|
||||
1,
|
||||
);
|
||||
region.assign(&config.custom_gates.inputs[0], &ln2_tensor)?;
|
||||
let unit = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
region.assign(&config.custom_gates.inputs[1], &unit)?;
|
||||
region.increment(1);
|
||||
|
||||
// 2. assign the image
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
// don't need to increment because the claimed output is assigned to output and incremented accordingly
|
||||
}
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.int_evals()?;
|
||||
// returns an integer with the base 2 logarithm
|
||||
tensor::ops::nonlinearities::ilog2(&input_evals.clone(), scale.0 as f64)
|
||||
.par_iter()
|
||||
.map(|x| Value::known(integer_rep_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input.dims())?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let pow2_of_claimed_output = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone()],
|
||||
&LookupOp::PowersOfTwo { scale },
|
||||
)?;
|
||||
|
||||
let num_bits = (std::mem::size_of::<IntegerRep>() * 8) as IntegerRep;
|
||||
|
||||
region.update_max_min_lookup_inputs_force(-num_bits, num_bits)?;
|
||||
|
||||
// now subtract 1 from the claimed output
|
||||
let claimed_output_minus_one = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), unit.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// now add 1 to the claimed output
|
||||
let claimed_output_plus_one = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), unit.clone()],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
// prior power of 2 is less than claimed output
|
||||
let prior_pow2 = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output_minus_one],
|
||||
&LookupOp::PowersOfTwo { scale },
|
||||
)?;
|
||||
|
||||
// next power of 2 is greater than claimed output
|
||||
let next_pow2 = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output_plus_one],
|
||||
&LookupOp::PowersOfTwo { scale },
|
||||
)?;
|
||||
|
||||
// assert that the original input is closest to the claimed output than the prior power of 2 and the next power of 2
|
||||
let distance_to_prior = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[input.clone(), prior_pow2.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// now take abs of the distance
|
||||
let distance_to_prior_l1 = abs(config, region, &[distance_to_prior.clone()])?;
|
||||
|
||||
let distance_to_next = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[input.clone(), next_pow2.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// now take abs of the distance
|
||||
let distance_to_next_l1 = abs(config, region, &[distance_to_next.clone()])?;
|
||||
|
||||
let distance_to_claimed = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[input.clone(), pow2_of_claimed_output.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// now take abs of the distance
|
||||
let distance_to_claimed_l1 = abs(config, region, &[distance_to_claimed.clone()])?;
|
||||
|
||||
// can be less than or equal because we round up
|
||||
let is_distance_to_prior_less = less_equal(
|
||||
config,
|
||||
region,
|
||||
&[distance_to_claimed_l1.clone(), distance_to_prior_l1.clone()],
|
||||
)?;
|
||||
|
||||
// should be striclty less because we round up
|
||||
let is_distance_to_next_less = less(
|
||||
config,
|
||||
region,
|
||||
&[distance_to_claimed_l1, distance_to_next_l1.clone()],
|
||||
)?;
|
||||
|
||||
let is_distance_to_prior_less_and_distance_to_next_less = and(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
is_distance_to_prior_less.clone(),
|
||||
is_distance_to_next_less.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(
|
||||
integer_rep_to_felt(1),
|
||||
is_distance_to_prior_less_and_distance_to_next_less.len(),
|
||||
);
|
||||
|
||||
comparison_unit.reshape(is_distance_to_prior_less_and_distance_to_next_less.dims())?;
|
||||
|
||||
// assigned unit
|
||||
let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?;
|
||||
region.increment(assigned_unit.len());
|
||||
|
||||
// assert that the values are truthy
|
||||
enforce_equality(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
is_distance_to_prior_less_and_distance_to_next_less,
|
||||
assigned_unit.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// get a linear interpolation now
|
||||
|
||||
let sign_of_distance_to_claimed = sign(config, region, &[distance_to_claimed.clone()])?;
|
||||
let sign_of_distance_to_claimed_is_positive = equals(
|
||||
config,
|
||||
region,
|
||||
&[sign_of_distance_to_claimed.clone(), assigned_unit.clone()],
|
||||
)?;
|
||||
|
||||
let sign_of_distance_to_claimed_is_negative = not(
|
||||
config,
|
||||
region,
|
||||
&[sign_of_distance_to_claimed_is_positive.clone()],
|
||||
)?;
|
||||
|
||||
let pow2_prior_to_claimed_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[pow2_of_claimed_output.clone(), prior_pow2.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let pow2_next_to_claimed_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[next_pow2.clone(), pow2_of_claimed_output.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
let recip_pow2_prior_to_claimed_distance = recip(
|
||||
config,
|
||||
region,
|
||||
&[pow2_prior_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
)?;
|
||||
|
||||
let interpolated_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
recip_pow2_prior_to_claimed_distance.clone(),
|
||||
distance_to_claimed.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let gated_prior_interpolated_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
interpolated_distance.clone(),
|
||||
sign_of_distance_to_claimed_is_negative.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let recip_next_to_claimed_distance = recip(
|
||||
config,
|
||||
region,
|
||||
&[pow2_next_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
)?;
|
||||
|
||||
let interpolated_distance_next = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
recip_next_to_claimed_distance.clone(),
|
||||
distance_to_claimed.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let gated_next_interpolated_distance = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
interpolated_distance_next.clone(),
|
||||
sign_of_distance_to_claimed_is_positive.clone(),
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let scaled_claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
claimed_output.clone(),
|
||||
assigned_triple_scaled_as_felt_tensor,
|
||||
],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
scaled_claimed_output.clone(),
|
||||
gated_prior_interpolated_distance.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
claimed_output.clone(),
|
||||
gated_next_interpolated_distance.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
// now multiply the claimed output by ln2
|
||||
pairwise(config, region, &[claimed_output, ln2_tensor], BaseOp::Mult)
|
||||
}
|
||||
|
||||
/// round layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
@@ -4654,6 +4980,155 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
)
|
||||
}
|
||||
|
||||
/// round half to even layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::round;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = round::<Fp>(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn round_half_to_even<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
// if scale is not exactly divisible by 2 we warn
|
||||
if scale.0 % 2.0 != 0.0 {
|
||||
log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate");
|
||||
}
|
||||
|
||||
let midway_point: ValTensor<F> = create_constant_tensor(
|
||||
integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep),
|
||||
1,
|
||||
);
|
||||
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let penultimate_elem =
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?;
|
||||
|
||||
let is_equal_to_midway = equals(
|
||||
config,
|
||||
region,
|
||||
&[last_elem.clone(), assigned_midway_point.clone()],
|
||||
)?;
|
||||
// penultimate_elem is equal to midway point and even, do nothing
|
||||
let is_odd = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[penultimate_elem.clone()],
|
||||
&LookupOp::IsOdd,
|
||||
)?;
|
||||
|
||||
let is_odd_and_equal_to_midway = and(
|
||||
config,
|
||||
region,
|
||||
&[is_odd.clone(), is_equal_to_midway.clone()],
|
||||
)?;
|
||||
|
||||
let is_greater_than_midway = greater(
|
||||
config,
|
||||
region,
|
||||
&[last_elem.clone(), assigned_midway_point.clone()],
|
||||
)?;
|
||||
|
||||
// if the number is equal to midway point and odd increment, or if it is is_greater_than_midway
|
||||
let is_odd_and_equal_to_midway_or_greater_than_midway = or(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
is_odd_and_equal_to_midway.clone(),
|
||||
is_greater_than_midway.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
is_odd_and_equal_to_midway_or_greater_than_midway.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.update_max_min_lookup_inputs_force(0, scale.0 as IntegerRep)?;
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
|
||||
@@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -16,12 +15,12 @@ use halo2curves::ff::PrimeField;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Div { denom: utils::F32 },
|
||||
Cast { scale: utils::F32 },
|
||||
RoundHalfToEven { scale: utils::F32 },
|
||||
IsOdd,
|
||||
PowersOfTwo { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Sqrt { scale: utils::F32 },
|
||||
Rsqrt { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
@@ -51,16 +50,16 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
|
||||
LookupOp::IsOdd => "is_odd".to_string(),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
@@ -85,18 +84,19 @@ impl LookupOp {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::PowersOfTwo { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
|
||||
}
|
||||
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Cast { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
|
||||
),
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
@@ -112,9 +112,6 @@ impl LookupOp {
|
||||
LookupOp::Exp { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
|
||||
}
|
||||
@@ -171,11 +168,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
|
||||
LookupOp::IsOdd => "IS_ODD".to_string(),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
|
||||
@@ -214,10 +211,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -474,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min forcefully
|
||||
pub fn update_max_min_lookup_inputs_force(
|
||||
&mut self,
|
||||
min: IntegerRep,
|
||||
max: IntegerRep,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
|
||||
@@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get largest element represented by the range
|
||||
pub fn largest(&self) -> IntegerRep {
|
||||
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
|
||||
}
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"{}_{}_{}",
|
||||
self.nonlinearity.as_path(),
|
||||
self.range.0,
|
||||
self.range.1
|
||||
self.largest()
|
||||
)
|
||||
}
|
||||
/// Configures the table.
|
||||
@@ -222,7 +226,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
let largest = self.largest();
|
||||
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
@@ -291,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
row_offset += chunk_idx * self.col_size;
|
||||
let (x, y) = self.cartesian_coord(row_offset);
|
||||
|
||||
if !preassigned_input {
|
||||
table.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
|
||||
@@ -803,7 +803,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
@@ -837,61 +837,70 @@ pub fn new_op_from_onnx(
|
||||
"Abs" => SupportedOp::Linear(PolyOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Ln" => {
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Source" => {
|
||||
let dt = node.outputs[0].fact.datum_type;
|
||||
@@ -935,11 +944,9 @@ pub fn new_op_from_onnx(
|
||||
replace_const(
|
||||
0,
|
||||
0,
|
||||
SupportedOp::Nonlinear(LookupOp::Cast {
|
||||
scale: crate::circuit::utils::F32(scale_to_multiplier(
|
||||
input_scales[0],
|
||||
)
|
||||
as f32),
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
)?
|
||||
} else {
|
||||
@@ -1045,7 +1052,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Softmax {
|
||||
@@ -1084,19 +1091,20 @@ pub fn new_op_from_onnx(
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
@@ -1116,7 +1124,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
a: crate::circuit::utils::F32(exponent),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -317,11 +317,18 @@ pub struct RunArgs {
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
|
||||
@@ -1474,6 +1474,85 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Checks if a tensor's elements are odd
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::is_odd;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let result = is_odd(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn is_odd(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rounded = if a_i % 2 == 0 { 0 } else { 1 };
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Powers of 2
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ipow2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ipow2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 32768, 4, 2, 2, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ipow2(a: &Tensor<IntegerRep>, scale_output: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = a_i as f64;
|
||||
let kix = scale_output * (2.0_f64).powf(kix);
|
||||
let rounded = kix.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies ln base 2 to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ilog2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 2]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ilog2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 1, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ilog2(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let kix = (kix).log2();
|
||||
let rounded = kix.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sigmoid to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1602,12 +1681,11 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies exponential to a tensor of integers.
|
||||
/// Elementwise applies ln to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// * `scale_output` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -27,7 +27,8 @@
|
||||
"check_mode": "UNSAFE",
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false
|
||||
},
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
|
||||
@@ -205,7 +205,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 95] = [
|
||||
const TESTS: [&str; 96] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -305,6 +305,7 @@ mod native_tests {
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
"rsqrt", // 94
|
||||
"log", // 95
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -543,7 +544,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=94 {
|
||||
seq!(N in 0..=95 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
|
||||
Reference in New Issue
Block a user