From fac04794651ef1ca5a5681fa55a0dc0dd71a9e1d Mon Sep 17 00:00:00 2001 From: jfrery Date: Thu, 11 Nov 2021 21:15:16 +0100 Subject: [PATCH] feat: separate matmul from the floating points and zero points values --- concrete/quantization/quantized_layers.py | 39 +++++++++++++++++------ 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/concrete/quantization/quantized_layers.py b/concrete/quantization/quantized_layers.py index 9dd5be821..1932343ab 100644 --- a/concrete/quantization/quantized_layers.py +++ b/concrete/quantization/quantized_layers.py @@ -56,18 +56,37 @@ class QuantizedLinear: # Satisfy mypy. assert self.q_out is not None assert self.q_bias is not None - # We need to develop the following equation to have the main computation - # (self.q_weights.q_values @ self.q_inputs.q_values) without zero_point values. - # See https://github.com/google/gemmlowp/blob/master/doc/quantization.md #852 - m_product = (q_input.scale * self.q_weights.scale) / (self.q_out.scale) - dot_product = (q_input.qvalues - q_input.zero_point) @ ( - self.q_weights.qvalues - self.q_weights.zero_point - ).T + # The following MatMul is done with integers, and thus, does not use of any PBS. + # Only the final conversion to float is done with a PBS, which can actually + # be merged with the PBS of following activation. + # State of the art quantization method assumes the following results in a int32 accumulator. - m_bias = self.q_bias.scale / (q_input.scale * self.q_weights.scale) - bias_part = m_bias * (self.q_bias.qvalues - self.q_bias.zero_point) - numpy_q_out = m_product * (dot_product + bias_part) + self.q_out.zero_point + # Here we follow Eq.7 in https://arxiv.org/abs/1712.05877 to split the core computation + # from the zero points and scales. + + p = self.q_weights.qvalues.shape[0] + + # Core matmul operation in full intergers with a shape change (INTEGERS) + matmul = q_input.qvalues @ self.q_weights.qvalues.T + + # Sum operation in full integers resulting in large integers (INTEGERS) + sum_input = self.q_weights.zero_point * numpy.sum(q_input.qvalues, axis=1, keepdims=True) + sum_weights = q_input.zero_point * numpy.sum( + self.q_weights.qvalues.T, axis=0, keepdims=True + ) + + # Quantization scales and zero points (FLOATS involved) + # This is going to be compiled with a PBS (along with the following activation function) + m_matmul = (q_input.scale * self.q_weights.scale) / (self.q_out.scale) + bias_part = ( + self.q_bias.scale / self.q_out.scale * (self.q_bias.qvalues - self.q_bias.zero_point) + ) + final_term = p * q_input.zero_point * self.q_weights.zero_point + + numpy_q_out = matmul - sum_input - sum_weights + final_term + numpy_q_out = m_matmul * numpy_q_out + numpy_q_out = self.q_out.zero_point + bias_part + numpy_q_out numpy_q_out = numpy_q_out.round().clip(0, 2 ** self.q_out.n_bits - 1).astype(int)