diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 3e88c2c932..7aaac0db84 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -238,6 +238,8 @@ jobs: run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt - name: Run LLaMA-3 8B on 4 GPUs with BEAM run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt + - name: Run quantized LLaMA3 + run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8 | tee llama3_fp8.txt # - name: Run LLaMA-3 8B on 6 GPUs # run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt # - name: Run LLaMA-2 70B @@ -271,6 +273,7 @@ jobs: llama3_beam.txt llama3_four_gpu.txt llama3_six_gpu.txt + llama3_fp8.txt llama_2_70B.txt mixtral.txt gpt2_unjitted.txt diff --git a/examples/llama3.py b/examples/llama3.py index d7c7f2c921..54aa8eafea 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -145,6 +145,41 @@ def NF4Linear(block_size): return new_state_dict return _NF4Linear +def quantize_to_fp8(x: Tensor, dtype=dtypes.fp8e4m3): + fp8_min = -448.0 if dtype == dtypes.fp8e4m3 else -57344.0 + fp8_max = 448.0 if dtype == dtypes.fp8e4m3 else 57344.0 + scale = fp8_max / x.abs().max() + x_scl_sat = (x * scale).clamp(fp8_min, fp8_max) + return x_scl_sat.cast(dtype), scale.float().reciprocal() + +class FP8Linear: + def __init__(self, in_features, out_features, bias=True): + self.weight = Tensor.empty(out_features, in_features, dtype=dtypes.fp8e4m3) + self.bias = Tensor.empty(out_features, dtype=dtypes.float16) if bias else None + self.weight_scale = Tensor.empty((), dtype=dtypes.float16) + + def __call__(self, x:Tensor): + y = x.dot(self.weight.T.cast(dtypes.float32)) * self.weight_scale + if self.bias is not None: y = y + self.bias.cast(y.dtype) + return y.cast(x.dtype) + + @staticmethod + def quantize(tensors, device, scale_dtype=dtypes.float16, quantize_embeds=False): + assert not quantize_embeds + new_tensors = {} + for name,v in tensors.items(): + if "feed_forward" in name or "attention.w" in name: + assert "weight" in name, name + fp8_weight, scale = quantize_to_fp8(v) + new_tensors[name] = fp8_weight + new_tensors[name.replace('weight', 'weight_scale')] = scale.cast(scale_dtype) + if isinstance(device, tuple): + new_tensors[name].shard_(device, axis=-1) + new_tensors[name.replace('weight', 'weight_scale')].shard_(device, axis=None) + else: + new_tensors[name] = v + return new_tensors + MODEL_PARAMS = { "1B": { "args": {"dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192}, @@ -167,6 +202,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, scale_dt # build model if quantize == "int8": linear, embedding, quantize_embeds = Int8Linear, Int8Embedding, True elif quantize == "nf4": linear, embedding, quantize_embeds = NF4Linear(64), nn.Embedding, False + elif quantize == "fp8": linear, embedding, quantize_embeds = FP8Linear, nn.Embedding, False else: linear, embedding, quantize_embeds = nn.Linear, nn.Embedding, False model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, embedding=embedding, max_context=max_context, jit=True) @@ -242,7 +278,7 @@ if __name__ == "__main__": parser.add_argument("--model", type=Path, help="Model path") parser.add_argument("--size", choices=["1B", "8B", "70B", "405B"], default="1B", help="Model size") parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices") - parser.add_argument("--quantize", choices=["int8", "nf4", "float16"], help="Quantization method") + parser.add_argument("--quantize", choices=["int8", "nf4", "float16", "fp8"], help="Quantization method") parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface") parser.add_argument("--host", type=str, default="0.0.0.0", help="Web server bind address") parser.add_argument("--port", type=int, default=7776, help="Web server port")