From 07d3676019ec023056031350650bb779e99ab66e Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Fri, 31 Jan 2025 23:16:47 +0100 Subject: [PATCH] weights_only=False (#8839) --- test/unit/test_disk_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index a76c194076..6bb099aa76 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -10,7 +10,7 @@ def compare_weights_both(url): import torch fn = fetch(url) tg_weights = get_state_dict(torch_load(fn)) - torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=True), tensor_type=torch.Tensor) + torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=False), tensor_type=torch.Tensor) assert list(tg_weights.keys()) == list(torch_weights.keys()) for k in tg_weights: if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16