Switch LoRAPatcher to use the new sidecar_wrappers/ rather than sidecar_layers/.

This commit is contained in:
Ryan Dick
2024-12-13 20:02:05 +00:00
parent ac28370fd2
commit 46133b5656
2 changed files with 17 additions and 46 deletions

View File

@@ -4,14 +4,9 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.patches.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.patches.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@@ -253,28 +248,22 @@ class LoRAPatcher:
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a model."""
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module_to_patch, patch, patch_weight)
# Replace the original module with a LoRASidecarModule if it has not already been done.
if module_to_patch_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module_to_patch, LoRASidecarModule)
lora_sidecar_module = module_to_patch
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module_to_patch, [])
# Replace the original module with a BaseSidecarWrapper if it has not already been done.
if not isinstance(module_to_patch, BaseSidecarWrapper):
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
LoRAPatcher._set_submodule(module_parent, module_name, wrapped_module)
else:
assert module_to_patch_key in original_modules
wrapped_module = module_to_patch
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
# Move the LoRA layer to the same device/dtype as the orig module.
patch.to(device=wrapped_module.orig_module.weight.device, dtype=dtype)
# Add the LoRA sidecar layer to the LoRASidecarModule.
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
# Add the patch to the sidecar wrapper.
wrapped_module.add_patch(patch, patch_weight)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
@@ -294,21 +283,6 @@ class LoRAPatcher:
else:
raise ValueError(f"Invalid module key: {module_key}")
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: BaseLayerPatch, patch_weight: float):
# TODO(ryand): Add support for more original layer types and LoRA layer types.
if isinstance(orig_layer, torch.nn.Linear) or (
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
@staticmethod
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
try:

View File

@@ -1,19 +1,16 @@
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
def wrap_module_with_sidecar_wrapper(
orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> torch.nn.Module:
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
if isinstance(orig_module, torch.nn.Linear):
return LinearSidecarWrapper(orig_module, patches_and_weights)
return LinearSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv1d):
return Conv1dSidecarWrapper(orig_module, patches_and_weights)
return Conv1dSidecarWrapper(orig_module)
elif isinstance(orig_module, torch.nn.Conv2d):
return Conv2dSidecarWrapper(orig_module, patches_and_weights)
return Conv2dSidecarWrapper(orig_module)
else:
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")