lshift and rshift (#4591)

This commit is contained in:
chenyu
2024-05-14 19:16:31 -04:00
committed by GitHub
parent 45e7400e3c
commit 2b0ee74bb6
3 changed files with 38 additions and 4 deletions

View File

@@ -230,8 +230,8 @@ def NF4Linear(block_size):
def __call__(self, x: Tensor) -> Tensor:
high_bits = self.weight
low_bits = (self.weight * 2 ** 4).contiguous()
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
low_bits = self.weight.lshift(4).contiguous()
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).rshift(4)
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)