diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 5edcb84299..d98bf08a90 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -47,6 +47,10 @@ jobs: JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt - name: Run LLaMA with BEAM run: JIT=1 BEAM=2 CACHELEVEL=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt + - name: Run quantized LLaMA + run: | + JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt + JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt - name: Run LLaMA 7B on 4 (virtual) GPUs run: JIT=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt - name: Run GPT2 @@ -76,6 +80,8 @@ jobs: llama_unjitted.txt llama_jitted.txt llama_beam.txt + llama_int8.txt + llama_nf4.txt llama_four_gpu.txt gpt2_unjitted.txt gpt2_jitted.txt diff --git a/examples/llama.py b/examples/llama.py index 6cfa0d4938..2e194e3a43 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -190,69 +190,6 @@ def load(fn:str): else: return torch_load(fn) -class Int8Linear: - def __init__(self, in_features, out_features, bias=False): - assert bias == False - self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8) - self.scale = Tensor.ones(out_features, dtype=dtypes.half) - - def __call__(self, x): - return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale) - - @staticmethod - def quantize(tensors, device): - new_tensors = {} - for name,v in tensors.items(): - if "feed_forward" in name or "attention.w" in name or name == "output.weight": - assert "weight" in name, name - scale = v.abs().max(axis=1) / 127.0 - int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8) - new_tensors[name] = int8_weight - new_tensors[name.replace('weight', 'scale')] = scale - if isinstance(device, tuple): - new_tensors[name].shard_(device, axis=-1) - new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None) - else: - new_tensors[name] = v - return new_tensors - -def NF4Linear(block_size): - CODE = Tensor([ - -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, - 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0, - ], dtype=dtypes.float16) - class _NF4Linear: - def __init__(self, in_features, out_features, bias=False): - assert not bias, "bias not supported" - self.in_features, self.out_features = in_features, out_features - self.weight = Tensor.empty(int(out_features * in_features / 2), dtype=dtypes.uint8) - self.scale = Tensor.empty(int(out_features * in_features / block_size), 1, dtype=dtypes.float16) - - def __call__(self, x: Tensor) -> Tensor: - high_bits = self.weight - low_bits = self.weight.lshift(4).contiguous() - unpacked = Tensor.stack([high_bits, low_bits], dim=-1).rshift(4) - unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale - return x.linear(unscaled.reshape(self.out_features, self.in_features).T) - - @staticmethod - def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]: - new_state_dict = {} - for k, v in state_dict.items(): - if "feed_forward" in k or "attention.w" in k or k == "output.weight": - grouped = v.reshape(-1, block_size) - scale = (grouped.abs().max(axis=1, keepdim=True)) - coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten() - new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2] - new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16) - if isinstance(device, tuple): - new_state_dict[k].shard_(device, axis=-1) - new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None) - else: - new_state_dict[k] = v - return new_state_dict - return _NF4Linear - class LLaMa: @staticmethod def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=None, device=None): @@ -261,9 +198,17 @@ class LLaMa: assert tokenizer.vocab_size() == params["args"]["vocab_size"], f"{tokenizer.vocab_size()=} not equal to {params['args']['vocab_size']}" jit = bool(getenv("JIT", 1)) - if quantize == "int8": model = Transformer(**params["args"], linear=Int8Linear, max_context=MAX_CONTEXT, jit=jit) - elif quantize == "nf4": model = Transformer(**params["args"], linear=NF4Linear(64), max_context=MAX_CONTEXT, jit=jit) - else: model = Transformer(**params["args"], max_context=MAX_CONTEXT, jit=jit) + + if quantize == "int8": + from llama3 import Int8Linear + linear = Int8Linear + elif quantize == "nf4": + from llama3 import NF4Linear + linear = NF4Linear(64) + else: + linear = nn.Linear + + model = Transformer(**params["args"], linear=linear, max_context=MAX_CONTEXT, jit=jit) if model_path.is_dir(): weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]], device[0] if isinstance(device, tuple) else device) @@ -274,24 +219,27 @@ class LLaMa: weights = fix_bf16(weights) - if quantize is not None: - with Context(BEAM=0): - weights = model.output.__class__.quantize(weights, device) + with Context(BEAM=0): + # quantize + if quantize is not None: + weights = linear.quantize(weights, device) for _,v in weights.items(): v.realize() - if isinstance(device, tuple): - for k,v in nn.state.get_state_dict(model).items(): - if 'scale' in k: v.shard_(device, axis=None) # from quantized - elif '.attention.' in k: v.shard_(device, axis=-1) - elif '.feed_forward.' in k: v.shard_(device, axis=-1) - elif 'tok_embeddings.weight' in k: v.shard_(device, axis=-1) - elif 'output.weight' in k: v.shard_(device, axis=-1) - #elif k.endswith('.weight'): v.shard_(device, axis=-1) - #elif 'norm.' in k: v.shard_(device, axis=-1) - else: v.shard_(device, axis=None) - #print(k, v.shape, v.lazydata.axis) + # shard + if isinstance(device, tuple): + for k,v in nn.state.get_state_dict(model).items(): + if 'scale' in k: v.shard_(device, axis=None) # from quantized + elif '.attention.' in k: v.shard_(device, axis=-1) + elif '.feed_forward.' in k: v.shard_(device, axis=-1) + elif 'tok_embeddings.weight' in k: v.shard_(device, axis=-1) + elif 'output.weight' in k: v.shard_(device, axis=-1) + #elif k.endswith('.weight'): v.shard_(device, axis=-1) + #elif 'norm.' in k: v.shard_(device, axis=-1) + else: v.shard_(device, axis=None) + #print(k, v.shape, v.lazydata.axis) - load_state_dict(model, weights, strict=False, consume=True) + # replace weights in model + load_state_dict(model, weights, strict=False, consume=True) return LLaMa(model, tokenizer)