mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix bias handling in LoRAModuleWrapper and add unit test that checks that all LoRA patching methods produce the same outputs.
This commit is contained in:
@@ -19,39 +19,40 @@ class LoRAModuleWrapper(torch.nn.Module):
|
||||
self._lora_weights.append(lora_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_lora_patched_parameters(self) -> dict[str, torch.Tensor]:
|
||||
out_params: dict[str, torch.Tensor] = {}
|
||||
def _get_lora_patched_parameters(self, params: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
for lora_layer, lora_weight in zip(self._lora_layers, self._lora_weights, strict=True):
|
||||
layer_params = lora_layer.get_parameters(self._orig_module)
|
||||
for param_name, param_weight in layer_params.items():
|
||||
# If the parameter already exists in out_params, use that one. Otherwise, use original parameter.
|
||||
if param_name not in out_params:
|
||||
out_params[param_name] = self._orig_module.get_parameter(param_name)
|
||||
if params[param_name].shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(params[param_name].shape)
|
||||
|
||||
if out_params[param_name].shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(out_params[param_name].shape)
|
||||
|
||||
# NOTE: It is important that out_params[param_name] is not modified in-place, because we initialize it
|
||||
# NOTE: It is important that params[param_name] is not modified in-place, because we initialize it
|
||||
# with the original parameter - which we don't want to modify. In other words,
|
||||
# `out_params[param_name] += ...` would not work.
|
||||
out_params[param_name] = out_params[param_name] + param_weight * (lora_layer.scale() * lora_weight)
|
||||
params[param_name] = params[param_name] + param_weight * (lora_layer.scale() * lora_weight)
|
||||
|
||||
return out_params
|
||||
return params
|
||||
|
||||
|
||||
class LoRALinearWrapper(LoRAModuleWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters()
|
||||
return torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
|
||||
params = self._get_lora_patched_parameters(
|
||||
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
|
||||
)
|
||||
return torch.nn.functional.linear(input, params["weight"], params["bias"])
|
||||
|
||||
|
||||
class LoRAConv1dWrapper(LoRAModuleWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters()
|
||||
return torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
|
||||
params = self._get_lora_patched_parameters(
|
||||
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
|
||||
)
|
||||
return torch.nn.functional.conv1d(input, params["weight"], params["bias"])
|
||||
|
||||
|
||||
class LoRAConv2dWrapper(LoRAModuleWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters()
|
||||
return torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))
|
||||
params = self._get_lora_patched_parameters(
|
||||
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
|
||||
)
|
||||
return torch.nn.functional.conv2d(input, params["weight"], params["bias"])
|
||||
|
||||
@@ -159,44 +159,6 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
|
||||
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
||||
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||
dtype = torch.float32
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
|
||||
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_sidecar_patches = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
@@ -245,3 +207,45 @@ def test_apply_lora_wrapper_patches(device: str, num_layers: int):
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
|
||||
def test_all_patching_methods_produce_same_output(num_layers: int):
|
||||
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||
dtype = torch.float32
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
|
||||
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_sidecar_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_wrapper_patches = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
||||
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user