mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add LoRAPatcher.smart_apply_lora_patches()
This commit is contained in:
@@ -16,6 +16,114 @@ from invokeai.backend.util.original_weights_storage import OriginalWeightsStorag
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_smart_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
|
||||
|
||||
# original_weights are stored for unpatching layers that are directly patched.
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_smart_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore directly patched layers.
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
# Restore LoRASidecarWrapper modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_smart_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||
patching or a sidecar wrapper for each module.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Decide whether to use direct patching or a sidecar wrapper.
|
||||
# Direct patching is preferred, because it results in better runtime speed.
|
||||
# Reasons to use sidecar patching:
|
||||
# - The module is already wrapped in a LoRASidecarWrapper.
|
||||
# - The module is quantized.
|
||||
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||
# CPU, since this would double the RAM usage)
|
||||
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
|
||||
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||
# forcing full patching even on the CPU?
|
||||
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
LoRAPatcher._apply_lora_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
class DummyModuleWithOneLayer(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
@@ -15,6 +15,16 @@ class DummyModule(torch.nn.Module):
|
||||
return self.linear_layer_1(x)
|
||||
|
||||
|
||||
class DummyModuleWithTwoLayers(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_layer_2(self.linear_layer_1(x))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
@@ -33,7 +43,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
@@ -79,7 +89,7 @@ def test_apply_lora_patches_change_device():
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
@@ -124,7 +134,7 @@ def test_apply_lora_wrapper_patches(device: str, num_layers: int):
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
@@ -159,6 +169,57 @@ def test_apply_lora_wrapper_patches(device: str, num_layers: int):
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_lora_patches(device: str, num_layers: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_smart_lora_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
|
||||
def test_all_patching_methods_produce_same_output(num_layers: int):
|
||||
@@ -167,7 +228,7 @@ def test_all_patching_methods_produce_same_output(num_layers: int):
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
|
||||
Reference in New Issue
Block a user