mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
lshift and rshift (#4591)
This commit is contained in:
@@ -230,8 +230,8 @@ def NF4Linear(block_size):
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
high_bits = self.weight
|
||||
low_bits = (self.weight * 2 ** 4).contiguous()
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
|
||||
low_bits = self.weight.lshift(4).contiguous()
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).rshift(4)
|
||||
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
|
||||
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user