mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Move handling of LoRA scale and patch weight down into the layer patch classes.
This commit is contained in:
@@ -107,8 +107,8 @@ def test_lora_layer_get_parameters():
|
||||
# Create mock original module
|
||||
orig_module = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
params = layer.get_parameters(orig_module)
|
||||
params = layer.get_parameters(orig_module, weight=1.0)
|
||||
assert "weight" in params
|
||||
assert params["weight"].shape == orig_module.weight.shape
|
||||
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha / rank)
|
||||
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha)
|
||||
assert "bias" not in params # No bias in this case
|
||||
|
||||
@@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters():
|
||||
target_weight = torch.randn(8, 4)
|
||||
layer = SetParameterLayer(param_name="weight", weight=target_weight)
|
||||
|
||||
params = layer.get_parameters(orig_module)
|
||||
params = layer.get_parameters(orig_module, weight=1.0)
|
||||
assert len(params) == 1
|
||||
new_weight = orig_module.weight + params["weight"]
|
||||
assert torch.allclose(new_weight, target_weight)
|
||||
|
||||
Reference in New Issue
Block a user