mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
slightly faster nf4 llama (#4542)
This commit is contained in:
@@ -226,8 +226,9 @@ def NF4Linear(block_size):
|
||||
self.scale = Tensor.empty(int(out_features * in_features / block_size), 1, dtype=dtypes.float16)
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
high_bits, low_bits = self.weight.div(2 ** 4, upcast=False), (self.weight * 2 ** 4).contiguous().div(2 ** 4, upcast=False)
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).flatten()
|
||||
high_bits = self.weight
|
||||
low_bits = (self.weight * 2 ** 4).contiguous()
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
|
||||
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
|
||||
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user