mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: remove transpose from layers
This commit is contained in:
@@ -29,7 +29,7 @@ class QuantizedLinear:
|
||||
self.n_bits = n_bits
|
||||
|
||||
if self.q_bias is None:
|
||||
self.q_bias = QuantizedArray(n_bits, numpy.zeros(self.q_weights.values.shape[:-1]))
|
||||
self.q_bias = QuantizedArray(n_bits, numpy.zeros(self.q_weights.values.shape[-1]))
|
||||
self.q_out = None
|
||||
|
||||
def calibrate(self, x: numpy.ndarray):
|
||||
@@ -39,7 +39,7 @@ class QuantizedLinear:
|
||||
x (numpy.ndarray): Inputs.
|
||||
"""
|
||||
assert self.q_bias is not None
|
||||
self.q_out = QuantizedArray(self.n_bits, x @ self.q_weights.values.T + self.q_bias.values)
|
||||
self.q_out = QuantizedArray(self.n_bits, (x @ self.q_weights.values) + self.q_bias.values)
|
||||
|
||||
def __call__(self, q_input: QuantizedArray) -> QuantizedArray:
|
||||
"""Process the forward pass of the quantized linear layer.
|
||||
@@ -68,13 +68,11 @@ class QuantizedLinear:
|
||||
p = self.q_weights.qvalues.shape[0]
|
||||
|
||||
# Core matmul operation in full intergers with a shape change (INTEGERS)
|
||||
matmul = q_input.qvalues @ self.q_weights.qvalues.T
|
||||
matmul = q_input.qvalues @ self.q_weights.qvalues
|
||||
|
||||
# Sum operation in full integers resulting in large integers (INTEGERS)
|
||||
sum_input = self.q_weights.zero_point * numpy.sum(q_input.qvalues, axis=1, keepdims=True)
|
||||
sum_weights = q_input.zero_point * numpy.sum(
|
||||
self.q_weights.qvalues.T, axis=0, keepdims=True
|
||||
)
|
||||
sum_weights = q_input.zero_point * numpy.sum(self.q_weights.qvalues, axis=0, keepdims=True)
|
||||
|
||||
# Quantization scales and zero points (FLOATS involved)
|
||||
# This is going to be compiled with a PBS (along with the following activation function)
|
||||
|
||||
@@ -12,10 +12,10 @@ class NumpyModule:
|
||||
"""Initialize our numpy module.
|
||||
|
||||
Current constraint: All objects used in the forward have to be defined in the
|
||||
__init__() of torch.nn.Module and follow the exact same order.
|
||||
(i.e. each linear layer must have one variable defined in the
|
||||
right order). This constraint will disappear when
|
||||
TorchScript is in place. (issue #818)
|
||||
__init__() of torch.nn.Module and follow the exact same order.
|
||||
(i.e. each linear layer must have one variable defined in the
|
||||
right order). This constraint will disappear when
|
||||
TorchScript is in place. (issue #818)
|
||||
|
||||
Args:
|
||||
torch_model (nn.Module): A fully trained, torch model alond with its parameters.
|
||||
@@ -43,7 +43,7 @@ class NumpyModule:
|
||||
|
||||
for name, weights in self.torch_model.state_dict().items():
|
||||
params = weights.detach().numpy()
|
||||
self.numpy_module_dict[name] = params
|
||||
self.numpy_module_dict[name] = params.T if "weight" in name else params
|
||||
|
||||
def __call__(self, x: numpy.ndarray):
|
||||
"""Return the function to be compiled by concretefhe.numpy."""
|
||||
@@ -64,7 +64,7 @@ class NumpyModule:
|
||||
if isinstance(layer, nn.Linear):
|
||||
# Apply a matmul product and add the bias.
|
||||
x = (
|
||||
x @ self.numpy_module_dict[f"{name}.weight"].T
|
||||
x @ self.numpy_module_dict[f"{name}.weight"]
|
||||
+ self.numpy_module_dict[f"{name}.bias"]
|
||||
)
|
||||
elif isinstance(layer, nn.Sigmoid):
|
||||
|
||||
@@ -30,11 +30,11 @@ def test_quantized_linear(n_examples, n_features, n_neurons, n_bits):
|
||||
inputs = numpy.random.uniform(size=(n_examples, n_features))
|
||||
q_inputs = QuantizedArray(n_bits, inputs)
|
||||
|
||||
# shape of weights: (n_examples, n_features, n_neurons)
|
||||
weights = numpy.random.uniform(size=(n_neurons, n_features))
|
||||
# shape of weights: (n_neurons, n_features)
|
||||
weights = numpy.random.uniform(size=(n_features, n_neurons))
|
||||
q_weights = QuantizedArray(n_bits, weights)
|
||||
|
||||
bias = numpy.random.uniform(size=(n_neurons))
|
||||
bias = numpy.random.uniform(size=(1, n_neurons))
|
||||
q_bias = QuantizedArray(n_bits, bias)
|
||||
|
||||
# Define our QuantizedLinear layer
|
||||
|
||||
Reference in New Issue
Block a user