diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index fab1336889..0a62916306 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack from typing import Callable, Iterator, Optional, Tuple import torch @@ -31,6 +32,7 @@ from invokeai.backend.flux.sampling_utils import ( ) from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher +from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice @@ -191,21 +193,38 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): with ( transformer_info.model_on_device() as (cached_weights, transformer), - # Apply the LoRA after transformer has been moved to its target device for faster patching. - # LoRAPatcher.apply_lora_sidecar_patches( - # model=transformer, - # patches=self._lora_iterator(context), - # prefix="", - # ), - LoRAPatcher.apply_lora_patches( - model=transformer, - patches=self._lora_iterator(context), - prefix="", - cached_weights=cached_weights, - ), + ExitStack() as exit_stack, ): assert isinstance(transformer, Flux) + config = transformer_info.config + assert config is not None + + # Apply LoRA models to the transformer. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + if config.format in [ModelFormat.Checkpoint]: + # The model is non-quantized, so we can apply the LoRA weights directly into the model. + exit_stack.enter_context( + LoRAPatcher.apply_lora_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix="", + cached_weights=cached_weights, + ) + ) + elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]: + # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference, + # than directly patching the weights, but is agnostic to the quantization format. + exit_stack.enter_context( + LoRAPatcher.apply_lora_sidecar_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix="", + ) + ) + else: + raise ValueError(f"Unsupported model format: {config.format}") + x = denoise( model=transformer, img=x, diff --git a/invokeai/backend/lora/lora_patcher.py b/invokeai/backend/lora/lora_patcher.py index 9bd8fa5a2c..8a40cde689 100644 --- a/invokeai/backend/lora/lora_patcher.py +++ b/invokeai/backend/lora/lora_patcher.py @@ -179,7 +179,9 @@ class LoRAPatcher: # Move the LoRA sidecar layer to the same device/dtype as the orig module. # TODO(ryand): Experiment with moving to the device first, then casting. This could be faster. - lora_sidecar_layer.to(device=module.weight.device, dtype=module.weight.dtype) + # HACK(ryand): Set the dtype properly here. We want to set it to the *compute* dtype of the original module. + # In the case of quantized layers, this may be different than the weight dtype. + lora_sidecar_layer.to(device=module.weight.device, dtype=torch.bfloat16) if module_key in original_modules: # The module has already been patched with a LoRASidecarModule. Append to it. @@ -197,7 +199,7 @@ class LoRAPatcher: def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float): if isinstance(orig_layer, torch.nn.Linear): if isinstance(lora_layer, LoRALayer): - return LoRALinearSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight) + return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight) else: raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}") elif isinstance(orig_layer, torch.nn.Conv1d): diff --git a/invokeai/backend/lora/sidecar_layers/lora/lora_conv_sidecar_layer.py b/invokeai/backend/lora/sidecar_layers/lora/lora_conv_sidecar_layer.py index 7d2f7faa9f..0c9bc07aa8 100644 --- a/invokeai/backend/lora/sidecar_layers/lora/lora_conv_sidecar_layer.py +++ b/invokeai/backend/lora/sidecar_layers/lora/lora_conv_sidecar_layer.py @@ -70,6 +70,8 @@ class LoRAConvSidecarLayer(torch.nn.Module): weight=weight, ) + # TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers? + # Inject weight into the LoRA layer. assert model._up.weight.shape == lora_layer.up.shape assert model._down.weight.shape == lora_layer.down.shape diff --git a/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py b/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py index f3e4fa990e..91ca0f9785 100644 --- a/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py +++ b/invokeai/backend/lora/sidecar_layers/lora/lora_linear_sidecar_layer.py @@ -4,95 +4,24 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer class LoRALinearSidecarLayer(torch.nn.Module): - """An implementation of a linear LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'. - (https://arxiv.org/pdf/2106.09685.pdf) - """ - def __init__( self, - in_features: int, - out_features: int, - include_mid: bool, - rank: int, - alpha: float, + lora_layer: LoRALayer, weight: float, - device: torch.device | None = None, - dtype: torch.dtype | None = None, ): super().__init__() - if rank > min(in_features, out_features): - raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_features, out_features)}") - - self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) - self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) - self._mid = None - if include_mid: - self._mid = torch.nn.Linear(rank, rank, bias=False, device=device, dtype=dtype) - - # Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict. - self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype)) - + self._lora_layer = lora_layer self._weight = weight - self._rank = rank - @classmethod - def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float): - # Initialize the LoRA layer. - with torch.device("meta"): - model = cls.from_orig_layer( - orig_layer, - include_mid=lora_layer.mid is not None, - rank=lora_layer.rank, - # TODO(ryand): Is this the right default in case of missing alpha? - alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank, - weight=weight, - ) - - # TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers? - - # Inject weight into the LoRA layer. - assert model._up.weight.shape == lora_layer.up.shape - assert model._down.weight.shape == lora_layer.down.shape - model._up.weight.data = lora_layer.up - model._down.weight.data = lora_layer.down - if lora_layer.mid is not None: - assert model._mid is not None - assert model._mid.weight.shape == lora_layer.mid.shape - model._mid.weight.data = lora_layer.mid - - return model - - @classmethod - def from_orig_layer( - cls, - layer: torch.nn.Module, - include_mid: bool, - rank: int, - alpha: float, - weight: float, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - ): - if not isinstance(layer, torch.nn.Linear): - raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.") - - return cls( - in_features=layer.in_features, - out_features=layer.out_features, - include_mid=include_mid, - rank=rank, - alpha=alpha, - weight=weight, - device=layer.weight.device if device is None else device, - dtype=layer.weight.dtype if dtype is None else dtype, - ) + def to(self, device: torch.device, dtype: torch.dtype): + self._lora_layer.to(device, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self._down(x) - if self._mid is not None: - x = self._mid(x) - x = self._up(x) - - x *= self._weight * self.alpha / self._rank + x = torch.nn.functional.linear(x, self._lora_layer.down) + if self._lora_layer.mid is not None: + x = torch.nn.functional.linear(x, self._lora_layer.mid) + x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias) + scale = self._lora_layer.alpha / self._lora_layer.rank if self._lora_layer.alpha is not None else 1.0 + x *= self._weight * scale return x