Fix bias handling in LoRAModuleWrapper and add unit test that checks that all LoRA patching methods produce the same outputs.

This commit is contained in:
Ryan Dick
2024-12-09 16:59:37 +00:00
parent 9353bfbdd6
commit 3f28d3afad
2 changed files with 60 additions and 55 deletions

View File

@@ -19,39 +19,40 @@ class LoRAModuleWrapper(torch.nn.Module):
self._lora_weights.append(lora_weight)
@torch.no_grad()
def _get_lora_patched_parameters(self) -> dict[str, torch.Tensor]:
out_params: dict[str, torch.Tensor] = {}
def _get_lora_patched_parameters(self, params: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for lora_layer, lora_weight in zip(self._lora_layers, self._lora_weights, strict=True):
layer_params = lora_layer.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
# If the parameter already exists in out_params, use that one. Otherwise, use original parameter.
if param_name not in out_params:
out_params[param_name] = self._orig_module.get_parameter(param_name)
if params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(params[param_name].shape)
if out_params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(out_params[param_name].shape)
# NOTE: It is important that out_params[param_name] is not modified in-place, because we initialize it
# NOTE: It is important that params[param_name] is not modified in-place, because we initialize it
# with the original parameter - which we don't want to modify. In other words,
# `out_params[param_name] += ...` would not work.
out_params[param_name] = out_params[param_name] + param_weight * (lora_layer.scale() * lora_weight)
params[param_name] = params[param_name] + param_weight * (lora_layer.scale() * lora_weight)
return out_params
return params
class LoRALinearWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters()
return torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
params = self._get_lora_patched_parameters(
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
)
return torch.nn.functional.linear(input, params["weight"], params["bias"])
class LoRAConv1dWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters()
return torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
params = self._get_lora_patched_parameters(
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
)
return torch.nn.functional.conv1d(input, params["weight"], params["bias"])
class LoRAConv2dWrapper(LoRAModuleWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters()
return torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))
params = self._get_lora_patched_parameters(
params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}
)
return torch.nn.functional.conv2d(input, params["weight"], params["bias"])

View File

@@ -159,44 +159,6 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
@pytest.mark.parametrize(
["device", "num_layers"],
[
@@ -245,3 +207,45 @@ def test_apply_lora_wrapper_patches(device: str, num_layers: int):
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@torch.no_grad()
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
def test_all_patching_methods_produce_same_output(num_layers: int):
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
dtype = torch.float32
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw(lora_layers)
lora_models.append((lora, lora_weight))
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
output_lora_wrapper_patches = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)