mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-07 07:24:57 -05:00
Rename ModelPatcher -> LayerPatcher to avoid conflicts with another ModelPatcher definition.
This commit is contained in:
@@ -13,7 +13,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
class LayerPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
@@ -37,7 +37,7 @@ class ModelPatcher:
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
ModelPatcher.apply_model_patch(
|
||||
LayerPatcher.apply_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -85,11 +85,11 @@ class ModelPatcher:
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = ModelPatcher._get_submodule(
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
ModelPatcher._apply_model_layer_patch(
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
@@ -169,7 +169,7 @@ class ModelPatcher:
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
ModelPatcher._apply_model_sidecar_patch(
|
||||
LayerPatcher._apply_model_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -182,9 +182,9 @@ class ModelPatcher:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = ModelPatcher._split_parent_key(module_key)
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
ModelPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_model_sidecar_patch(
|
||||
@@ -212,11 +212,11 @@ class ModelPatcher:
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = ModelPatcher._get_submodule(
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
ModelPatcher._apply_model_layer_wrapper_patch(
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
@@ -242,9 +242,9 @@ class ModelPatcher:
|
||||
if not isinstance(module_to_patch, BaseSidecarWrapper):
|
||||
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
module_parent_key, module_name = ModelPatcher._split_parent_key(module_to_patch_key)
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
ModelPatcher._set_submodule(module_parent, module_name, wrapped_module)
|
||||
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
|
||||
else:
|
||||
assert module_to_patch_key in original_modules
|
||||
wrapped_module = module_to_patch
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,7 +31,7 @@ class LoRAExt(ExtensionBase):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
ModelPatcher.apply_model_patch(
|
||||
LayerPatcher.apply_model_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
|
||||
Reference in New Issue
Block a user