mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Rename ModelPatcher methods to reflect that they are general model patching methods and are not LoRA-specific.
This commit is contained in:
@@ -17,7 +17,7 @@ class ModelPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
def apply_model_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
@@ -37,7 +37,7 @@ class ModelPatcher:
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
ModelPatcher.apply_lora_patch(
|
||||
ModelPatcher.apply_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -54,7 +54,7 @@ class ModelPatcher:
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
def apply_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
@@ -89,7 +89,7 @@ class ModelPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
ModelPatcher._apply_lora_layer_patch(
|
||||
ModelPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
@@ -99,7 +99,7 @@ class ModelPatcher:
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_layer_patch(
|
||||
def _apply_model_layer_patch(
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: BaseLayerPatch,
|
||||
@@ -146,7 +146,7 @@ class ModelPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
def apply_model_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
@@ -169,7 +169,7 @@ class ModelPatcher:
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
ModelPatcher._apply_lora_sidecar_patch(
|
||||
ModelPatcher._apply_model_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -187,7 +187,7 @@ class ModelPatcher:
|
||||
ModelPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
def _apply_model_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
@@ -216,7 +216,7 @@ class ModelPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
ModelPatcher._apply_lora_layer_wrapper_patch(
|
||||
ModelPatcher._apply_model_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
@@ -228,7 +228,7 @@ class ModelPatcher:
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_layer_wrapper_patch(
|
||||
def _apply_model_layer_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
|
||||
@@ -31,7 +31,7 @@ class LoRAExt(ExtensionBase):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
ModelPatcher.apply_lora_patch(
|
||||
ModelPatcher.apply_model_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
|
||||
Reference in New Issue
Block a user