mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tests(mm): flux state dict tests
This commit is contained in:
@@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.model_manager.taxonomy import FluxVariantType
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
_group_state_by_submodel,
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
@@ -45,7 +45,7 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell))
|
||||
model = Flux(get_flux_transformers_params(FluxVariantType.Schnell))
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
for converted_key_prefix in converted_key_prefixes:
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.model_manager.taxonomy import FluxVariantType
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
@@ -64,7 +64,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell))
|
||||
model = Flux(get_flux_transformers_params(FluxVariantType.Schnell))
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
# Assert that the converted state dict matches the keys in the actual model.
|
||||
|
||||
Reference in New Issue
Block a user