feat(mm): port flux redux to new api

This commit is contained in:
psychedelicious
2025-09-24 18:53:30 +10:00
parent eb1eee37f1
commit ab2b1b2bde
2 changed files with 8 additions and 4 deletions

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict
from typing import Any
def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_flux_redux(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely a FLUX Redux model."""
expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"}

View File

@@ -45,6 +45,7 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
@@ -1117,8 +1118,8 @@ class FluxReduxConfig(ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SigLIP,
"format": ModelFormat.Diffusers,
"type": ModelType.FluxRedux,
"format": ModelFormat.Checkpoint,
}
@classmethod
@@ -1128,6 +1129,9 @@ class FluxReduxConfig(ModelConfigBase):
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
return cls(**fields)
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
return cls(**fields)