diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index f3be042146..c33cb3f4ab 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -125,13 +125,16 @@ class IPAdapter(RawModel): self.device, dtype=self.dtype ) - def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): - self.device = device + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False + ): + if device is not None: + self.device = device if dtype is not None: self.dtype = dtype - self._image_proj_model.to(device=self.device, dtype=self.dtype) - self.attn_weights.to(device=self.device, dtype=self.dtype) + self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) + self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) def calc_size(self): # workaround for circular import diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 0b7128034a..f7c3863a6a 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -61,9 +61,10 @@ class LoRALayerBase: self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) + self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking) # TODO: find and debug lora/locon with bias @@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: - super().to(device=device, dtype=dtype) + super().to(device=device, dtype=dtype, non_blocking=non_blocking) - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) + self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) + self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking) class LoHALayer(LoRALayerBase): @@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: super().to(device=device, dtype=dtype) - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) + self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) + self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking) - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) + self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) + self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking) class LoKRLayer(LoRALayerBase): @@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: super().to(device=device, dtype=dtype) @@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase): else: assert self.w1_a is not None assert self.w1_b is not None - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) + self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) + self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking) else: assert self.w2_a is not None assert self.w2_b is not None - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) + self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking) if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) + self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking) class FullLayer(LoRALayerBase): @@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking) class IA3Layer(LoRALayerBase): @@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ): super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) - self.on_input = self.on_input.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking) + self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking) AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] @@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, ) -> None: # TODO: try revert if exception? for _key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) + layer.to(device=device, dtype=dtype, non_blocking=non_blocking) def calc_size(self) -> int: model_size = 0 @@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() - layer.to(device=device, dtype=dtype) + layer.to(device=device, dtype=dtype, non_blocking=True) model.layers[layer_key] = layer return model diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 335a15a5c8..d48e45426e 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]): else: new_dict: Dict[str, torch.Tensor] = {} for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(torch.device(target_device), copy=True) + new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True) cache_entry.model.load_state_dict(new_dict, assign=True) - cache_entry.model.to(target_device) + cache_entry.model.to(target_device, non_blocking=True) cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index c407cd8472..fdc79539ae 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -67,7 +67,7 @@ class ModelPatcher: unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]], model_state_dict: Optional[Dict[str, torch.Tensor]] = None, - ) -> None: + ) -> Generator[None, None, None]: with cls.apply_lora( unet, loras=loras, @@ -83,7 +83,7 @@ class ModelPatcher: text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], model_state_dict: Optional[Dict[str, torch.Tensor]] = None, - ) -> None: + ) -> Generator[None, None, None]: with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): yield @@ -95,7 +95,7 @@ class ModelPatcher: loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, model_state_dict: Optional[Dict[str, torch.Tensor]] = None, - ) -> Generator[Any, None, None]: + ) -> Generator[None, None, None]: """ Apply one or more LoRAs to a model. @@ -139,12 +139,12 @@ class ModelPatcher: # We intentionally move to the target device first, then cast. Experimentally, this was found to # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) + layer.to(device=device, non_blocking=True) + layer.to(dtype=torch.float32, non_blocking=True) # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=torch.device("cpu")) + layer.to(device=torch.device("cpu"), non_blocking=True) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! if module.weight.shape != layer_weight.shape: @@ -153,7 +153,7 @@ class ModelPatcher: layer_weight = layer_weight.reshape(module.weight.shape) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype) + module.weight += layer_weight.to(dtype=dtype, non_blocking=True) yield # wait for context manager exit @@ -161,7 +161,7 @@ class ModelPatcher: assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) + model.get_submodule(module_key).weight.copy_(weight, non_blocking=True) @classmethod @contextmanager diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index 8916865dd5..9fcd4d093f 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union import numpy as np import onnx +import torch from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers @@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel): # return self.io_binding.copy_outputs_to_cpu() return self.session.run(None, inputs) + # compatability with RawModel ABC + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + pass + # compatability with diffusers load code @classmethod def from_pretrained( diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py index d0dc50c456..7bca6945d9 100644 --- a/invokeai/backend/raw_model.py +++ b/invokeai/backend/raw_model.py @@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module that adds additional methods and attributes. """ +from abc import ABC, abstractmethod +from typing import Optional -class RawModel: - """Base class for 'Raw' model wrappers.""" +import torch + + +class RawModel(ABC): + """Abstract base class for 'Raw' model wrappers.""" + + @abstractmethod + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + pass diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index 98104f769e..0408176edb 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel): return result + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + if not torch.cuda.is_available(): + return + for emb in [self.embedding, self.embedding_2]: + if emb is not None: + emb.to(device=device, dtype=dtype, non_blocking=non_blocking) + class TextualInversionManager(BaseTextualInversionManager): """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""