mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: separate matmul from the floating points and zero points values
This commit is contained in:
@@ -56,18 +56,37 @@ class QuantizedLinear:
|
||||
# Satisfy mypy.
|
||||
assert self.q_out is not None
|
||||
assert self.q_bias is not None
|
||||
# We need to develop the following equation to have the main computation
|
||||
# (self.q_weights.q_values @ self.q_inputs.q_values) without zero_point values.
|
||||
# See https://github.com/google/gemmlowp/blob/master/doc/quantization.md #852
|
||||
|
||||
m_product = (q_input.scale * self.q_weights.scale) / (self.q_out.scale)
|
||||
dot_product = (q_input.qvalues - q_input.zero_point) @ (
|
||||
self.q_weights.qvalues - self.q_weights.zero_point
|
||||
).T
|
||||
# The following MatMul is done with integers, and thus, does not use of any PBS.
|
||||
# Only the final conversion to float is done with a PBS, which can actually
|
||||
# be merged with the PBS of following activation.
|
||||
# State of the art quantization method assumes the following results in a int32 accumulator.
|
||||
|
||||
m_bias = self.q_bias.scale / (q_input.scale * self.q_weights.scale)
|
||||
bias_part = m_bias * (self.q_bias.qvalues - self.q_bias.zero_point)
|
||||
numpy_q_out = m_product * (dot_product + bias_part) + self.q_out.zero_point
|
||||
# Here we follow Eq.7 in https://arxiv.org/abs/1712.05877 to split the core computation
|
||||
# from the zero points and scales.
|
||||
|
||||
p = self.q_weights.qvalues.shape[0]
|
||||
|
||||
# Core matmul operation in full intergers with a shape change (INTEGERS)
|
||||
matmul = q_input.qvalues @ self.q_weights.qvalues.T
|
||||
|
||||
# Sum operation in full integers resulting in large integers (INTEGERS)
|
||||
sum_input = self.q_weights.zero_point * numpy.sum(q_input.qvalues, axis=1, keepdims=True)
|
||||
sum_weights = q_input.zero_point * numpy.sum(
|
||||
self.q_weights.qvalues.T, axis=0, keepdims=True
|
||||
)
|
||||
|
||||
# Quantization scales and zero points (FLOATS involved)
|
||||
# This is going to be compiled with a PBS (along with the following activation function)
|
||||
m_matmul = (q_input.scale * self.q_weights.scale) / (self.q_out.scale)
|
||||
bias_part = (
|
||||
self.q_bias.scale / self.q_out.scale * (self.q_bias.qvalues - self.q_bias.zero_point)
|
||||
)
|
||||
final_term = p * q_input.zero_point * self.q_weights.zero_point
|
||||
|
||||
numpy_q_out = matmul - sum_input - sum_weights + final_term
|
||||
numpy_q_out = m_matmul * numpy_q_out
|
||||
numpy_q_out = self.q_out.zero_point + bias_part + numpy_q_out
|
||||
|
||||
numpy_q_out = numpy_q_out.round().clip(0, 2 ** self.q_out.n_bits - 1).astype(int)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user