Fixup unit tests.

This commit is contained in:
Ryan Dick
2024-09-12 22:27:49 +00:00
committed by Kent Keirsey
parent 5bb0c79c14
commit 7ce41bf7e0
2 changed files with 4 additions and 4 deletions

View File

@@ -37,7 +37,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_diffusers_state_dict(state_dict)
model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
@@ -63,4 +63,4 @@ def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_diffusers_state_dict(state_dict)
lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)

View File

@@ -27,7 +27,7 @@ def test_apply_lora(device: str):
)
lora_layers = {
"linear_layer_1": LoRALayer(
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
@@ -70,7 +70,7 @@ def test_apply_lora_change_device():
)
lora_layers = {
"linear_layer_1": LoRALayer(
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),