bf16 fix + cleanups from mixtral (#2698)

* bf16 fix + cleanups from mixtral

* generic bf16 cast
This commit is contained in:
George Hotz
2023-12-10 16:31:52 -08:00
committed by GitHub
parent 7fbebb3df6
commit 0fd44259cd
8 changed files with 24 additions and 17 deletions

View File

@@ -1,7 +1,7 @@
import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.tensor import Tensor, Device, dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import CI, fetch, temp
from tinygrad.helpers import Timing
@@ -13,6 +13,7 @@ def compare_weights_both(url):
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu')), tensor_type=torch.Tensor)
assert list(tg_weights.keys()) == list(torch_weights.keys())
for k in tg_weights:
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
print(f"compared {len(tg_weights)} weights")