Disable device meta for spandrel

This commit is contained in:
Billy
2025-03-17 11:30:05 +11:00
parent 654e992630
commit 7d5687e9ff

View File

@@ -42,7 +42,7 @@ def dress(v):
match v:
case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}:
dtype = STR_TO_DTYPE[dtype_str]
return torch.empty(shape, dtype=dtype, device="meta")
return torch.empty(shape, dtype=dtype)
case dict():
return {k: dress(v) for k, v in v.items()}
case list() | tuple():
@@ -54,7 +54,7 @@ def dress(v):
def load_stripped_model(path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
return { "state_dict": dress(contents) }
return dress(contents)
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk: