fix llama3 with nf4 quantize (#10107)

also int8 outputs is wrong
This commit is contained in:
chenyu
2025-04-29 15:14:36 -04:00
committed by GitHub
parent 9c1b80499f
commit 4a04098389
2 changed files with 8 additions and 1 deletions

View File

@@ -85,6 +85,10 @@ jobs:
run: |
python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
- name: Run quantized LLaMA3
run: |
python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8 | tee llama3_int8.txt
python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4 | tee llama3_nf4.txt
#- name: Run LLaMA 7B on 4 (virtual) GPUs
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
- name: Run GPT2
@@ -118,6 +122,8 @@ jobs:
llama_beam.txt
llama_int8.txt
llama_nf4.txt
llama3_int8.txt
llama3_nf4.txt
llama_four_gpu.txt
gpt2_unjitted.txt
gpt2_jitted.txt

View File

@@ -126,7 +126,8 @@ def NF4Linear(block_size):
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
@staticmethod
def quantize(state_dict: dict[str, Tensor], device, scale_dtype=dtypes.float16) -> dict[str, Tensor]:
def quantize(state_dict: dict[str, Tensor], device, scale_dtype=dtypes.float16, quantize_embeds=False) -> dict[str, Tensor]:
assert not quantize_embeds # TODO: support this?
new_state_dict = {}
for k, v in state_dict.items():
if "feed_forward" in k or "attention.w" in k: