Move handling of LoRA scale and patch weight down into the layer patch classes.

This commit is contained in:
Ryan Dick
2024-12-13 23:15:30 +00:00
parent e7e3f7e144
commit fe09f2d27a
7 changed files with 14 additions and 24 deletions

View File

@@ -107,8 +107,8 @@ def test_lora_layer_get_parameters():
# Create mock original module
orig_module = torch.nn.Linear(in_features, out_features)
params = layer.get_parameters(orig_module)
params = layer.get_parameters(orig_module, weight=1.0)
assert "weight" in params
assert params["weight"].shape == orig_module.weight.shape
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha / rank)
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha)
assert "bias" not in params # No bias in this case

View File

@@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters():
target_weight = torch.randn(8, 4)
layer = SetParameterLayer(param_name="weight", weight=target_weight)
params = layer.get_parameters(orig_module)
params = layer.get_parameters(orig_module, weight=1.0)
assert len(params) == 1
new_weight = orig_module.weight + params["weight"]
assert torch.allclose(new_weight, target_weight)