mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): port flux redux to new api
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
|
||||
def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool:
|
||||
def is_state_dict_likely_flux_redux(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely a FLUX Redux model."""
|
||||
|
||||
expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"}
|
||||
|
||||
@@ -45,6 +45,7 @@ from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
@@ -1117,8 +1118,8 @@ class FluxReduxConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.SigLIP,
|
||||
"format": ModelFormat.Diffusers,
|
||||
"type": ModelType.FluxRedux,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -1128,6 +1129,9 @@ class FluxReduxConfig(ModelConfigBase):
|
||||
if _validate_overrides(cls, fields, cls.VALID_OVERRIDES):
|
||||
return cls(**fields)
|
||||
|
||||
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
|
||||
raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user