mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add state dict tensor shapes for existing LoRA unit tests.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
|
||||
|
||||
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
|
||||
def keys_to_mock_state_dict(keys: dict[str, list[int]]) -> dict[str, torch.Tensor]:
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
for k, shape in keys.items():
|
||||
state_dict[k] = torch.empty(shape)
|
||||
return state_dict
|
||||
|
||||
@@ -23,7 +23,7 @@ from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_s
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
@@ -83,7 +83,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
Reference in New Issue
Block a user