apply the same fix_bf16 in llama and coder (#3789)

* apply the same fix_bf16 in llama and coder

did not realize the same logic was in llama too.
really fix #2775

* flag for native SUPPORT_BF16 cast
This commit is contained in:
chenyu
2024-03-17 21:25:24 -04:00
committed by GitHub
parent 639bd5dbfc
commit 5ac1fa933f
3 changed files with 12 additions and 9 deletions

View File

@@ -6,7 +6,7 @@ from io import StringIO
from contextlib import redirect_stdout
from tinygrad import Tensor, nn, Device, dtypes
from tinygrad.helpers import Timing, colored, getenv, fetch
from extra.models.llama import Transformer, convert_from_huggingface
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
from sentencepiece import SentencePieceProcessor
def create_fixed_tokenizer(output_file):
@@ -33,9 +33,6 @@ if __name__ == "__main__":
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
# fix bf16, TODO: check if device supports bf16
def fix_bf16(weights): return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
with Timing("weights -> model: "):
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, model, 32, 8)), strict=False)
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, model, 32, 8)), strict=False)

View File

@@ -12,7 +12,7 @@ from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, colored
from tinygrad import Device, GlobalCounters, dtypes, nn
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from extra.models.llama import Transformer, convert_from_huggingface
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
from sentencepiece import SentencePieceProcessor
MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
@@ -176,8 +176,7 @@ class LLaMa:
if "model.embed_tokens.weight" in weights:
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
# fix bf16, TODO: check if device supports bf16
weights = {k:v.to(Device.DEFAULT).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
weights = fix_bf16(weights)
if quantize:
weights = AbsmaxQuantizedLinear.quantize(weights)

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Union, Optional, Dict
from typing import Tuple, Union, Optional, Dict, Any
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv
@@ -162,4 +162,11 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
elif "k_proj" in k:
v = permute(v, n_kv_heads)
sd[keymap[k]] = v
return sd
return sd
def fix_bf16(weights:Dict[Any, Tensor]):
if getenv("SUPPORT_BF16", 1):
# TODO: without casting to float16, 70B llama OOM on tinybox.
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
# TODO: check if device supports bf16
return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}