Rename ModelPatcher -> LayerPatcher to avoid conflicts with another ModelPatcher definition.

This commit is contained in:
Ryan Dick
2024-12-14 16:11:23 +00:00
parent 7fad4c9491
commit dd09509dbd
9 changed files with 33 additions and 33 deletions

View File

@@ -13,7 +13,7 @@ from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class ModelPatcher:
class LayerPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
@@ -37,7 +37,7 @@ class ModelPatcher:
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
ModelPatcher.apply_model_patch(
LayerPatcher.apply_model_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -85,11 +85,11 @@ class ModelPatcher:
if not layer_key.startswith(prefix):
continue
module_key, module = ModelPatcher._get_submodule(
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
ModelPatcher._apply_model_layer_patch(
LayerPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
@@ -169,7 +169,7 @@ class ModelPatcher:
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
ModelPatcher._apply_model_sidecar_patch(
LayerPatcher._apply_model_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -182,9 +182,9 @@ class ModelPatcher:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = ModelPatcher._split_parent_key(module_key)
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
ModelPatcher._set_submodule(parent_module, module_name, orig_module)
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_model_sidecar_patch(
@@ -212,11 +212,11 @@ class ModelPatcher:
if not layer_key.startswith(prefix):
continue
module_key, module = ModelPatcher._get_submodule(
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
ModelPatcher._apply_model_layer_wrapper_patch(
LayerPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
@@ -242,9 +242,9 @@ class ModelPatcher:
if not isinstance(module_to_patch, BaseSidecarWrapper):
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = ModelPatcher._split_parent_key(module_to_patch_key)
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
ModelPatcher._set_submodule(module_parent, module_name, wrapped_module)
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
else:
assert module_to_patch_key in original_modules
wrapped_module = module_to_patch

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
from diffusers import UNet2DConditionModel
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.patches.model_patcher import LayerPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
@@ -31,7 +31,7 @@ class LoRAExt(ExtensionBase):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, ModelPatchRaw)
ModelPatcher.apply_model_patch(
LayerPatcher.apply_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,