fix(benchmarks): use quantized cleartext loss instead of non quantized one in linear regression

This commit is contained in:
Umut
2021-10-12 17:31:50 +03:00
parent 47f03e427f
commit 280ba7f8cd

View File

@@ -85,6 +85,30 @@ def main():
def dequantize(self):
return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q
def affine(self, w, b, min_y, max_y, n_y):
x_q = self.values
w_q = w.values
b_q = b.values
q_x = self.parameters.q
q_w = w.parameters.q
q_b = b.parameters.q
zp_x = self.parameters.zp
zp_w = w.parameters.zp
zp_b = b.parameters.zp
q_y = (2 ** n_y - 1) / (max_y - min_y)
zp_y = int(round(min_y * q_y))
y_q = (q_y / (q_x * q_w)) * (
(x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)
)
y_q -= min_y * q_y
y_q = y_q.round().clip(0, 2 ** n_y - 1).astype(np.uint)
return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))
class QuantizedFunction:
def __init__(self, table):
self.table = table
@@ -163,7 +187,17 @@ def main():
for i, (x_i, y_i) in enumerate(zip(x_q, y)):
x_i = [int(value) for value in x_i]
non_homomorphic_prediction = model.evaluate(x[i])[0]
non_homomorphic_prediction = (
QuantizedArray(x_i, QuantizationParameters(q_x, zp_x, input_bits))
.affine(
QuantizedArray.of(model.w, parameter_bits),
QuantizedArray.of(model.b, parameter_bits),
min_y,
max_y,
output_bits,
)
.dequantize()[0]
)
# Measure: Evaluation Time (ms)
homomorphic_prediction = QuantizedArray(engine.run(*x_i), y_parameters).dequantize()
# Measure: End
@@ -176,7 +210,7 @@ def main():
print(f"input = {x[i][0]}")
print(f"output = {y_i:.4f}")
print(f"non homomorphic prediction = {non_homomorphic_loss:.4f}")
print(f"non homomorphic prediction = {non_homomorphic_prediction:.4f}")
print(f"homomorphic prediction = {homomorphic_prediction:.4f}")
non_homomorphic_loss /= len(y)
@@ -191,7 +225,7 @@ def main():
# Measure: Non Homomorphic Loss = non_homomorphic_loss
# Measure: Homomorphic Loss = homomorphic_loss
# Measure: Relative Loss Difference Between Homomorphic and Non Homomorphic Implementation (%) = difference
# Alert: Relative Loss Difference Between Homomorphic and Non Homomorphic Implementation (%) > 20
# Alert: Relative Loss Difference Between Homomorphic and Non Homomorphic Implementation (%) > 5
if __name__ == "__main__":