mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(benchmarks): add more metrics to logistic regression benchmark
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
# Full Target: Logistic Regression
|
||||
|
||||
# Disable line length warnings as we have a looooong metric...
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=C0301
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common import BENCHMARK_CONFIGURATION
|
||||
@@ -76,6 +80,9 @@ def main():
|
||||
self.zp = zp
|
||||
self.n = n
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.q == other.q and self.zp == other.zp and self.n == other.n
|
||||
|
||||
class QuantizedArray:
|
||||
def __init__(self, values, parameters):
|
||||
self.values = np.array(values)
|
||||
@@ -249,22 +256,53 @@ def main():
|
||||
)
|
||||
# Measure: End
|
||||
|
||||
correct = 0
|
||||
for x_i, y_i in zip(x_q, y):
|
||||
non_homomorphic_correct = 0
|
||||
homomorphic_correct = 0
|
||||
|
||||
for i, (x_i, y_i) in enumerate(zip(x_q, y)):
|
||||
x_i = [int(value) for value in x_i]
|
||||
|
||||
non_homomorphic_prediction = round(
|
||||
sigmoid.apply(
|
||||
QuantizedArray(x_i, QuantizationParameters(q_x, zp_x, input_bits)).affine(
|
||||
QuantizedArray.of(w, parameter_bits),
|
||||
QuantizedArray.of(b, parameter_bits),
|
||||
intermediate.min(),
|
||||
intermediate.max(),
|
||||
output_bits,
|
||||
)
|
||||
).dequantize()[0]
|
||||
)
|
||||
# Measure: Evaluation Time (ms)
|
||||
prediction = round(QuantizedArray(engine.run(*x_i), y_parameters).dequantize())
|
||||
homomorphic_prediction = round(QuantizedArray(engine.run(*x_i), y_parameters).dequantize())
|
||||
# Measure: End
|
||||
|
||||
if prediction == y_i:
|
||||
correct += 1
|
||||
if non_homomorphic_prediction == y_i:
|
||||
non_homomorphic_correct += 1
|
||||
if homomorphic_prediction == y_i:
|
||||
homomorphic_correct += 1
|
||||
|
||||
accuracy = (correct / len(y)) * 100
|
||||
print(f"Accuracy: {accuracy:.2f}%")
|
||||
print()
|
||||
|
||||
# Measure: Accuracy (%) = accuracy
|
||||
# Alert: Accuracy (%) < 85
|
||||
print(f"input = {x[i][0]}, {x[i][1]}")
|
||||
print(f"output = {y_i:.4f}")
|
||||
|
||||
print(f"non homomorphic prediction = {non_homomorphic_prediction:.4f}")
|
||||
print(f"homomorphic prediction = {homomorphic_prediction:.4f}")
|
||||
|
||||
non_homomorphic_accuracy = (non_homomorphic_correct / len(y)) * 100
|
||||
homomorphic_accuracy = (homomorphic_correct / len(y)) * 100
|
||||
difference = abs(homomorphic_accuracy - non_homomorphic_accuracy)
|
||||
|
||||
print()
|
||||
print(f"Non Homomorphic Accuracy: {non_homomorphic_accuracy:.4f}")
|
||||
print(f"Homomorphic Accuracy: {homomorphic_accuracy:.4f}")
|
||||
print(f"Difference Percentage: {difference:.2f}%")
|
||||
|
||||
# Measure: Non Homomorphic Accuracy = non_homomorphic_accuracy
|
||||
# Measure: Homomorphic Accuracy = homomorphic_accuracy
|
||||
# Measure: Accuracy Difference Between Homomorphic and Non Homomorphic Implementation (%) = difference
|
||||
# Alert: Accuracy Difference Between Homomorphic and Non Homomorphic Implementation (%) > 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user