mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
weights_only=False (#8839)
This commit is contained in:
@@ -10,7 +10,7 @@ def compare_weights_both(url):
|
||||
import torch
|
||||
fn = fetch(url)
|
||||
tg_weights = get_state_dict(torch_load(fn))
|
||||
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=True), tensor_type=torch.Tensor)
|
||||
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=False), tensor_type=torch.Tensor)
|
||||
assert list(tg_weights.keys()) == list(torch_weights.keys())
|
||||
for k in tg_weights:
|
||||
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
|
||||
Reference in New Issue
Block a user