quantized llama multilazybuffer fix (#4557)

This commit is contained in:
wozeparrot
2024-05-12 14:19:21 -07:00
committed by GitHub
parent bcee4743ce
commit d7670f8141
3 changed files with 45 additions and 6 deletions

View File

@@ -200,7 +200,7 @@ class Int8Linear:
return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
@staticmethod
def quantize(tensors):
def quantize(tensors, device):
new_tensors = {}
for name,v in tensors.items():
if "feed_forward" in name or "attention.w" in name or name == "output.weight":
@@ -209,6 +209,9 @@ class Int8Linear:
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
new_tensors[name] = int8_weight
new_tensors[name.replace('weight', 'scale')] = scale
if isinstance(device, tuple):
new_tensors[name].shard_(device, axis=-1)
new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None)
else:
new_tensors[name] = v
return new_tensors
@@ -233,15 +236,18 @@ def NF4Linear(block_size):
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
@staticmethod
def quantize(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]:
new_state_dict = {}
for k, v in state_dict.items():
if "feed_forward" in k or "attention.w" in k or k == "output.weight":
grouped = v.to(CODE.device).reshape(-1, block_size)
grouped = v.reshape(-1, block_size)
scale = (grouped.abs().max(axis=1, keepdim=True))
coded = ((grouped / scale).unsqueeze(-1) - CODE).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2]
new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16)
if isinstance(device, tuple):
new_state_dict[k].shard_(device, axis=-1)
new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None)
else:
new_state_dict[k] = v
return new_state_dict
@@ -270,7 +276,7 @@ class LLaMa:
if quantize is not None:
with Context(BEAM=0):
weights = model.output.__class__.quantize(weights)
weights = model.output.__class__.quantize(weights, device)
for _,v in weights.items(): v.realize()
if isinstance(device, tuple):

View File

@@ -6,6 +6,7 @@ from tinygrad import Tensor, Device, TinyJit
from tinygrad.helpers import CI, Context
from tinygrad.ops import BufferOps
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@@ -372,6 +373,35 @@ class TestNN(unittest.TestCase):
self.assertEqual(1, len([item for item in schedule if item.ast[0].op is BufferOps.STORE]), "second run realizes embedding only")
run_schedule(schedule)
def test_load_state_dict(self):
layer = Conv2d(3, 5, kernel_size=3)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3),
'bias': Tensor.randn(5),
}
load_state_dict(layer, state_dict)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, -1)
layer.bias.shard_(devices, None)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, -1),
'bias': Tensor.randn(5).shard(devices, None),
}
load_state_dict(layer, state_dict)
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.bias.device, devices)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
if __name__ == '__main__':
unittest.main()

View File

@@ -67,7 +67,10 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
if k not in state_dict and not strict:
if DEBUG >= 1: print(f"WARNING: not loading {k}")
continue
v.replace(state_dict[k].shard(mlb.device, mlb.axis) if isinstance((mlb:=v.lazydata), MultiLazyBuffer) else state_dict[k].to(v.device)).realize()
if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
else: v.replace(state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]
# torch support!