feat: separate matmul from the floating points and zero points values

This commit is contained in:
jfrery
2021-11-11 21:15:16 +01:00
committed by jfrery
parent 5448669e83
commit fac0479465

View File

@@ -56,18 +56,37 @@ class QuantizedLinear:
# Satisfy mypy.
assert self.q_out is not None
assert self.q_bias is not None
# We need to develop the following equation to have the main computation
# (self.q_weights.q_values @ self.q_inputs.q_values) without zero_point values.
# See https://github.com/google/gemmlowp/blob/master/doc/quantization.md #852
m_product = (q_input.scale * self.q_weights.scale) / (self.q_out.scale)
dot_product = (q_input.qvalues - q_input.zero_point) @ (
self.q_weights.qvalues - self.q_weights.zero_point
).T
# The following MatMul is done with integers, and thus, does not use of any PBS.
# Only the final conversion to float is done with a PBS, which can actually
# be merged with the PBS of following activation.
# State of the art quantization method assumes the following results in a int32 accumulator.
m_bias = self.q_bias.scale / (q_input.scale * self.q_weights.scale)
bias_part = m_bias * (self.q_bias.qvalues - self.q_bias.zero_point)
numpy_q_out = m_product * (dot_product + bias_part) + self.q_out.zero_point
# Here we follow Eq.7 in https://arxiv.org/abs/1712.05877 to split the core computation
# from the zero points and scales.
p = self.q_weights.qvalues.shape[0]
# Core matmul operation in full intergers with a shape change (INTEGERS)
matmul = q_input.qvalues @ self.q_weights.qvalues.T
# Sum operation in full integers resulting in large integers (INTEGERS)
sum_input = self.q_weights.zero_point * numpy.sum(q_input.qvalues, axis=1, keepdims=True)
sum_weights = q_input.zero_point * numpy.sum(
self.q_weights.qvalues.T, axis=0, keepdims=True
)
# Quantization scales and zero points (FLOATS involved)
# This is going to be compiled with a PBS (along with the following activation function)
m_matmul = (q_input.scale * self.q_weights.scale) / (self.q_out.scale)
bias_part = (
self.q_bias.scale / self.q_out.scale * (self.q_bias.qvalues - self.q_bias.zero_point)
)
final_term = p * q_input.zero_point * self.q_weights.zero_point
numpy_q_out = matmul - sum_input - sum_weights + final_term
numpy_q_out = m_matmul * numpy_q_out
numpy_q_out = self.q_out.zero_point + bias_part + numpy_q_out
numpy_q_out = numpy_q_out.round().clip(0, 2 ** self.q_out.n_bits - 1).astype(int)