diff --git a/tests/model_identification/stripped_model_on_disk.py b/tests/model_identification/stripped_model_on_disk.py index 13d91e23e8..387e04839e 100644 --- a/tests/model_identification/stripped_model_on_disk.py +++ b/tests/model_identification/stripped_model_on_disk.py @@ -67,7 +67,7 @@ class StrippedModelOnDisk(ModelOnDisk): ) case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}: dtype = cls.STR_TO_DTYPE[dtype_str] - return torch.empty(shape, dtype=dtype) + return torch.empty(shape, dtype=dtype, device="meta") case dict(): return {k: cls.dress(v) for k, v in v.items()} case list() | tuple():