refactor: port MM probes to new api

- Add concept of match certainty to new probe
- Port CLIP Embed models to new API
- Fiddle with stuff
This commit is contained in:
psychedelicious
2025-09-23 13:00:26 +10:00
parent 613fa15ee7
commit 4b52cc2546
21 changed files with 488 additions and 116 deletions

View File

@@ -13,7 +13,7 @@ from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.flux.util import get_flux_max_seq_length
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
)
@@ -94,5 +94,5 @@ class FluxModelLoaderInvocation(BaseInvocation):
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
max_seq_len=get_flux_max_seq_length(transformer_config.variant),
)

View File

@@ -5,6 +5,7 @@ import os
import re
import threading
import time
from copy import deepcopy
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
@@ -370,6 +371,8 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = self.app_config.models_path / model.path
if model_path.is_file() or model_path.is_symlink():
model_path.unlink()
assert model_path.parent != self.app_config.models_path
os.rmdir(model_path.parent)
elif model_path.is_dir():
rmtree(model_path)
self.unregister(key)
@@ -607,9 +610,9 @@ class ModelInstallService(ModelInstallServiceBase):
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@@ -21,6 +21,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
FluxVariantType,
ModelFormat,
ModelSourceType,
ModelType,
@@ -90,7 +91,9 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
description="The variant of the model.", default=None
)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)

View File

@@ -141,10 +141,25 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# The changes may mean the model config class changes. So we need to:
#
# 1. convert the existing record to a dict
# 2. apply the changes to the dict
# 3. create a new model config from the updated dict
#
# This way we ensure that the update does not inadvertently create an invalid model config.
# 1. convert the existing record to a dict
record_as_dict = record.model_dump()
# 2. apply the changes to the dict
for field_name in changes.model_fields_set:
record_as_dict[field_name] = getattr(changes, field_name)
# 3. create a new model config from the updated dict
record = ModelConfigFactory.make_config(record_as_dict)
# If we get this far, the updated model config is valid, so we can save it to the database.
json_serialized = record.model_dump_json()
cursor.execute(

View File

@@ -8,7 +8,7 @@ from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator
from invokeai.backend.model_manager.config import AnyModelConfigValidator
class NormalizeResult(NamedTuple):
@@ -30,7 +30,7 @@ class Migration22Callback:
for model_id, config_json in rows:
try:
# Get the model config as a pydantic object
config = self._load_model_config(config_json)
config = AnyModelConfigValidator.validate_json(config_json)
except ValidationError:
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
# for users, more likely for devs testing out migration paths.
@@ -216,11 +216,6 @@ class Migration22Callback:
self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir)
def _load_model_config(self, config_json: str) -> AnyModelConfig:
# The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function
# just makes that clear.
return AnyModelConfigValidator.validate_json(config_json)
def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""Builds the migration object for migrating from version 21 to version 22.

View File

@@ -1,10 +1,11 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Dict, Literal
from typing import Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
@dataclass
@@ -41,30 +42,39 @@ PREFERED_KONTEXT_RESOLUTIONS = [
]
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,
"flux-schnell": 256,
_flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
FluxVariantType.Dev: 512,
FluxVariantType.DevFill: 512,
FluxVariantType.Schnell: 256,
}
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
}
def get_flux_max_seq_length(variant: AnyVariant):
try:
return _flux_max_seq_lengths[variant]
except KeyError:
raise ValueError(f"Unknown variant for FLUX max seq len: {variant}")
params = {
"flux-dev": FluxParams(
_flux_ae_params = AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
def get_flux_ae_params() -> AutoEncoderParams:
return _flux_ae_params
_flux_transformer_params: dict[AnyVariant, FluxParams] = {
FluxVariantType.Dev: FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
@@ -78,7 +88,7 @@ params = {
qkv_bias=True,
guidance_embed=True,
),
"flux-schnell": FluxParams(
FluxVariantType.Schnell: FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
@@ -92,7 +102,7 @@ params = {
qkv_bias=True,
guidance_embed=False,
),
"flux-dev-fill": FluxParams(
FluxVariantType.DevFill: FluxParams(
in_channels=384,
out_channels=64,
vec_in_dim=768,
@@ -108,3 +118,10 @@ params = {
guidance_embed=True,
),
}
def get_flux_transformers_params(variant: AnyVariant):
try:
return _flux_transformer_params[variant]
except KeyError:
raise ValueError(f"Unknown variant for FLUX transformer params: {variant}")

View File

@@ -44,6 +44,7 @@ from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
FluxLoRAFormat,
FluxVariantType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
@@ -51,6 +52,7 @@ from invokeai.backend.model_manager.taxonomy import (
ModelVariantType,
SchedulerPredictionType,
SubModelType,
variant_type_adapter,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -71,7 +73,7 @@ DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
variant: AnyVariant | None = None
model_config = ConfigDict(protected_namespaces=())
@@ -111,6 +113,15 @@ class MatchSpeed(int, Enum):
SLOW = 2
class MatchCertainty(int, Enum):
"""Represents the certainty of a config's 'matches' method."""
NEVER = 0
MAYBE = 1
EXACT = 2
OVERRIDE = 3
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@@ -194,20 +205,47 @@ class ModelConfigBase(ABC, BaseModel):
Created to deprecate ModelProbe.probe
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(mod, hash_algo)
mod = ModelOnDisk(Path(mod), hash_algo)
candidates = ModelConfigBase.USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
overrides = overrides or {}
ModelConfigBase.cast_overrides(**overrides)
matches: dict[Type[ModelConfigBase], MatchCertainty] = {}
for config_cls in sorted_by_match_speed:
try:
if not config_cls.matches(mod):
score = config_cls.matches(mod, **overrides)
# A score of 0 means "no match"
if score is MatchCertainty.NEVER:
continue
matches[config_cls] = score
if score is MatchCertainty.EXACT or score is MatchCertainty.OVERRIDE:
# Perfect match - skip further checks
break
except Exception as e:
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
if matches:
# Select the config class with the highest score
sorted_by_score = sorted(matches.items(), key=lambda item: item[1].value)
# Check if there are multiple classes with the same top score
top_score = sorted_by_score[-1][1]
top_classes = [cls for cls, score in sorted_by_score if score is top_score]
if len(top_classes) > 1:
logger.warning(
f"Multiple model config classes matched with the same top score ({top_score}) for model {mod.name}: {[cls.__name__ for cls in top_classes]}. Using {top_classes[0].__name__}."
)
config_cls = top_classes[0]
# Finally, create the config instance
logger.info(f"Model {mod.name} classified as {config_cls.__name__} with score {top_score.name}")
return config_cls.from_model_on_disk(mod, **overrides)
if app_config.allow_unknown_models:
try:
@@ -234,13 +272,13 @@ class ModelConfigBase(ABC, BaseModel):
@classmethod
@abstractmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
"""Performs a quick check to determine if the config matches the model.
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
Returns a MatchCertainty score."""
pass
@staticmethod
def cast_overrides(overrides: dict[str, Any]):
def cast_overrides(**overrides):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])
@@ -255,13 +293,13 @@ class ModelConfigBase(ABC, BaseModel):
overrides["source_type"] = ModelSourceType(overrides["source_type"])
if "variant" in overrides:
overrides["variant"] = ModelVariantType(overrides["variant"])
overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
fields = cls.parse(mod)
cls.cast_overrides(overrides)
cls.cast_overrides(**overrides)
fields.update(overrides)
fields["path"] = mod.path.as_posix()
@@ -283,8 +321,8 @@ class UnknownModelConfig(ModelConfigBase):
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
return False
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
@@ -370,17 +408,34 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
is_omi_override = overrides.get("format") is ModelFormat.OMI
# If both type and format are overridden, skip the heuristic checks
if is_lora_override and is_omi_override:
return MatchCertainty.OVERRIDE
# OMI LoRAs are always files, never directories
if mod.path.is_dir():
return False
return MatchCertainty.NEVER
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return MatchCertainty.NEVER
metadata = mod.metadata()
return (
is_omi_lora_heuristic = (
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
)
if is_omi_lora_heuristic:
return MatchCertainty.EXACT
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
metadata = mod.metadata()
@@ -402,27 +457,55 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
is_omi_override = overrides.get("format") is ModelFormat.LyCORIS
# If both type and format are overridden, skip the heuristic checks and return a perfect score
if is_lora_override and is_omi_override:
return MatchCertainty.OVERRIDE
# LyCORIS LoRAs are always files, never directories
if mod.path.is_dir():
return False
return MatchCertainty.NEVER
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return False
return MatchCertainty.NEVER
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if isinstance(key, int):
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return True
# Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = key.startswith(
(
"lora_te_",
"lora_unet_",
"lora_te1_",
"lora_te2_",
"lora_transformer_",
)
)
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return True
has_key_with_lora_suffix = key.endswith(
(
"to_k_lora.up.weight",
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
)
)
return False
if has_key_with_lora_prefix or has_key_with_lora_suffix:
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
@@ -459,13 +542,31 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_lora_override = overrides.get("type") is ModelType.LoRA
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
# If both type and format are overridden, skip the heuristic checks and return a perfect score
if is_lora_override and is_diffusers_override:
return MatchCertainty.OVERRIDE
# Diffusers LoRAs are always directories, never files
if mod.path.is_file():
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
return MatchCertainty.NEVER
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
return any(wf.exists() for wf in weight_files)
has_lora_weight_file = any(wf.exists() for wf in weight_files)
if is_flux_lora_diffusers and has_lora_weight_file:
return MatchCertainty.EXACT
if is_flux_lora_diffusers or has_lora_weight_file:
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
@@ -520,7 +621,7 @@ class MainConfigBase(ABC, BaseModel):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: AnyVariant = ModelVariantType.Normal
variant: ModelVariantType | FluxVariantType = ModelVariantType.Normal
class VideoConfigBase(ABC, BaseModel):
@@ -529,7 +630,7 @@ class VideoConfigBase(ABC, BaseModel):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: AnyVariant = ModelVariantType.Normal
variant: ModelVariantType = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
@@ -538,6 +639,14 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixi
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
# @classmethod
# def matches(cls, mod: ModelOnDisk) -> bool:
# pass
# @classmethod
# def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
# pass
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
@@ -583,12 +692,51 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConf
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
variant: ClipVariantType = Field(description="Clip variant for this model")
variant: ClipVariantType = Field(...)
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: Literal[BaseModelType.Any] = BaseModelType.Any
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
clip_variant = cls.get_clip_variant_type(mod)
if clip_variant is None:
raise InvalidModelConfigException("Unable to determine CLIP variant type")
return {"variant": clip_variant}
@classmethod
def get_clip_variant_type(cls, mod: ModelOnDisk) -> ClipVariantType | None:
try:
with open(mod.path / "config.json") as file:
config = json.load(file)
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
except Exception:
return None
@classmethod
def is_clip_text_encoder(cls, mod: ModelOnDisk) -> bool:
try:
with open(mod.path / "config.json", "r") as file:
config = json.load(file)
architectures = config.get("architectures")
return architectures[0] in (
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
)
except Exception:
return False
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
variant: Literal[ClipVariantType.G] = ClipVariantType.G
@@ -597,8 +745,28 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, Mode
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
has_clip_variant_override = overrides.get("variant") is ClipVariantType.G
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
if is_clip_embed and clip_variant is ClipVariantType.G:
return MatchCertainty.EXACT
return MatchCertainty.NEVER
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: Literal[ClipVariantType.L] = ClipVariantType.L
@@ -607,6 +775,26 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, Mode
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
has_clip_variant_override = overrides.get("variant") is ClipVariantType.L
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
is_clip_embed = cls.is_clip_text_encoder(mod)
clip_variant = cls.get_clip_variant_type(mod)
if is_clip_embed and clip_variant is ClipVariantType.L:
return MatchCertainty.EXACT
return MatchCertainty.NEVER
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIPVision."""
@@ -649,22 +837,30 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_llava_override = overrides.get("type") is ModelType.LlavaOnevision
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
if is_llava_override and is_diffusers_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return False
return MatchCertainty.NEVER
config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False
return MatchCertainty.NEVER
architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
if architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration":
return MatchCertainty.EXACT
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
@@ -680,9 +876,9 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
# API models are not stored on disk, so we can't match them.
return False
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
@@ -695,9 +891,9 @@ class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
# API models are not stored on disk, so we can't match them.
return False
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
@@ -735,8 +931,7 @@ def get_model_discriminator_value(v: Any) -> str:
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
# To maintain compatibility, we default to ClipVariantType.L in this case.
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
variant_ = variant_ or ClipVariantType.L.value
if type_ == ModelType.CLIPEmbed.value:
return f"{type_}.{format_}.{variant_}"
return f"{type_}.{format_}"
@@ -779,7 +974,7 @@ AnyModelConfig = Annotated[
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
@@ -787,8 +982,8 @@ class ModelConfigFactory:
@staticmethod
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
model = AnyModelConfigValidator.validate_python(model_data)
if isinstance(model, CheckpointConfigBase) and timestamp:
model.converted_at = timestamp
validate_hash(model.hash)
return model # type: ignore
return model

View File

@@ -33,6 +33,7 @@ from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
FluxVariantType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
@@ -63,6 +64,7 @@ from invokeai.backend.util.silence_warnings import SilenceWarnings
CkptType = Dict[str | int, Any]
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: {
@@ -106,7 +108,7 @@ class ProbeBase(object):
"""Get model file format."""
raise NotImplementedError
def get_variant_type(self) -> Optional[ModelVariantType]:
def get_variant_type(self) -> AnyVariant | None:
"""Get model variant type."""
return None
@@ -256,7 +258,7 @@ class ModelProbe(object):
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
model_info = ModelConfigFactory.make_config(fields)
return model_info
@classmethod
@@ -580,7 +582,7 @@ class CheckpointProbeBase(ProbeBase):
return ModelFormat.GGUFQuantized
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
def get_variant_type(self) -> AnyVariant:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
base_type = self.get_base_type()
if model_type != ModelType.Main:
@@ -597,19 +599,26 @@ class CheckpointProbeBase(ProbeBase):
)
return ModelVariantType.Normal
is_flux_dev = (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
)
# FLUX Model variant types are distinguished by input channels:
# - Unquantized Dev and Schnell have in_channels=64
# - BNB-NF4 Dev and Schnell have in_channels=1
# - FLUX Fill has in_channels=384
# - Unsure of quantized FLUX Fill models
# - Unsure of GGUF-quantized models
if in_channels == 384:
if is_flux_dev and in_channels == 384:
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
# type is used to determine whether to use the fill model or the base model.
return ModelVariantType.Inpaint
else:
return FluxVariantType.DevFill
elif is_flux_dev:
# Fall back on "normal" variant type for all other FLUX models.
return ModelVariantType.Normal
return FluxVariantType.Dev
else:
return FluxVariantType.Schnell
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:

View File

@@ -33,10 +33,11 @@ from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
CLIPEmbedCheckpointConfig,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
@@ -56,6 +57,7 @@ from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.util.model_util import (
@@ -90,7 +92,7 @@ class FluxVAELoader(ModelLoader):
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = AutoEncoder(ae_params[config.config_path])
model = AutoEncoder(get_flux_ae_params())
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
@@ -107,7 +109,7 @@ class FluxVAELoader(ModelLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
class CLIPDiffusersLoader(ModelLoader):
"""Class to load main models."""
def _load_model(
@@ -129,6 +131,27 @@ class ClipCheckpointModel(ModelLoader):
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint)
class CLIPCheckpointLoader(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedCheckpointConfig):
raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.")
match submodel_type:
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(Path(config.path), use_safetensors=True)
case _:
raise ValueError(
f"Only TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
@@ -229,7 +252,7 @@ class FluxCheckpointModel(ModelLoader):
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = Flux(get_flux_transformers_params(config.variant))
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
@@ -271,7 +294,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = Flux(get_flux_transformers_params(config.variant))
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
@@ -322,7 +345,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = Flux(get_flux_transformers_params(config.variant))
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
@@ -362,7 +385,7 @@ class FluxControlnetModel(ModelLoader):
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = XLabsControlNetFlux(params["flux-dev"])
model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev))
model.load_state_dict(sd, assign=True)
return model

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
@dataclass(frozen=True)
class LegacyConfigKey:
type: ModelType
base: BaseModelType
variant: ModelVariantType | None = None
pred: SchedulerPredictionType | None = None
LEGACY_CONFIG_MAP: dict[LegacyConfigKey, str] = {
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion1,
ModelVariantType.Normal,
SchedulerPredictionType.Epsilon,
): "stable-diffusion/v1-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion1,
ModelVariantType.Normal,
SchedulerPredictionType.VPrediction,
): "stable-diffusion/v1-inference-v.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion1,
ModelVariantType.Inpaint,
): "stable-diffusion/v1-inpainting-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Normal,
SchedulerPredictionType.Epsilon,
): "stable-diffusion/v2-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Normal,
SchedulerPredictionType.VPrediction,
): "stable-diffusion/v2-inference-v.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Inpaint,
SchedulerPredictionType.Epsilon,
): "stable-diffusion/v2-inpainting-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Inpaint,
SchedulerPredictionType.VPrediction,
): "stable-diffusion/v2-inpainting-inference-v.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Depth,
): "stable-diffusion/v2-midas-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusionXL,
ModelVariantType.Normal,
): "stable-diffusion/sd_xl_base.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusionXL,
ModelVariantType.Inpaint,
): "stable-diffusion/sd_xl_inpaint.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusionXLRefiner,
ModelVariantType.Normal,
): "stable-diffusion/sd_xl_refiner.yaml",
LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion1): "controlnet/cldm_v15.yaml",
LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion2): "controlnet/cldm_v21.yaml",
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion1): "stable-diffusion/v1-inference.yaml",
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion2): "stable-diffusion/v2-inference.yaml",
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusionXL): "stable-diffusion/sd_xl_base.yaml",
}

View File

@@ -5,6 +5,7 @@ import diffusers
import onnxruntime as ort
import torch
from diffusers import ModelMixin
from pydantic import TypeAdapter
from invokeai.backend.raw_model import RawModel
@@ -30,6 +31,7 @@ class BaseModelType(str, Enum):
Imagen4 = "imagen4"
Gemini2_5 = "gemini-2.5"
ChatGPT4o = "chatgpt-4o"
# This is actually the FLUX Kontext API model. Local FLUX Kontext is just BaseModelType.Flux.
FluxKontext = "flux-kontext"
Veo3 = "veo3"
Runway = "runway"
@@ -92,6 +94,12 @@ class ModelVariantType(str, Enum):
Depth = "depth"
class FluxVariantType(str, Enum):
Schnell = "schnell"
Dev = "dev"
DevFill = "dev_fill"
class ModelFormat(str, Enum):
"""Storage format of model."""
@@ -149,4 +157,7 @@ class FluxLoRAFormat(str, Enum):
AIToolkit = "flux.aitoolkit"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
variant_type_adapter = TypeAdapter[ModelVariantType | ClipVariantType | FluxVariantType](
ModelVariantType | ClipVariantType | FluxVariantType
)

View File

@@ -12,7 +12,10 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
def is_state_dict_likely_in_flux_aitoolkit_format(
state_dict: dict[str, Any],
metadata: dict[str, Any] | None = None,
) -> bool:
if metadata:
try:
software = json.loads(metadata.get("software", "{}"))

View File

@@ -1,3 +1,5 @@
from typing import Any
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
@@ -14,7 +16,10 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
def flux_format_from_state_dict(
state_dict: dict[str, Any],
metadata: dict[str, Any] | None = None,
) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):

View File

@@ -4,7 +4,8 @@ import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
@@ -22,7 +23,7 @@ def main():
with log_time("Initialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
p = get_flux_transformers_params(ModelVariantType.FluxSchnell)
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():

View File

@@ -7,7 +7,8 @@ import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@@ -35,7 +36,7 @@ def main():
# inference_dtype = torch.bfloat16
with log_time("Initialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
p = get_flux_transformers_params(ModelVariantType.FluxSchnell)
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():

View File

@@ -223,6 +223,9 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record<ModelVariantType, string> = {
normal: 'Normal',
inpaint: 'Inpaint',
depth: 'Depth',
dev: 'FLUX Dev',
dev_fill: 'FLUX Dev Fill',
schnell: 'FLUX Schnell',
};
export const MODEL_FORMAT_TO_LONG_NAME: Record<ModelFormat, string> = {

View File

@@ -147,7 +147,7 @@ export const zSubModelType = z.enum([
]);
export const zClipVariantType = z.enum(['large', 'gigantic']);
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth', 'dev', 'dev_fill', 'schnell']);
export type ModelVariantType = z.infer<typeof zModelVariantType>;
export const zModelFormat = z.enum([
'omi',

View File

@@ -5490,7 +5490,7 @@ export type components = {
* Config Path
* @description path to the checkpoint model config file
*/
config_path: string;
config_path?: string | null;
/**
* Converted At
* @description When this model was last converted to diffusers
@@ -14818,7 +14818,7 @@ export type components = {
* Config Path
* @description path to the checkpoint model config file
*/
config_path: string;
config_path?: string | null;
/**
* Converted At
* @description When this model was last converted to diffusers
@@ -14927,7 +14927,7 @@ export type components = {
* Config Path
* @description path to the checkpoint model config file
*/
config_path: string;
config_path?: string | null;
/**
* Converted At
* @description When this model was last converted to diffusers
@@ -15128,7 +15128,7 @@ export type components = {
* Config Path
* @description path to the checkpoint model config file
*/
config_path: string;
config_path?: string | null;
/**
* Converted At
* @description When this model was last converted to diffusers
@@ -17487,7 +17487,7 @@ export type components = {
* @description Variant type.
* @enum {string}
*/
ModelVariantType: "normal" | "inpaint" | "depth";
ModelVariantType: "normal" | "inpaint" | "depth" | "flux_dev" | "flux_dev_fill" | "flux_schnell";
/**
* ModelsList
* @description Return list of configs.
@@ -22199,7 +22199,7 @@ export type components = {
* Config Path
* @description path to the checkpoint model config file
*/
config_path: string;
config_path?: string | null;
/**
* Converted At
* @description When this model was last converted to diffusers

View File

@@ -2,7 +2,8 @@ import accelerate
import pytest
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
_group_state_by_submodel,
is_state_dict_likely_in_flux_aitoolkit_format,
@@ -44,7 +45,7 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
# Initialize a FLUX model on the meta device.
with accelerate.init_empty_weights():
model = Flux(params["flux-schnell"])
model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell))
model_keys = set(model.state_dict().keys())
for converted_key_prefix in converted_key_prefixes:

View File

@@ -3,7 +3,8 @@ import pytest
import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
@@ -63,7 +64,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
# Initialize a FLUX model on the meta device.
with accelerate.init_empty_weights():
model = Flux(params["flux-dev"])
model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell))
model_keys = set(model.state_dict().keys())
# Assert that the converted state dict matches the keys in the actual model.

View File

@@ -115,7 +115,7 @@ class MinimalConfigExample(ModelConfigBase):
fun_quote: str
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
def matches(cls, mod: ModelOnDisk, **overrides) -> bool:
return mod.path.suffix == ".json"
@classmethod