diff --git a/examples/coder.py b/examples/coder.py index b8982c2c51..f74fbc2e3f 100644 --- a/examples/coder.py +++ b/examples/coder.py @@ -6,7 +6,7 @@ from io import StringIO from contextlib import redirect_stdout from tinygrad import Tensor, nn, Device, dtypes from tinygrad.helpers import Timing, colored, getenv, fetch -from extra.models.llama import Transformer, convert_from_huggingface +from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 from sentencepiece import SentencePieceProcessor def create_fixed_tokenizer(output_file): @@ -33,9 +33,6 @@ if __name__ == "__main__": part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true")) part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true")) - # fix bf16, TODO: check if device supports bf16 - def fix_bf16(weights): return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()} - with Timing("weights -> model: "): nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, model, 32, 8)), strict=False) nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part2, model, 32, 8)), strict=False) diff --git a/examples/llama.py b/examples/llama.py index aacbfbe709..eb1150c7ac 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -12,7 +12,7 @@ from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, colored from tinygrad import Device, GlobalCounters, dtypes, nn from tinygrad.tensor import Tensor from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters -from extra.models.llama import Transformer, convert_from_huggingface +from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 from sentencepiece import SentencePieceProcessor MAX_CONTEXT = getenv("MAX_CONTEXT", 4096) @@ -176,8 +176,7 @@ class LLaMa: if "model.embed_tokens.weight" in weights: weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"])) - # fix bf16, TODO: check if device supports bf16 - weights = {k:v.to(Device.DEFAULT).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()} + weights = fix_bf16(weights) if quantize: weights = AbsmaxQuantizedLinear.quantize(weights) diff --git a/extra/models/llama.py b/extra/models/llama.py index 20998c313b..49c2e451de 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, Optional, Dict +from typing import Tuple, Union, Optional, Dict, Any from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device from tinygrad.helpers import getenv @@ -162,4 +162,11 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he elif "k_proj" in k: v = permute(v, n_kv_heads) sd[keymap[k]] = v - return sd \ No newline at end of file + return sd + +def fix_bf16(weights:Dict[Any, Tensor]): + if getenv("SUPPORT_BF16", 1): + # TODO: without casting to float16, 70B llama OOM on tinybox. + return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()} + # TODO: check if device supports bf16 + return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}