From ab2b1b2bdedf9f50504fbeb5f049c80a30d48857 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 Sep 2025 18:53:30 +1000 Subject: [PATCH] feat(mm): port flux redux to new api --- .../backend/flux/redux/flux_redux_state_dict_utils.py | 4 ++-- invokeai/backend/model_manager/config.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py b/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py index a5a13b402d..83e96d3845 100644 --- a/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py +++ b/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py @@ -1,7 +1,7 @@ -from typing import Any, Dict +from typing import Any -def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_flux_redux(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely a FLUX Redux model.""" expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"} diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 640ab56318..a1e043dc35 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -45,6 +45,7 @@ from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config from invokeai.app.util.misc import uuid_string +from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_manager.model_on_disk import ModelOnDisk @@ -1117,8 +1118,8 @@ class FluxReduxConfig(ModelConfigBase): format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint VALID_OVERRIDES: ClassVar = { - "type": ModelType.SigLIP, - "format": ModelFormat.Diffusers, + "type": ModelType.FluxRedux, + "format": ModelFormat.Checkpoint, } @classmethod @@ -1128,6 +1129,9 @@ class FluxReduxConfig(ModelConfigBase): if _validate_overrides(cls, fields, cls.VALID_OVERRIDES): return cls(**fields) + if not is_state_dict_likely_flux_redux(mod.load_state_dict()): + raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics") + return cls(**fields)