feat: add tanh nonlinearity (#175)

---------

Co-authored-by: disirulla <sidalluri@gmail.com>
Co-authored-by: Alexander Camuto <alexander.camuto@st-hughs.ox.ac.uk>
This commit is contained in:
Alluri
2023-03-29 13:32:31 +07:00
committed by GitHub
parent 122378456a
commit a111214da0
8 changed files with 99 additions and 11 deletions

View File

@@ -245,7 +245,7 @@ jobs:
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: KZG prove and verify aggr tests
run: cargo test --release --verbose tests_aggr::kzg_aggr_prove_and_verify_
run: cargo test --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ -- --test-threads 8
prove-and-verify-aggr-evm-tests:
runs-on: self-hosted

2
.gitignore vendored
View File

@@ -1,5 +1,7 @@
target
data
*.ipynb_checkpoints
*.ipynb
*.sol
*.pf
*.vk

View File

@@ -0,0 +1 @@
{"input_data":[[-0.4007,2.4938,0.5796]],"input_shapes":[[3]],"output_data":[[-0.375,0.984375,0.515625]]}

View File

@@ -0,0 +1,13 @@
pytorch2.0.0:v
"
inputoutput /layer/Tanh"Tanh torch_jitZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -35,6 +35,9 @@ pub enum Op {
Sigmoid {
scales: (usize, usize),
},
Tanh{
scales: (usize, usize),
},
}
impl fmt::Display for Op {
@@ -50,6 +53,7 @@ impl fmt::Display for Op {
}
Op::Sigmoid { scales } => write!(f, "sigmoid w/ scale: {}", scales.0),
Op::Sqrt { scales } => write!(f, "sqrt w/ scale: {}", scales.0),
Op::Tanh { scales } => write!(f, "tanh w/ scale: {}", scales.0),
}
}
}
@@ -64,6 +68,7 @@ impl Op {
Op::PReLU { scale, slopes } => leakyrelu(&x, *scale, slopes[0].0),
Op::Sigmoid { scales } => sigmoid(&x, scales.0, scales.1),
Op::Sqrt { scales } => sqrt(&x, scales.0, scales.1),
Op::Tanh { scales } => tanh(&x, scales.0, scales.1),
}
}
@@ -75,6 +80,7 @@ impl Op {
Op::PReLU { .. } => "PRELU",
Op::Sigmoid { .. } => "SIGMOID",
Op::Sqrt { .. } => "SQRT",
Op::Tanh { .. } => "TANH"
}
}

View File

@@ -71,6 +71,7 @@ impl OpKind {
}),
"Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }),
"Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }),
"Tanh" => OpKind::Lookup(LookupOp::Tanh {scales: (1, 1)}),
"Div" => OpKind::Lookup(LookupOp::Div { denom: F32(1.0) }),
"Const" => OpKind::Const,
"Source" => OpKind::Input,
@@ -330,6 +331,33 @@ impl Node {
}
}
LookupOp::Tanh { .. } => {
let input_node = &inputs[0];
let scale_diff = input_node.out_scale;
if scale_diff > 0 {
let mult = scale_to_multiplier(scale_diff);
opkind = OpKind::Lookup(LookupOp::Tanh {
scales: (mult as usize, scale_to_multiplier(scale) as usize),
});
} else {
opkind = OpKind::Lookup(LookupOp::Tanh {
scales: (1, scale_to_multiplier(scale) as usize),
});
}
Node {
idx,
opkind,
inputs: node.inputs.clone(),
in_dims: vec![input_node.out_dims.clone()],
out_dims: input_node.out_dims.clone(),
in_scale: input_node.out_scale,
out_scale: scale,
output_max: scale_to_multiplier(scale),
..Default::default()
}
}
LookupOp::ReLU { .. } => {
let input_node = &inputs[0];
let scale_diff = input_node.out_scale - scale;

View File

@@ -740,7 +740,7 @@ pub mod nonlinearities {
output
}
/// Elementwise applies sigmoid to a tensor of integers.
/// Elementwise applies square root to a tensor of integers.
/// # Arguments
///
/// * `a` - Tensor
@@ -771,6 +771,40 @@ pub mod nonlinearities {
output
}
/// Elementwise applies tanh activation to a tensor of integers.
/// # Arguments
///
/// * `a` - Tensor
/// * `scale_input` - Single value
/// * `scale_output` - Single value
/// # Examples
/// ```
/// use ezkl_lib::tensor::Tensor;
/// use ezkl_lib::tensor::ops::nonlinearities::tanh;
/// let x = Tensor::<i128>::new(
/// Some(&[4, 25, 8, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = tanh(&x, 1, 1);
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn tanh(a: &Tensor<i128>, scale_input: usize, scale_output: usize) -> Tensor<i128> {
let mut output = a.clone();
for i in 0..a.len(){
let z = a[i].clone() as f32 / (scale_input as f32);
let numerator = z.exp() - (1.0/z.exp());
let denominator = z.exp() + (1.0/z.exp());
let tanhz = (scale_output as f32) * (numerator/denominator);
output[i] = tanhz as i128;
}
output
}
/// Elementwise applies leaky relu to a tensor of integers.
/// # Arguments
///

View File

@@ -40,7 +40,7 @@ fn init() {
assert!(status.success());
}
const TESTS: [&str; 19] = [
const TESTS: [&str; 20] = [
"1l_mlp",
"1l_flatten",
"1l_average",
@@ -51,6 +51,7 @@ const TESTS: [&str; 19] = [
"1l_sqrt",
"1l_leakyrelu",
"1l_relu",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_fc",
"2l_relu_small",
@@ -62,7 +63,7 @@ const TESTS: [&str; 19] = [
"4l_relu_conv_fc",
];
const PACKING_TESTS: [&str; 11] = [
const PACKING_TESTS: [&str; 12] = [
"1l_mlp",
"1l_average",
"1l_div",
@@ -71,12 +72,13 @@ const PACKING_TESTS: [&str; 11] = [
"1l_sqrt",
"1l_leakyrelu",
"1l_relu",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_fc",
"2l_relu_small",
];
const TESTS_AGGR: [&str; 14] = [
const TESTS_AGGR: [&str; 15] = [
"1l_mlp",
"1l_flatten",
"1l_average",
@@ -87,6 +89,7 @@ const TESTS_AGGR: [&str; 14] = [
"1l_sqrt",
"1l_leakyrelu",
"1l_relu",
"1l_tanh",
"2l_relu_fc",
"2l_relu_sigmoid_small",
"2l_relu_small",
@@ -98,7 +101,7 @@ const NEG_TESTS: [(&str, &str); 2] = [
("2l_relu_small", "2l_relu_sigmoid_small"),
];
const TESTS_EVM: [&str; 12] = [
const TESTS_EVM: [&str; 13] = [
"1l_mlp",
"1l_flatten",
"1l_average",
@@ -108,6 +111,7 @@ const TESTS_EVM: [&str; 12] = [
"1l_sqrt",
"1l_leakyrelu",
"1l_relu",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_small",
"2l_relu_fc",
@@ -123,7 +127,7 @@ macro_rules! test_func_aggr {
use crate::TESTS_AGGR;
use test_case::test_case;
use crate::kzg_aggr_prove_and_verify;
seq!(N in 0..=13 {
seq!(N in 0..=14 {
#(#[test_case(TESTS_AGGR[N])])*
fn kzg_aggr_prove_and_verify_(test: &str) {
@@ -145,7 +149,7 @@ macro_rules! test_packed_func {
use crate::mock_packed_outputs;
use crate::mock_everything;
seq!(N in 0..=10 {
seq!(N in 0..=11 {
#(#[test_case(PACKING_TESTS[N])])*
fn mock_packed_outputs_(test: &str) {
@@ -178,7 +182,7 @@ macro_rules! test_func {
use crate::render_circuit;
seq!(N in 0..=18 {
seq!(N in 0..=19 {
#(#[test_case(TESTS[N])])*
fn render_circuit_(test: &str) {
@@ -238,7 +242,7 @@ macro_rules! test_func_evm {
"2l_relu_fc"
];
seq!(N in 0..=11 {
seq!(N in 0..=12 {
#(#[test_case(TESTS_EVM[N])])*
fn kzg_evm_prove_and_verify_(test: &str) {
@@ -844,4 +848,4 @@ fn build_ezkl() {
.status()
.expect("failed to execute process");
assert!(status.success());
}
}