add quantize fp8 in llama3 (#12893)

* add quantize fp8 in llama3

* don't truncate fp8 alu result

* cast to float32 before matmul

* --model weights/LLaMA-3/8B-SF-DPO/

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
b1tg
2025-10-27 22:22:57 +08:00
committed by GitHub
parent 25c2da1579
commit 45e2f916a3
2 changed files with 40 additions and 1 deletions

View File

@@ -238,6 +238,8 @@ jobs:
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run quantized LLaMA3
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8 | tee llama3_fp8.txt
# - name: Run LLaMA-3 8B on 6 GPUs
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
# - name: Run LLaMA-2 70B
@@ -271,6 +273,7 @@ jobs:
llama3_beam.txt
llama3_four_gpu.txt
llama3_six_gpu.txt
llama3_fp8.txt
llama_2_70B.txt
mixtral.txt
gpt2_unjitted.txt

View File

@@ -145,6 +145,41 @@ def NF4Linear(block_size):
return new_state_dict
return _NF4Linear
def quantize_to_fp8(x: Tensor, dtype=dtypes.fp8e4m3):
fp8_min = -448.0 if dtype == dtypes.fp8e4m3 else -57344.0
fp8_max = 448.0 if dtype == dtypes.fp8e4m3 else 57344.0
scale = fp8_max / x.abs().max()
x_scl_sat = (x * scale).clamp(fp8_min, fp8_max)
return x_scl_sat.cast(dtype), scale.float().reciprocal()
class FP8Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.empty(out_features, in_features, dtype=dtypes.fp8e4m3)
self.bias = Tensor.empty(out_features, dtype=dtypes.float16) if bias else None
self.weight_scale = Tensor.empty((), dtype=dtypes.float16)
def __call__(self, x:Tensor):
y = x.dot(self.weight.T.cast(dtypes.float32)) * self.weight_scale
if self.bias is not None: y = y + self.bias.cast(y.dtype)
return y.cast(x.dtype)
@staticmethod
def quantize(tensors, device, scale_dtype=dtypes.float16, quantize_embeds=False):
assert not quantize_embeds
new_tensors = {}
for name,v in tensors.items():
if "feed_forward" in name or "attention.w" in name:
assert "weight" in name, name
fp8_weight, scale = quantize_to_fp8(v)
new_tensors[name] = fp8_weight
new_tensors[name.replace('weight', 'weight_scale')] = scale.cast(scale_dtype)
if isinstance(device, tuple):
new_tensors[name].shard_(device, axis=-1)
new_tensors[name.replace('weight', 'weight_scale')].shard_(device, axis=None)
else:
new_tensors[name] = v
return new_tensors
MODEL_PARAMS = {
"1B": {
"args": {"dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192},
@@ -167,6 +202,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, scale_dt
# build model
if quantize == "int8": linear, embedding, quantize_embeds = Int8Linear, Int8Embedding, True
elif quantize == "nf4": linear, embedding, quantize_embeds = NF4Linear(64), nn.Embedding, False
elif quantize == "fp8": linear, embedding, quantize_embeds = FP8Linear, nn.Embedding, False
else: linear, embedding, quantize_embeds = nn.Linear, nn.Embedding, False
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, embedding=embedding, max_context=max_context, jit=True)
@@ -242,7 +278,7 @@ if __name__ == "__main__":
parser.add_argument("--model", type=Path, help="Model path")
parser.add_argument("--size", choices=["1B", "8B", "70B", "405B"], default="1B", help="Model size")
parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
parser.add_argument("--quantize", choices=["int8", "nf4", "float16"], help="Quantization method")
parser.add_argument("--quantize", choices=["int8", "nf4", "float16", "fp8"], help="Quantization method")
parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Web server bind address")
parser.add_argument("--port", type=int, default=7776, help="Web server port")