mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
apply the same fix_bf16 in llama and coder (#3789)
* apply the same fix_bf16 in llama and coder did not realize the same logic was in llama too. really fix #2775 * flag for native SUPPORT_BF16 cast
This commit is contained in:
@@ -6,7 +6,7 @@ from io import StringIO
|
||||
from contextlib import redirect_stdout
|
||||
from tinygrad import Tensor, nn, Device, dtypes
|
||||
from tinygrad.helpers import Timing, colored, getenv, fetch
|
||||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
def create_fixed_tokenizer(output_file):
|
||||
@@ -33,9 +33,6 @@ if __name__ == "__main__":
|
||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
||||
|
||||
# fix bf16, TODO: check if device supports bf16
|
||||
def fix_bf16(weights): return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, model, 32, 8)), strict=False)
|
||||
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, model, 32, 8)), strict=False)
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, colored
|
||||
from tinygrad import Device, GlobalCounters, dtypes, nn
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
|
||||
@@ -176,8 +176,7 @@ class LLaMa:
|
||||
if "model.embed_tokens.weight" in weights:
|
||||
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
|
||||
|
||||
# fix bf16, TODO: check if device supports bf16
|
||||
weights = {k:v.to(Device.DEFAULT).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
weights = fix_bf16(weights)
|
||||
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple, Union, Optional, Dict
|
||||
from typing import Tuple, Union, Optional, Dict, Any
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
@@ -162,4 +162,11 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
|
||||
elif "k_proj" in k:
|
||||
v = permute(v, n_kv_heads)
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
return sd
|
||||
|
||||
def fix_bf16(weights:Dict[Any, Tensor]):
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
# TODO: check if device supports bf16
|
||||
return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
|
||||
Reference in New Issue
Block a user