mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
1 Commits
ac/remove-
...
ac/piecewi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4a5e65126 |
@@ -197,6 +197,9 @@ struct PyRunArgs {
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
/// bool: Should the circuit use unbounded lookups for exp
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_exp_lookup: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -213,6 +216,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
bounded_exp_lookup: py_run_args.bounded_exp_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
@@ -237,6 +241,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
bounded_exp_lookup: self.bounded_exp_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
|
||||
@@ -16,6 +16,9 @@ pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -130,6 +133,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::Exp { scale } => format!("EXP(scale={})", scale),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
@@ -215,6 +219,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::Exp { scale } => {
|
||||
layouts::exp(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
@@ -357,7 +364,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
} => 3 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::{
|
||||
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use hybrid::HybridOp;
|
||||
use itertools::Itertools;
|
||||
use log::{error, trace};
|
||||
use maybe_rayon::{
|
||||
@@ -19,6 +20,7 @@ use super::{chip::BaseConfig, region::RegionCtx};
|
||||
use crate::{
|
||||
circuit::{ops::base::BaseOp, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::scale_to_multiplier,
|
||||
tensor::{
|
||||
create_unit_tensor, get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
@@ -91,11 +93,23 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
|
||||
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
|
||||
let is_opt = and(
|
||||
config,
|
||||
region,
|
||||
&[f_x_is_opt_lhs.clone(), f_x_is_opt_rhs.clone()],
|
||||
)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
println!("fx {}", f_x.show());
|
||||
println!("f_x_plus_1 {}", f_x_plus_1.show());
|
||||
println!("f_x_minus_1 {}", f_x_minus_1.show());
|
||||
|
||||
println!("f_x_is_opt_lhs {}", f_x_is_opt_lhs.show());
|
||||
println!("f_x_is_opt_rhs {}", f_x_is_opt_rhs.show());
|
||||
println!("is_opt {}", is_opt.show());
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
@@ -132,7 +146,14 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
return Ok(value[0].clone());
|
||||
}
|
||||
|
||||
let input = value[0].clone();
|
||||
let mut input = value[0].clone();
|
||||
|
||||
// assign the image
|
||||
if input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
// don't need to increment because the claimed output is assigned to output and incremented accordingly
|
||||
}
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
let divisor = create_constant_tensor(div, 1);
|
||||
@@ -180,9 +201,15 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
input_scale: F,
|
||||
output_scale: F,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
let mut input = value[0].clone();
|
||||
|
||||
// assigned
|
||||
if input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
// don't need to increment because the claimed output is assigned to output and incremented accordingly
|
||||
}
|
||||
|
||||
let input_dims = input.dims();
|
||||
let unit_scale = create_constant_tensor(output_scale * input_scale, 1);
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
@@ -279,6 +306,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
input_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = value[0].clone();
|
||||
|
||||
let input_dims = input.dims();
|
||||
|
||||
let unit_scale = create_constant_tensor(integer_rep_to_felt(input_scale.0 as IntegerRep), 1);
|
||||
@@ -4645,7 +4673,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
)
|
||||
}
|
||||
|
||||
/// integer ln layout
|
||||
/// piecewise linear ln layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
@@ -4687,8 +4715,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let mut input = values[0].clone();
|
||||
let scale_as_felt = integer_rep_to_felt(scale.0.round() as IntegerRep);
|
||||
|
||||
let triple_scaled_as_felt_tensor =
|
||||
create_constant_tensor(scale_as_felt * scale_as_felt * scale_as_felt, 1);
|
||||
let double_scaled_as_felt_tensor = create_constant_tensor(scale_as_felt * scale_as_felt, 1);
|
||||
|
||||
// natural ln is log2(x) * ln(2)
|
||||
let ln2 = utils::F32::from(2.0_f32.ln());
|
||||
@@ -4850,7 +4877,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region,
|
||||
&[pow2_prior_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
scale_as_felt,
|
||||
)?;
|
||||
|
||||
let interpolated_distance = pairwise(
|
||||
@@ -4878,7 +4905,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region,
|
||||
&[pow2_next_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
scale_as_felt,
|
||||
)?;
|
||||
|
||||
let interpolated_distance_next = pairwise(
|
||||
@@ -4904,7 +4931,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let scaled_claimed_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), triple_scaled_as_felt_tensor],
|
||||
&[claimed_output.clone(), double_scaled_as_felt_tensor],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
@@ -4932,6 +4959,88 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pairwise(config, region, &[claimed_output, ln2_tensor], BaseOp::Mult)
|
||||
}
|
||||
|
||||
/// Square root accumulated layout
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::exp;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 9]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = exp::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into()).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 4, 8, 4, 8, 16, 8, 16, 512]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn exp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let ln_op = HybridOp::Ln { scale: input_scale };
|
||||
let rescale_factor = Op::<F>::out_scale(&ln_op, vec![])?;
|
||||
let ratio = (scale_to_multiplier(rescale_factor) / (input_scale.0 as f64)).round() as u64;
|
||||
|
||||
let scaling_ratio = create_constant_tensor(F::from(ratio), 1);
|
||||
|
||||
let rescaled_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[input.clone(), scaling_ratio],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.int_evals()?;
|
||||
tensor::ops::nonlinearities::inverse_ln_piecewise(&input_evals, input_scale.0 as f64)
|
||||
.par_iter()
|
||||
.map(|x| Value::known(integer_rep_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
println!("claimed output {}", claimed_output.show());
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
-> Result<ValTensor<F>, CircuitError> {
|
||||
println!("x {}", x.show());
|
||||
let ln_x = ln(config, region, &[x.clone()], input_scale)?;
|
||||
let distance = l1_distance(config, region, &[ln_x.clone(), rescaled_input.clone()])?;
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
println!("offset {}", region.row());
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// round layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
|
||||
@@ -853,9 +853,17 @@ pub fn new_op_from_onnx(
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
})
|
||||
}
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Exp" => {
|
||||
if run_args.bounded_exp_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Exp {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
"Ln" => {
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
|
||||
@@ -331,14 +331,21 @@ pub struct RunArgs {
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
/// use bounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use bounded lookup for the exp
|
||||
pub bounded_exp_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
bounded_exp_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
|
||||
@@ -1728,6 +1728,172 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Piecewise linear estimation of log2(x) for x > 0.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
pub fn log2_piecewise(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
|
||||
let kix_log = kix.log2();
|
||||
|
||||
if kix_log.fract() == 0.0 {
|
||||
return Ok::<_, TensorError>(kix_log.round() as IntegerRep);
|
||||
}
|
||||
|
||||
let prev_log = kix_log.floor();
|
||||
let next_log = prev_log + 1.0;
|
||||
|
||||
let prev_pow2 = (2.0_f64).powf(prev_log);
|
||||
let next_pow2 = (2.0_f64).powf(next_log);
|
||||
|
||||
let gradient = (kix - prev_pow2) / (next_pow2 - prev_pow2);
|
||||
|
||||
println!(
|
||||
"kix: {}, prev_log: {}, next_log: {}, prev_pow2: {}, next_pow2: {}, gradient: {}",
|
||||
kix, prev_log, next_log, prev_pow2, next_pow2, gradient
|
||||
);
|
||||
|
||||
let linear_estimation = prev_log + gradient;
|
||||
|
||||
let rounded = (linear_estimation * scale_input).round();
|
||||
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Piecewise linear estimation of log2(x) for x > 0.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
pub fn ln_piecewise(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
let log2 = log2_piecewise(a, scale_input);
|
||||
let ln2 = 2.0_f64.ln();
|
||||
|
||||
log2.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
|
||||
let rounded = (kix * ln2 * scale_input).round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Piecewise inverse of linear estimation of log2(x) for x > 0.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::inverse_log2_piecewise;
|
||||
/// use ezkl::tensor::ops::nonlinearities::log2_piecewise;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let log = log2_piecewise(&x, 1.0);
|
||||
/// let result = inverse_log2_piecewise(&log, 1.0);
|
||||
///
|
||||
/// let rounded_x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 4, 4, 4, 8]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// assert_eq!(result, rounded_x);
|
||||
/// ```
|
||||
pub fn inverse_log2_piecewise(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
println!("a_i: {}", a_i);
|
||||
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
|
||||
println!("kix: {}", kix);
|
||||
|
||||
if kix.fract() == 0.0 {
|
||||
return Ok::<_, TensorError>(2.0_f64.powf(kix).round() as IntegerRep);
|
||||
}
|
||||
let prev_log = kix.floor();
|
||||
let next_log = prev_log + 1.0;
|
||||
|
||||
println!("prev_log: {}, next_log: {}", prev_log, next_log);
|
||||
|
||||
let prev_pow2 = (2.0_f64).powf(prev_log);
|
||||
let next_pow2 = (2.0_f64).powf(next_log);
|
||||
|
||||
println!("prev_pow2: {}, next_pow2: {}", prev_pow2, next_pow2);
|
||||
|
||||
let inv = (kix - prev_log) * (next_pow2 - prev_pow2) + prev_pow2;
|
||||
|
||||
let rounded = (inv * scale_input).round();
|
||||
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Piecewise inverse of linear estimation of ln(x) for x > 0.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::inverse_ln_piecewise;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ln_piecewise;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let log = ln_piecewise(&x, 1.0);
|
||||
/// let result = inverse_ln_piecewise(&log, 1.0);
|
||||
///
|
||||
/// let rounded_x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 2, 2, 2, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// assert_eq!(result, rounded_x);
|
||||
/// ```
|
||||
pub fn inverse_ln_piecewise(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
println!("a_i: {}", a_i);
|
||||
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
|
||||
println!("kix: {}", kix);
|
||||
|
||||
if kix.fract() == 0.0 {
|
||||
return Ok::<_, TensorError>(kix.exp().round() as IntegerRep);
|
||||
}
|
||||
let prev_log = kix.floor();
|
||||
let next_log = prev_log + 1.0;
|
||||
|
||||
println!("prev_log: {}, next_log: {}", prev_log, next_log);
|
||||
|
||||
let prev_pow2 = (2.0_f64).powf(prev_log);
|
||||
let next_pow2 = (2.0_f64).powf(next_log);
|
||||
|
||||
let ln2 = 2.0_f64.ln();
|
||||
|
||||
println!("prev_pow2: {}, next_pow2: {}", prev_pow2, next_pow2);
|
||||
|
||||
let inv = (kix - prev_log * ln2) * (next_pow2 - prev_pow2) + prev_pow2 * ln2;
|
||||
|
||||
let rounded = (inv * scale_input).round();
|
||||
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
/// Elementwise applies square root to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user