mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add utility test function for creating a dummy state_dict.
This commit is contained in:
8
tests/backend/lora/conversions/lora_state_dicts/utils.py
Normal file
8
tests/backend/lora/conversions/lora_state_dicts/utils.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
|
||||
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
return state_dict
|
||||
@@ -1,5 +1,3 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
@@ -10,14 +8,13 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_true():
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in flux_diffusers_state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
|
||||
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
@@ -27,9 +24,7 @@ def test_is_state_dict_likely_in_flux_diffusers_format_false():
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is not in the Kohya FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in flux_kohya_state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
|
||||
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
@@ -37,10 +32,7 @@ def test_is_state_dict_likely_in_flux_diffusers_format_false():
|
||||
def test_lora_model_from_flux_diffusers_state_dict():
|
||||
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in flux_diffusers_state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
# Load the state dict into a LoRAModelRaw object.
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict)
|
||||
|
||||
|
||||
@@ -8,31 +8,34 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true():
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
|
||||
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_false():
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is not in the Kohya FLUX LoRA format."""
|
||||
state_dict: dict[str, torch.Tensor] = {
|
||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||
}
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is in the Diffusers
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
|
||||
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
@@ -75,16 +78,14 @@ def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
||||
|
||||
def test_lora_model_from_flux_kohya_state_dict():
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys)
|
||||
|
||||
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
||||
|
||||
# Prepare expected layer keys.
|
||||
expected_layer_keys: set[str] = set()
|
||||
for k in state_dict_keys:
|
||||
for k in flux_kohya_state_dict_keys:
|
||||
k = k.replace("lora_unet_", "")
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
|
||||
Reference in New Issue
Block a user