Add utility test function for creating a dummy state_dict.

This commit is contained in:
Ryan Dick
2024-09-09 19:09:42 +00:00
parent 2e8effe83f
commit cbba28bdec
3 changed files with 29 additions and 28 deletions

View File

@@ -0,0 +1,8 @@
import torch
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
state_dict: dict[str, torch.Tensor] = {}
for k in keys:
state_dict[k] = torch.empty(1)
return state_dict

View File

@@ -1,5 +1,3 @@
import torch
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
@@ -10,14 +8,13 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_diffusers_format_true():
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in flux_diffusers_state_dict_keys:
state_dict[k] = torch.empty(1)
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
@@ -27,9 +24,7 @@ def test_is_state_dict_likely_in_flux_diffusers_format_false():
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in flux_kohya_state_dict_keys:
state_dict[k] = torch.empty(1)
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
@@ -37,10 +32,7 @@ def test_is_state_dict_likely_in_flux_diffusers_format_false():
def test_lora_model_from_flux_diffusers_state_dict():
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in flux_diffusers_state_dict_keys:
state_dict[k] = torch.empty(1)
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_diffusers_state_dict(state_dict)

View File

@@ -8,31 +8,34 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_kohya_format_true():
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_is_state_dict_likely_in_flux_kohya_format_false():
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is not in the Kohya FLUX LoRA format."""
state_dict: dict[str, torch.Tensor] = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
@@ -75,16 +78,14 @@ def test_convert_flux_kohya_state_dict_to_invoke_format_error():
def test_lora_model_from_flux_kohya_state_dict():
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct state_dict from state_dict_keys.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in state_dict_keys:
for k in flux_kohya_state_dict_keys:
k = k.replace("lora_unet_", "")
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")