slightly faster nf4 llama (#4542)

This commit is contained in:
chenyu
2024-05-12 14:24:42 -04:00
committed by GitHub
parent 4c232dc0ae
commit 01a0c1a948

View File

@@ -226,8 +226,9 @@ def NF4Linear(block_size):
self.scale = Tensor.empty(int(out_features * in_features / block_size), 1, dtype=dtypes.float16)
def __call__(self, x: Tensor) -> Tensor:
high_bits, low_bits = self.weight.div(2 ** 4, upcast=False), (self.weight * 2 ** 4).contiguous().div(2 ** 4, upcast=False)
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).flatten()
high_bits = self.weight
low_bits = (self.weight * 2 ** 4).contiguous()
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)