mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix llama example quantize (#4699)
* fix llama example quantize import quantize layers from new example llama3 add to mac benchmark * fix that * save the files
This commit is contained in:
6
.github/workflows/benchmark.yml
vendored
6
.github/workflows/benchmark.yml
vendored
@@ -47,6 +47,10 @@ jobs:
|
||||
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
- name: Run LLaMA with BEAM
|
||||
run: JIT=1 BEAM=2 CACHELEVEL=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
|
||||
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
|
||||
- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
run: JIT=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
|
||||
- name: Run GPT2
|
||||
@@ -76,6 +80,8 @@ jobs:
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
llama_int8.txt
|
||||
llama_nf4.txt
|
||||
llama_four_gpu.txt
|
||||
gpt2_unjitted.txt
|
||||
gpt2_jitted.txt
|
||||
|
||||
@@ -190,69 +190,6 @@ def load(fn:str):
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
class Int8Linear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert bias == False
|
||||
self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8)
|
||||
self.scale = Tensor.ones(out_features, dtype=dtypes.half)
|
||||
|
||||
def __call__(self, x):
|
||||
return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
|
||||
|
||||
@staticmethod
|
||||
def quantize(tensors, device):
|
||||
new_tensors = {}
|
||||
for name,v in tensors.items():
|
||||
if "feed_forward" in name or "attention.w" in name or name == "output.weight":
|
||||
assert "weight" in name, name
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace('weight', 'scale')] = scale
|
||||
if isinstance(device, tuple):
|
||||
new_tensors[name].shard_(device, axis=-1)
|
||||
new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None)
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
|
||||
def NF4Linear(block_size):
|
||||
CODE = Tensor([
|
||||
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
|
||||
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0,
|
||||
], dtype=dtypes.float16)
|
||||
class _NF4Linear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert not bias, "bias not supported"
|
||||
self.in_features, self.out_features = in_features, out_features
|
||||
self.weight = Tensor.empty(int(out_features * in_features / 2), dtype=dtypes.uint8)
|
||||
self.scale = Tensor.empty(int(out_features * in_features / block_size), 1, dtype=dtypes.float16)
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
high_bits = self.weight
|
||||
low_bits = self.weight.lshift(4).contiguous()
|
||||
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).rshift(4)
|
||||
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
|
||||
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
|
||||
|
||||
@staticmethod
|
||||
def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]:
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if "feed_forward" in k or "attention.w" in k or k == "output.weight":
|
||||
grouped = v.reshape(-1, block_size)
|
||||
scale = (grouped.abs().max(axis=1, keepdim=True))
|
||||
coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
|
||||
new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2]
|
||||
new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16)
|
||||
if isinstance(device, tuple):
|
||||
new_state_dict[k].shard_(device, axis=-1)
|
||||
new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None)
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
return _NF4Linear
|
||||
|
||||
class LLaMa:
|
||||
@staticmethod
|
||||
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=None, device=None):
|
||||
@@ -261,9 +198,17 @@ class LLaMa:
|
||||
assert tokenizer.vocab_size() == params["args"]["vocab_size"], f"{tokenizer.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||
|
||||
jit = bool(getenv("JIT", 1))
|
||||
if quantize == "int8": model = Transformer(**params["args"], linear=Int8Linear, max_context=MAX_CONTEXT, jit=jit)
|
||||
elif quantize == "nf4": model = Transformer(**params["args"], linear=NF4Linear(64), max_context=MAX_CONTEXT, jit=jit)
|
||||
else: model = Transformer(**params["args"], max_context=MAX_CONTEXT, jit=jit)
|
||||
|
||||
if quantize == "int8":
|
||||
from llama3 import Int8Linear
|
||||
linear = Int8Linear
|
||||
elif quantize == "nf4":
|
||||
from llama3 import NF4Linear
|
||||
linear = NF4Linear(64)
|
||||
else:
|
||||
linear = nn.Linear
|
||||
|
||||
model = Transformer(**params["args"], linear=linear, max_context=MAX_CONTEXT, jit=jit)
|
||||
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]], device[0] if isinstance(device, tuple) else device)
|
||||
@@ -274,24 +219,27 @@ class LLaMa:
|
||||
|
||||
weights = fix_bf16(weights)
|
||||
|
||||
if quantize is not None:
|
||||
with Context(BEAM=0):
|
||||
weights = model.output.__class__.quantize(weights, device)
|
||||
with Context(BEAM=0):
|
||||
# quantize
|
||||
if quantize is not None:
|
||||
weights = linear.quantize(weights, device)
|
||||
for _,v in weights.items(): v.realize()
|
||||
|
||||
if isinstance(device, tuple):
|
||||
for k,v in nn.state.get_state_dict(model).items():
|
||||
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
elif '.attention.' in k: v.shard_(device, axis=-1)
|
||||
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
|
||||
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=-1)
|
||||
elif 'output.weight' in k: v.shard_(device, axis=-1)
|
||||
#elif k.endswith('.weight'): v.shard_(device, axis=-1)
|
||||
#elif 'norm.' in k: v.shard_(device, axis=-1)
|
||||
else: v.shard_(device, axis=None)
|
||||
#print(k, v.shape, v.lazydata.axis)
|
||||
# shard
|
||||
if isinstance(device, tuple):
|
||||
for k,v in nn.state.get_state_dict(model).items():
|
||||
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
elif '.attention.' in k: v.shard_(device, axis=-1)
|
||||
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
|
||||
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=-1)
|
||||
elif 'output.weight' in k: v.shard_(device, axis=-1)
|
||||
#elif k.endswith('.weight'): v.shard_(device, axis=-1)
|
||||
#elif 'norm.' in k: v.shard_(device, axis=-1)
|
||||
else: v.shard_(device, axis=None)
|
||||
#print(k, v.shape, v.lazydata.axis)
|
||||
|
||||
load_state_dict(model, weights, strict=False, consume=True)
|
||||
# replace weights in model
|
||||
load_state_dict(model, weights, strict=False, consume=True)
|
||||
|
||||
return LLaMa(model, tokenizer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user