fix import ordering, remove code I reverted that the resync added back

This commit is contained in:
David Burnett
2025-04-29 12:10:57 +01:00
committed by psychedelicious
parent 99e154d773
commit 6c0bd7d150
2 changed files with 4 additions and 17 deletions

View File

@@ -1,8 +1,9 @@
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from typing import Any
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
@@ -78,8 +79,7 @@ class CachedModelOnlyFullLoad:
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
@@ -103,7 +103,7 @@ class CachedModelOnlyFullLoad:
if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)