mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-09 14:28:00 -05:00
feat: add tanh nonlinearity (#175)
--------- Co-authored-by: disirulla <sidalluri@gmail.com> Co-authored-by: Alexander Camuto <alexander.camuto@st-hughs.ox.ac.uk>
This commit is contained in:
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -245,7 +245,7 @@ jobs:
|
||||
- name: Install wasm32-wasi
|
||||
run: rustup target add wasm32-wasi
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo test --release --verbose tests_aggr::kzg_aggr_prove_and_verify_
|
||||
run: cargo test --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ -- --test-threads 8
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
runs-on: self-hosted
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,7 @@
|
||||
target
|
||||
data
|
||||
*.ipynb_checkpoints
|
||||
*.ipynb
|
||||
*.sol
|
||||
*.pf
|
||||
*.vk
|
||||
|
||||
1
examples/onnx/1l_tanh/input.json
Normal file
1
examples/onnx/1l_tanh/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data":[[-0.4007,2.4938,0.5796]],"input_shapes":[[3]],"output_data":[[-0.375,0.984375,0.515625]]}
|
||||
13
examples/onnx/1l_tanh/network.onnx
Normal file
13
examples/onnx/1l_tanh/network.onnx
Normal file
@@ -0,0 +1,13 @@
|
||||
pytorch2.0.0:v
|
||||
"
|
||||
inputoutput/layer/Tanh"Tanh torch_jitZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -35,6 +35,9 @@ pub enum Op {
|
||||
Sigmoid {
|
||||
scales: (usize, usize),
|
||||
},
|
||||
Tanh{
|
||||
scales: (usize, usize),
|
||||
},
|
||||
}
|
||||
|
||||
impl fmt::Display for Op {
|
||||
@@ -50,6 +53,7 @@ impl fmt::Display for Op {
|
||||
}
|
||||
Op::Sigmoid { scales } => write!(f, "sigmoid w/ scale: {}", scales.0),
|
||||
Op::Sqrt { scales } => write!(f, "sqrt w/ scale: {}", scales.0),
|
||||
Op::Tanh { scales } => write!(f, "tanh w/ scale: {}", scales.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,6 +68,7 @@ impl Op {
|
||||
Op::PReLU { scale, slopes } => leakyrelu(&x, *scale, slopes[0].0),
|
||||
Op::Sigmoid { scales } => sigmoid(&x, scales.0, scales.1),
|
||||
Op::Sqrt { scales } => sqrt(&x, scales.0, scales.1),
|
||||
Op::Tanh { scales } => tanh(&x, scales.0, scales.1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +80,7 @@ impl Op {
|
||||
Op::PReLU { .. } => "PRELU",
|
||||
Op::Sigmoid { .. } => "SIGMOID",
|
||||
Op::Sqrt { .. } => "SQRT",
|
||||
Op::Tanh { .. } => "TANH"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ impl OpKind {
|
||||
}),
|
||||
"Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }),
|
||||
"Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }),
|
||||
"Tanh" => OpKind::Lookup(LookupOp::Tanh {scales: (1, 1)}),
|
||||
"Div" => OpKind::Lookup(LookupOp::Div { denom: F32(1.0) }),
|
||||
"Const" => OpKind::Const,
|
||||
"Source" => OpKind::Input,
|
||||
@@ -330,6 +331,33 @@ impl Node {
|
||||
}
|
||||
}
|
||||
|
||||
LookupOp::Tanh { .. } => {
|
||||
let input_node = &inputs[0];
|
||||
let scale_diff = input_node.out_scale;
|
||||
if scale_diff > 0 {
|
||||
let mult = scale_to_multiplier(scale_diff);
|
||||
opkind = OpKind::Lookup(LookupOp::Tanh {
|
||||
scales: (mult as usize, scale_to_multiplier(scale) as usize),
|
||||
});
|
||||
} else {
|
||||
opkind = OpKind::Lookup(LookupOp::Tanh {
|
||||
scales: (1, scale_to_multiplier(scale) as usize),
|
||||
});
|
||||
}
|
||||
|
||||
Node {
|
||||
idx,
|
||||
opkind,
|
||||
inputs: node.inputs.clone(),
|
||||
in_dims: vec![input_node.out_dims.clone()],
|
||||
out_dims: input_node.out_dims.clone(),
|
||||
in_scale: input_node.out_scale,
|
||||
out_scale: scale,
|
||||
output_max: scale_to_multiplier(scale),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
LookupOp::ReLU { .. } => {
|
||||
let input_node = &inputs[0];
|
||||
let scale_diff = input_node.out_scale - scale;
|
||||
|
||||
@@ -740,7 +740,7 @@ pub mod nonlinearities {
|
||||
output
|
||||
}
|
||||
|
||||
/// Elementwise applies sigmoid to a tensor of integers.
|
||||
/// Elementwise applies square root to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
@@ -771,6 +771,40 @@ pub mod nonlinearities {
|
||||
output
|
||||
}
|
||||
|
||||
/// Elementwise applies tanh activation to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// * `scale_output` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl_lib::tensor::Tensor;
|
||||
/// use ezkl_lib::tensor::ops::nonlinearities::tanh;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[4, 25, 8, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = tanh(&x, 1, 1);
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
|
||||
pub fn tanh(a: &Tensor<i128>, scale_input: usize, scale_output: usize) -> Tensor<i128> {
|
||||
let mut output = a.clone();
|
||||
|
||||
for i in 0..a.len(){
|
||||
let z = a[i].clone() as f32 / (scale_input as f32);
|
||||
let numerator = z.exp() - (1.0/z.exp());
|
||||
let denominator = z.exp() + (1.0/z.exp());
|
||||
let tanhz = (scale_output as f32) * (numerator/denominator);
|
||||
output[i] = tanhz as i128;
|
||||
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Elementwise applies leaky relu to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
@@ -40,7 +40,7 @@ fn init() {
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
const TESTS: [&str; 19] = [
|
||||
const TESTS: [&str; 20] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
@@ -51,6 +51,7 @@ const TESTS: [&str; 19] = [
|
||||
"1l_sqrt",
|
||||
"1l_leakyrelu",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small",
|
||||
@@ -62,7 +63,7 @@ const TESTS: [&str; 19] = [
|
||||
"4l_relu_conv_fc",
|
||||
];
|
||||
|
||||
const PACKING_TESTS: [&str; 11] = [
|
||||
const PACKING_TESTS: [&str; 12] = [
|
||||
"1l_mlp",
|
||||
"1l_average",
|
||||
"1l_div",
|
||||
@@ -71,12 +72,13 @@ const PACKING_TESTS: [&str; 11] = [
|
||||
"1l_sqrt",
|
||||
"1l_leakyrelu",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small",
|
||||
];
|
||||
|
||||
const TESTS_AGGR: [&str; 14] = [
|
||||
const TESTS_AGGR: [&str; 15] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
@@ -87,6 +89,7 @@ const TESTS_AGGR: [&str; 14] = [
|
||||
"1l_sqrt",
|
||||
"1l_leakyrelu",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_small",
|
||||
@@ -98,7 +101,7 @@ const NEG_TESTS: [(&str, &str); 2] = [
|
||||
("2l_relu_small", "2l_relu_sigmoid_small"),
|
||||
];
|
||||
|
||||
const TESTS_EVM: [&str; 12] = [
|
||||
const TESTS_EVM: [&str; 13] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
@@ -108,6 +111,7 @@ const TESTS_EVM: [&str; 12] = [
|
||||
"1l_sqrt",
|
||||
"1l_leakyrelu",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_small",
|
||||
"2l_relu_fc",
|
||||
@@ -123,7 +127,7 @@ macro_rules! test_func_aggr {
|
||||
use crate::TESTS_AGGR;
|
||||
use test_case::test_case;
|
||||
use crate::kzg_aggr_prove_and_verify;
|
||||
seq!(N in 0..=13 {
|
||||
seq!(N in 0..=14 {
|
||||
|
||||
#(#[test_case(TESTS_AGGR[N])])*
|
||||
fn kzg_aggr_prove_and_verify_(test: &str) {
|
||||
@@ -145,7 +149,7 @@ macro_rules! test_packed_func {
|
||||
use crate::mock_packed_outputs;
|
||||
use crate::mock_everything;
|
||||
|
||||
seq!(N in 0..=10 {
|
||||
seq!(N in 0..=11 {
|
||||
|
||||
#(#[test_case(PACKING_TESTS[N])])*
|
||||
fn mock_packed_outputs_(test: &str) {
|
||||
@@ -178,7 +182,7 @@ macro_rules! test_func {
|
||||
use crate::render_circuit;
|
||||
|
||||
|
||||
seq!(N in 0..=18 {
|
||||
seq!(N in 0..=19 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn render_circuit_(test: &str) {
|
||||
@@ -238,7 +242,7 @@ macro_rules! test_func_evm {
|
||||
"2l_relu_fc"
|
||||
];
|
||||
|
||||
seq!(N in 0..=11 {
|
||||
seq!(N in 0..=12 {
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_prove_and_verify_(test: &str) {
|
||||
@@ -844,4 +848,4 @@ fn build_ezkl() {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user