Rename ModelPatcher methods to reflect that they are general model patching methods and are not LoRA-specific.

This commit is contained in:
Ryan Dick
2024-12-14 15:37:26 +00:00
parent c604a0956e
commit b820862eab
9 changed files with 24 additions and 24 deletions

View File

@@ -17,7 +17,7 @@ class ModelPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_patches(
def apply_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
@@ -37,7 +37,7 @@ class ModelPatcher:
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
ModelPatcher.apply_lora_patch(
ModelPatcher.apply_model_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -54,7 +54,7 @@ class ModelPatcher:
@staticmethod
@torch.no_grad()
def apply_lora_patch(
def apply_model_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
@@ -89,7 +89,7 @@ class ModelPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
ModelPatcher._apply_lora_layer_patch(
ModelPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
@@ -99,7 +99,7 @@ class ModelPatcher:
@staticmethod
@torch.no_grad()
def _apply_lora_layer_patch(
def _apply_model_layer_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: BaseLayerPatch,
@@ -146,7 +146,7 @@ class ModelPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
def apply_model_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
@@ -169,7 +169,7 @@ class ModelPatcher:
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
ModelPatcher._apply_lora_sidecar_patch(
ModelPatcher._apply_model_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -187,7 +187,7 @@ class ModelPatcher:
ModelPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
def _apply_model_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
@@ -216,7 +216,7 @@ class ModelPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
ModelPatcher._apply_lora_layer_wrapper_patch(
ModelPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
@@ -228,7 +228,7 @@ class ModelPatcher:
@staticmethod
@torch.no_grad()
def _apply_lora_layer_wrapper_patch(
def _apply_model_layer_wrapper_patch(
model: torch.nn.Module,
module_to_patch: torch.nn.Module,
module_to_patch_key: str,

View File

@@ -31,7 +31,7 @@ class LoRAExt(ExtensionBase):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, LoRAModelRaw)
ModelPatcher.apply_lora_patch(
ModelPatcher.apply_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,