feat(benchmarks): add more metrics to logistic regression benchmark

This commit is contained in:
Umut
2021-10-12 17:37:51 +03:00
parent 280ba7f8cd
commit acdb80c6e3

View File

@@ -1,5 +1,9 @@
# Full Target: Logistic Regression
# Disable line length warnings as we have a looooong metric...
# flake8: noqa: E501
# pylint: disable=C0301
import numpy as np
import torch
from common import BENCHMARK_CONFIGURATION
@@ -76,6 +80,9 @@ def main():
self.zp = zp
self.n = n
def __eq__(self, other):
return self.q == other.q and self.zp == other.zp and self.n == other.n
class QuantizedArray:
def __init__(self, values, parameters):
self.values = np.array(values)
@@ -249,22 +256,53 @@ def main():
)
# Measure: End
correct = 0
for x_i, y_i in zip(x_q, y):
non_homomorphic_correct = 0
homomorphic_correct = 0
for i, (x_i, y_i) in enumerate(zip(x_q, y)):
x_i = [int(value) for value in x_i]
non_homomorphic_prediction = round(
sigmoid.apply(
QuantizedArray(x_i, QuantizationParameters(q_x, zp_x, input_bits)).affine(
QuantizedArray.of(w, parameter_bits),
QuantizedArray.of(b, parameter_bits),
intermediate.min(),
intermediate.max(),
output_bits,
)
).dequantize()[0]
)
# Measure: Evaluation Time (ms)
prediction = round(QuantizedArray(engine.run(*x_i), y_parameters).dequantize())
homomorphic_prediction = round(QuantizedArray(engine.run(*x_i), y_parameters).dequantize())
# Measure: End
if prediction == y_i:
correct += 1
if non_homomorphic_prediction == y_i:
non_homomorphic_correct += 1
if homomorphic_prediction == y_i:
homomorphic_correct += 1
accuracy = (correct / len(y)) * 100
print(f"Accuracy: {accuracy:.2f}%")
print()
# Measure: Accuracy (%) = accuracy
# Alert: Accuracy (%) < 85
print(f"input = {x[i][0]}, {x[i][1]}")
print(f"output = {y_i:.4f}")
print(f"non homomorphic prediction = {non_homomorphic_prediction:.4f}")
print(f"homomorphic prediction = {homomorphic_prediction:.4f}")
non_homomorphic_accuracy = (non_homomorphic_correct / len(y)) * 100
homomorphic_accuracy = (homomorphic_correct / len(y)) * 100
difference = abs(homomorphic_accuracy - non_homomorphic_accuracy)
print()
print(f"Non Homomorphic Accuracy: {non_homomorphic_accuracy:.4f}")
print(f"Homomorphic Accuracy: {homomorphic_accuracy:.4f}")
print(f"Difference Percentage: {difference:.2f}%")
# Measure: Non Homomorphic Accuracy = non_homomorphic_accuracy
# Measure: Homomorphic Accuracy = homomorphic_accuracy
# Measure: Accuracy Difference Between Homomorphic and Non Homomorphic Implementation (%) = difference
# Alert: Accuracy Difference Between Homomorphic and Non Homomorphic Implementation (%) > 5
if __name__ == "__main__":