diff --git a/invokeai/app/invocations/flux_lora_loader.py b/invokeai/app/invocations/flux_lora_loader.py index 9ec58755f2..a12f21cb9a 100644 --- a/invokeai/app/invocations/flux_lora_loader.py +++ b/invokeai/app/invocations/flux_lora_loader.py @@ -40,7 +40,7 @@ class FluxLoRALoaderInvocation(BaseInvocation): raise ValueError(f"Unknown lora: {lora_key}!") if any(lora.lora.key == lora_key for lora in self.transformer.loras): - raise Exception(f'LoRA "{lora_key}" already applied to transformer.') + raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') transformer = self.transformer.model_copy(deep=True) transformer.loras.append( diff --git a/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py index 1d21eeccd0..ccac032686 100644 --- a/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py @@ -30,7 +30,7 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te return all_keys_in_peft_format and all_expected_keys_present -def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared) +def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: """Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object. This function is based on: diff --git a/invokeai/backend/lora/layers/concatenated_lora_layer.py b/invokeai/backend/lora/layers/concatenated_lora_layer.py index d69dabb03f..f9a860d07e 100644 --- a/invokeai/backend/lora/layers/concatenated_lora_layer.py +++ b/invokeai/backend/lora/layers/concatenated_lora_layer.py @@ -41,6 +41,3 @@ class ConcatenatedLoRALayer(LoRALayerBase): assert len(layer_biases) == len(self.lora_layers) return torch.cat(layer_biases, dim=self.concat_axis) - - def calc_size(self) -> int: - return sum(lora_layer.calc_size() for lora_layer in self.lora_layers) diff --git a/invokeai/backend/lora/lora_patcher.py b/invokeai/backend/lora/lora_patcher.py index 86e2f59f7e..7ef6360c47 100644 --- a/invokeai/backend/lora/lora_patcher.py +++ b/invokeai/backend/lora/lora_patcher.py @@ -28,11 +28,14 @@ class LoRAPatcher: ): """Apply one or more LoRA patches to a model within a context manager. - :param model: The model to patch. - :param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so - that the LoRA patches do not need to be loaded into memory all at once. - :param prefix: The keys in the patches will be filtered to only include weights with this prefix. - :cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes. + Args: + model (torch.nn.Module): The model to patch. + patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and + associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory + all at once. + prefix (str): The keys in the patches will be filtered to only include weights with this prefix. + cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in + CPU RAM, for efficient unpatching purposes. """ original_weights = OriginalWeightsStorage(cached_weights) try: @@ -60,15 +63,15 @@ class LoRAPatcher: patch_weight: float, original_weights: OriginalWeightsStorage, ): - """ - Apply a single LoRA patch to a model. - :param model: The model to patch. - :param patch: LoRA model to patch in. - :param patch_weight: LoRA patch weight. - :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching. - """ + """Apply a single LoRA patch to a model. + Args: + model (torch.nn.Module): The model to patch. + prefix (str): A string prefix that precedes keys used in the LoRAs weight layers. + patch (LoRAModelRaw): The LoRA model to patch in. + patch_weight (float): The weight of the LoRA patch. + original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching. + """ if patch_weight == 0: return @@ -126,6 +129,17 @@ class LoRAPatcher: patches: Iterable[Tuple[LoRAModelRaw, float]], prefix: str, ): + """Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some + overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any + quantization format. + + Args: + model (torch.nn.Module): The model to patch. + patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and + associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory + all at once. + prefix (str): The keys in the patches will be filtered to only include weights with this prefix. + """ original_modules: dict[str, torch.nn.Module] = {} try: for patch, patch_weight in patches: @@ -136,7 +150,6 @@ class LoRAPatcher: patch_weight=patch_weight, original_modules=original_modules, ) - yield finally: # Restore original modules. @@ -154,6 +167,8 @@ class LoRAPatcher: prefix: str, original_modules: dict[str, torch.nn.Module], ): + """Apply a single LoRA sidecar patch to a model.""" + if patch_weight == 0: return @@ -178,8 +193,8 @@ class LoRAPatcher: # Move the LoRA sidecar layer to the same device/dtype as the orig module. # TODO(ryand): Experiment with moving to the device first, then casting. This could be faster. - # HACK(ryand): Set the dtype properly here. We want to set it to the *compute* dtype of the original module. - # In the case of quantized layers, this may be different than the weight dtype. + # HACK(ryand): Figure out how to set the dtype properly here. We want to set it to the *compute* dtype of + # the original module. In the case of quantized layers, this may be different than the weight dtype. lora_sidecar_layer.to(device=module.weight.device, dtype=torch.bfloat16) if module_key in original_modules: @@ -196,6 +211,7 @@ class LoRAPatcher: @staticmethod def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float): + # TODO(ryand): Add support for more original layer types and LoRA layer types. if isinstance(orig_layer, torch.nn.Linear): if isinstance(lora_layer, LoRALayer): return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight) @@ -211,7 +227,7 @@ class LoRAPatcher: try: submodule_index = int(module_name) # If the module name is an integer, then we use the __setitem__ method to set the submodule. - parent_module[submodule_index] = submodule + parent_module[submodule_index] = submodule # type: ignore except ValueError: # If the module name is not an integer, then we use the setattr method to set the submodule. setattr(parent_module, module_name, submodule) @@ -221,12 +237,16 @@ class LoRAPatcher: model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool ) -> tuple[str, torch.nn.Module]: """Get the submodule corresponding to the given layer key. - :param model: The model to search. - :param layer_key: The layer key to search for. - :param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced - with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without - searching, but some legacy code still uses flattened keys. - :return: A tuple containing the module key and the submodule. + + Args: + model (torch.nn.Module): The model to search. + layer_key (str): The layer key to search for. + layer_key_is_flattened (bool): Whether the layer key is flattened. If flattened, then all '.' have been + replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed + directly without searching, but some legacy code still uses flattened keys. + + Returns: + tuple[str, torch.nn.Module]: A tuple containing the module key and the submodule. """ if not layer_key_is_flattened: return layer_key, model.get_submodule(layer_key) diff --git a/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py b/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py index f14d46e56a..b479518e76 100644 --- a/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py +++ b/invokeai/backend/lora/sidecar_layers/lora_sidecar_module.py @@ -2,6 +2,8 @@ import torch class LoRASidecarModule(torch.nn.Module): + """A LoRA sidecar module that wraps an original module and adds LoRA layers to it.""" + def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]): super().__init__() self._orig_module = orig_module