From ab1e15e4f5212ded8d9516fb76f15b487cf8e062 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 10 Oct 2025 14:23:50 +1100 Subject: [PATCH] tests(mm): flux state dict tests --- .../test_flux_aitoolkit_lora_conversion_utils.py | 4 ++-- .../lora_conversions/test_flux_kohya_lora_conversion_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py index 051ed210cd..f9c20e82a5 100644 --- a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -3,7 +3,7 @@ import pytest from invokeai.backend.flux.model import Flux from invokeai.backend.flux.util import get_flux_transformers_params -from invokeai.backend.model_manager.taxonomy import ModelVariantType +from invokeai.backend.model_manager.taxonomy import FluxVariantType from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( _group_state_by_submodel, is_state_dict_likely_in_flux_aitoolkit_format, @@ -45,7 +45,7 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format(): # Initialize a FLUX model on the meta device. with accelerate.init_empty_weights(): - model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell)) + model = Flux(get_flux_transformers_params(FluxVariantType.Schnell)) model_keys = set(model.state_dict().keys()) for converted_key_prefix in converted_key_prefixes: diff --git a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py index eb8846f456..35a5f5a909 100644 --- a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py @@ -4,7 +4,7 @@ import torch from invokeai.backend.flux.model import Flux from invokeai.backend.flux.util import get_flux_transformers_params -from invokeai.backend.model_manager.taxonomy import ModelVariantType +from invokeai.backend.model_manager.taxonomy import FluxVariantType from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, @@ -64,7 +64,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format(): # Initialize a FLUX model on the meta device. with accelerate.init_empty_weights(): - model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell)) + model = Flux(get_flux_transformers_params(FluxVariantType.Schnell)) model_keys = set(model.state_dict().keys()) # Assert that the converted state dict matches the keys in the actual model.