mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add quantize fp8 in llama3 (#12893)
* add quantize fp8 in llama3 * don't truncate fp8 alu result * cast to float32 before matmul * --model weights/LLaMA-3/8B-SF-DPO/ --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
3
.github/workflows/benchmark.yml
vendored
3
.github/workflows/benchmark.yml
vendored
@@ -238,6 +238,8 @@ jobs:
|
||||
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run quantized LLaMA3
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8 | tee llama3_fp8.txt
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
# - name: Run LLaMA-2 70B
|
||||
@@ -271,6 +273,7 @@ jobs:
|
||||
llama3_beam.txt
|
||||
llama3_four_gpu.txt
|
||||
llama3_six_gpu.txt
|
||||
llama3_fp8.txt
|
||||
llama_2_70B.txt
|
||||
mixtral.txt
|
||||
gpt2_unjitted.txt
|
||||
|
||||
@@ -145,6 +145,41 @@ def NF4Linear(block_size):
|
||||
return new_state_dict
|
||||
return _NF4Linear
|
||||
|
||||
def quantize_to_fp8(x: Tensor, dtype=dtypes.fp8e4m3):
|
||||
fp8_min = -448.0 if dtype == dtypes.fp8e4m3 else -57344.0
|
||||
fp8_max = 448.0 if dtype == dtypes.fp8e4m3 else 57344.0
|
||||
scale = fp8_max / x.abs().max()
|
||||
x_scl_sat = (x * scale).clamp(fp8_min, fp8_max)
|
||||
return x_scl_sat.cast(dtype), scale.float().reciprocal()
|
||||
|
||||
class FP8Linear:
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.weight = Tensor.empty(out_features, in_features, dtype=dtypes.fp8e4m3)
|
||||
self.bias = Tensor.empty(out_features, dtype=dtypes.float16) if bias else None
|
||||
self.weight_scale = Tensor.empty((), dtype=dtypes.float16)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
y = x.dot(self.weight.T.cast(dtypes.float32)) * self.weight_scale
|
||||
if self.bias is not None: y = y + self.bias.cast(y.dtype)
|
||||
return y.cast(x.dtype)
|
||||
|
||||
@staticmethod
|
||||
def quantize(tensors, device, scale_dtype=dtypes.float16, quantize_embeds=False):
|
||||
assert not quantize_embeds
|
||||
new_tensors = {}
|
||||
for name,v in tensors.items():
|
||||
if "feed_forward" in name or "attention.w" in name:
|
||||
assert "weight" in name, name
|
||||
fp8_weight, scale = quantize_to_fp8(v)
|
||||
new_tensors[name] = fp8_weight
|
||||
new_tensors[name.replace('weight', 'weight_scale')] = scale.cast(scale_dtype)
|
||||
if isinstance(device, tuple):
|
||||
new_tensors[name].shard_(device, axis=-1)
|
||||
new_tensors[name.replace('weight', 'weight_scale')].shard_(device, axis=None)
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
|
||||
MODEL_PARAMS = {
|
||||
"1B": {
|
||||
"args": {"dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192},
|
||||
@@ -167,6 +202,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, scale_dt
|
||||
# build model
|
||||
if quantize == "int8": linear, embedding, quantize_embeds = Int8Linear, Int8Embedding, True
|
||||
elif quantize == "nf4": linear, embedding, quantize_embeds = NF4Linear(64), nn.Embedding, False
|
||||
elif quantize == "fp8": linear, embedding, quantize_embeds = FP8Linear, nn.Embedding, False
|
||||
else: linear, embedding, quantize_embeds = nn.Linear, nn.Embedding, False
|
||||
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, embedding=embedding, max_context=max_context, jit=True)
|
||||
|
||||
@@ -242,7 +278,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model", type=Path, help="Model path")
|
||||
parser.add_argument("--size", choices=["1B", "8B", "70B", "405B"], default="1B", help="Model size")
|
||||
parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
|
||||
parser.add_argument("--quantize", choices=["int8", "nf4", "float16"], help="Quantization method")
|
||||
parser.add_argument("--quantize", choices=["int8", "nf4", "float16", "fp8"], help="Quantization method")
|
||||
parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Web server bind address")
|
||||
parser.add_argument("--port", type=int, default=7776, help="Web server port")
|
||||
|
||||
Reference in New Issue
Block a user