weights_only=False (#8839)

This commit is contained in:
Ahmed Harmouche
2025-01-31 23:16:47 +01:00
committed by GitHub
parent 741bbc900d
commit 07d3676019

View File

@@ -10,7 +10,7 @@ def compare_weights_both(url):
import torch
fn = fetch(url)
tg_weights = get_state_dict(torch_load(fn))
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=True), tensor_type=torch.Tensor)
torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=False), tensor_type=torch.Tensor)
assert list(tg_weights.keys()) == list(torch_weights.keys())
for k in tg_weights:
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16