From 7ce41bf7e0d2ecc820742a331caab512b0714a2e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 12 Sep 2024 22:27:49 +0000 Subject: [PATCH] Fixup unit tests. --- .../conversions/test_flux_diffusers_lora_conversion_utils.py | 4 ++-- tests/backend/lora/test_lora_patcher.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py index 12c45b579d..d770788381 100644 --- a/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py @@ -37,7 +37,7 @@ def test_lora_model_from_flux_diffusers_state_dict(): # Construct a state dict that is in the Diffusers FLUX LoRA format. state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys) # Load the state dict into a LoRAModelRaw object. - model = lora_model_from_flux_diffusers_state_dict(state_dict) + model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0) # Check that the model has the correct number of LoRA layers. expected_lora_layers: set[str] = set() @@ -63,4 +63,4 @@ def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error(): # Check that an error is raised. with pytest.raises(AssertionError): - lora_model_from_flux_diffusers_state_dict(state_dict) + lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0) diff --git a/tests/backend/lora/test_lora_patcher.py b/tests/backend/lora/test_lora_patcher.py index e82a8dad85..0d8d3b6964 100644 --- a/tests/backend/lora/test_lora_patcher.py +++ b/tests/backend/lora/test_lora_patcher.py @@ -27,7 +27,7 @@ def test_apply_lora(device: str): ) lora_layers = { - "linear_layer_1": LoRALayer( + "linear_layer_1": LoRALayer.from_state_dict_values( values={ "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16), "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16), @@ -70,7 +70,7 @@ def test_apply_lora_change_device(): ) lora_layers = { - "linear_layer_1": LoRALayer( + "linear_layer_1": LoRALayer.from_state_dict_values( values={ "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16), "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),