mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
quantized llama multilazybuffer fix (#4557)
This commit is contained in:
@@ -200,7 +200,7 @@ class Int8Linear:
|
||||
return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
|
||||
|
||||
@staticmethod
|
||||
def quantize(tensors):
|
||||
def quantize(tensors, device):
|
||||
new_tensors = {}
|
||||
for name,v in tensors.items():
|
||||
if "feed_forward" in name or "attention.w" in name or name == "output.weight":
|
||||
@@ -209,6 +209,9 @@ class Int8Linear:
|
||||
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace('weight', 'scale')] = scale
|
||||
if isinstance(device, tuple):
|
||||
new_tensors[name].shard_(device, axis=-1)
|
||||
new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None)
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
@@ -233,15 +236,18 @@ def NF4Linear(block_size):
|
||||
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
|
||||
|
||||
@staticmethod
|
||||
def quantize(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]:
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if "feed_forward" in k or "attention.w" in k or k == "output.weight":
|
||||
grouped = v.to(CODE.device).reshape(-1, block_size)
|
||||
grouped = v.reshape(-1, block_size)
|
||||
scale = (grouped.abs().max(axis=1, keepdim=True))
|
||||
coded = ((grouped / scale).unsqueeze(-1) - CODE).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
|
||||
coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
|
||||
new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2]
|
||||
new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16)
|
||||
if isinstance(device, tuple):
|
||||
new_state_dict[k].shard_(device, axis=-1)
|
||||
new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None)
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
@@ -270,7 +276,7 @@ class LLaMa:
|
||||
|
||||
if quantize is not None:
|
||||
with Context(BEAM=0):
|
||||
weights = model.output.__class__.quantize(weights)
|
||||
weights = model.output.__class__.quantize(weights, device)
|
||||
for _,v in weights.items(): v.realize()
|
||||
|
||||
if isinstance(device, tuple):
|
||||
|
||||
@@ -6,6 +6,7 @@ from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.ops import BufferOps
|
||||
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
@@ -372,6 +373,35 @@ class TestNN(unittest.TestCase):
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast[0].op is BufferOps.STORE]), "second run realizes embedding only")
|
||||
run_schedule(schedule)
|
||||
|
||||
def test_load_state_dict(self):
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
|
||||
state_dict = {
|
||||
'weight': Tensor.randn(5, 3, 3, 3),
|
||||
'bias': Tensor.randn(5),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, -1)
|
||||
layer.bias.shard_(devices, None)
|
||||
state_dict = {
|
||||
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, -1),
|
||||
'bias': Tensor.randn(5).shard(devices, None),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
self.assertEqual(layer.weight.device, devices)
|
||||
self.assertEqual(layer.bias.device, devices)
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -67,7 +67,10 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
||||
if k not in state_dict and not strict:
|
||||
if DEBUG >= 1: print(f"WARNING: not loading {k}")
|
||||
continue
|
||||
v.replace(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
|
||||
if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
|
||||
if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
|
||||
else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
# torch support!
|
||||
|
||||
Reference in New Issue
Block a user