mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
bf16 fix + cleanups from mixtral (#2698)
* bf16 fix + cleanups from mixtral * generic bf16 cast
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import pathlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.tensor import Tensor, Device, dtypes
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
from tinygrad.helpers import Timing
|
||||
@@ -13,6 +13,7 @@ def compare_weights_both(url):
|
||||
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu')), tensor_type=torch.Tensor)
|
||||
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
||||
for k in tg_weights:
|
||||
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
||||
print(f"compared {len(tg_weights)} weights")
|
||||
|
||||
Reference in New Issue
Block a user