Rename LoRAPatcher -> ModelPatcher.

This commit is contained in:
Ryan Dick
2024-12-14 15:31:05 +00:00
parent 9369b39a12
commit c604a0956e
9 changed files with 33 additions and 33 deletions

View File

@@ -3,7 +3,7 @@ import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.lora_patcher import LoRAPatcher
from invokeai.backend.patches.model_patcher import ModelPatcher
class DummyModule(torch.nn.Module):
@@ -53,7 +53,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
@@ -93,7 +93,7 @@ def test_apply_lora_patches_change_device():
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
with LoRAPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
with ModelPatcher.apply_lora_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
@@ -146,7 +146,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
@@ -186,10 +186,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
with ModelPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with ModelPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical