diff --git a/benchmarks/logistic_regression.py b/benchmarks/logistic_regression.py index b91a91a4f..dc8bc9bb9 100644 --- a/benchmarks/logistic_regression.py +++ b/benchmarks/logistic_regression.py @@ -158,14 +158,17 @@ def main(): intermediate = x @ w + b intermediate_q = x_q.affine(w_q, b_q, intermediate.min(), intermediate.max(), output_bits) - n_y = output_bits - q_y = (2 ** output_bits - 1) / (intermediate.max() - intermediate.min()) - zp_y = int(round(intermediate.min() * q_y)) - y_parameters = QuantizationParameters(q_y, zp_y, n_y) + sigmoid = QuantizedFunction.plain( + lambda x: 1 / (1 + np.exp(-x)), intermediate_q.parameters, output_bits + ) + + y_q = sigmoid.apply(intermediate_q) + y_parameters = y_q.parameters q_x = x_q.parameters.q q_w = w_q.parameters.q q_b = b_q.parameters.q + q_intermediate = intermediate_q.parameters.q zp_x = x_q.parameters.zp zp_w = w_q.parameters.zp @@ -175,10 +178,10 @@ def main(): w_q = w_q.values b_q = b_q.values - c1 = q_y / (q_x * q_w) + c1 = q_intermediate / (q_x * q_w) c2 = w_q + zp_w c3 = (q_x * q_w / q_b) * (b_q + zp_b) - c4 = intermediate.min() * q_y + c4 = intermediate.min() * q_intermediate def f(x): values = ((c1 * (x + c3)) - c4).round().clip(0, 2 ** output_bits - 1).astype(np.uint) @@ -229,7 +232,10 @@ def main(): if prediction == y_i: correct += 1 - # Measure: Accuracy (%) = (correct / len(y)) * 100 + accuracy = (correct / len(y)) * 100 + print(f"Accuracy: {accuracy:.2f}%") + + # Measure: Accuracy (%) = accuracy if __name__ == "__main__":