From 4b52cc25461528fac9ca135bc7a008e1fed67e1f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:00:26 +1000 Subject: [PATCH] refactor: port MM probes to new api - Add concept of match certainty to new probe - Port CLIP Embed models to new API - Fiddle with stuff --- invokeai/app/invocations/flux_model_loader.py | 4 +- .../model_install/model_install_default.py | 7 +- .../model_records/model_records_base.py | 5 +- .../model_records/model_records_sql.py | 21 +- .../migrations/migration_22.py | 9 +- invokeai/backend/flux/util.py | 61 ++-- invokeai/backend/model_manager/config.py | 287 +++++++++++++++--- .../backend/model_manager/legacy_probe.py | 23 +- .../model_manager/load/model_loaders/flux.py | 37 ++- .../model_manager/single_file_config_files.py | 86 ++++++ invokeai/backend/model_manager/taxonomy.py | 13 +- .../flux_aitoolkit_lora_conversion_utils.py | 5 +- .../patches/lora_conversions/formats.py | 7 +- .../scripts/load_flux_model_bnb_llm_int8.py | 5 +- .../scripts/load_flux_model_bnb_nf4.py | 5 +- .../web/src/features/modelManagerV2/models.ts | 3 + .../web/src/features/nodes/types/common.ts | 2 +- .../frontend/web/src/services/api/schema.ts | 12 +- ...st_flux_aitoolkit_lora_conversion_utils.py | 5 +- .../test_flux_kohya_lora_conversion_utils.py | 5 +- tests/test_model_probe.py | 2 +- 21 files changed, 488 insertions(+), 116 deletions(-) create mode 100644 invokeai/backend/model_manager/single_file_config_files.py diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py index e5a1966c65..4ed3b91bc6 100644 --- a/invokeai/app/invocations/flux_model_loader.py +++ b/invokeai/app/invocations/flux_model_loader.py @@ -13,7 +13,7 @@ from invokeai.app.util.t5_model_identifier import ( preprocess_t5_encoder_model_identifier, preprocess_t5_tokenizer_model_identifier, ) -from invokeai.backend.flux.util import max_seq_lengths +from invokeai.backend.flux.util import get_flux_max_seq_length from invokeai.backend.model_manager.config import ( CheckpointConfigBase, ) @@ -94,5 +94,5 @@ class FluxModelLoaderInvocation(BaseInvocation): clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]), vae=VAEField(vae=vae), - max_seq_len=max_seq_lengths[transformer_config.config_path], + max_seq_len=get_flux_max_seq_length(transformer_config.variant), ) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 454697ea5a..5bc9af8e6b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -5,6 +5,7 @@ import os import re import threading import time +from copy import deepcopy from pathlib import Path from queue import Empty, Queue from shutil import move, rmtree @@ -370,6 +371,8 @@ class ModelInstallService(ModelInstallServiceBase): model_path = self.app_config.models_path / model.path if model_path.is_file() or model_path.is_symlink(): model_path.unlink() + assert model_path.parent != self.app_config.models_path + os.rmdir(model_path.parent) elif model_path.is_dir(): rmtree(model_path) self.unregister(key) @@ -607,9 +610,9 @@ class ModelInstallService(ModelInstallServiceBase): # any given model - eliminating ambiguity and removing reliance on order. # After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe` try: - return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore + return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore except InvalidModelConfigException: - return ModelConfigBase.classify(model_path, hash_algo, **fields) + return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields) def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 740d548a4a..48f5317536 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -21,6 +21,7 @@ from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.taxonomy import ( BaseModelType, ClipVariantType, + FluxVariantType, ModelFormat, ModelSourceType, ModelType, @@ -90,7 +91,9 @@ class ModelRecordChanges(BaseModelExcludeNull): # Checkpoint-specific changes # TODO(MM2): Should we expose these? Feels footgun-y... - variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None) + variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field( + description="The variant of the model.", default=None + ) prediction_type: Optional[SchedulerPredictionType] = Field( description="The prediction type of the model.", default=None ) diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index e3b24a6e62..9d4892a141 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -141,10 +141,25 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.transaction() as cursor: record = self.get_model(key) - # Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic. - for field_name in changes.model_fields_set: - setattr(record, field_name, getattr(changes, field_name)) + # The changes may mean the model config class changes. So we need to: + # + # 1. convert the existing record to a dict + # 2. apply the changes to the dict + # 3. create a new model config from the updated dict + # + # This way we ensure that the update does not inadvertently create an invalid model config. + # 1. convert the existing record to a dict + record_as_dict = record.model_dump() + + # 2. apply the changes to the dict + for field_name in changes.model_fields_set: + record_as_dict[field_name] = getattr(changes, field_name) + + # 3. create a new model config from the updated dict + record = ModelConfigFactory.make_config(record_as_dict) + + # If we get this far, the updated model config is valid, so we can save it to the database. json_serialized = record.model_dump_json() cursor.execute( diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py index c79b58bf2a..08b0e76068 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py @@ -8,7 +8,7 @@ from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator +from invokeai.backend.model_manager.config import AnyModelConfigValidator class NormalizeResult(NamedTuple): @@ -30,7 +30,7 @@ class Migration22Callback: for model_id, config_json in rows: try: # Get the model config as a pydantic object - config = self._load_model_config(config_json) + config = AnyModelConfigValidator.validate_json(config_json) except ValidationError: # This could happen if the config schema changed in a way that makes old configs invalid. Unlikely # for users, more likely for devs testing out migration paths. @@ -216,11 +216,6 @@ class Migration22Callback: self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir) - def _load_model_config(self, config_json: str) -> AnyModelConfig: - # The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function - # just makes that clear. - return AnyModelConfigValidator.validate_json(config_json) - def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: """Builds the migration object for migrating from version 21 to version 22. diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index 2a5261cb5c..2cf52b6ec1 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -1,10 +1,11 @@ # Initially pulled from https://github.com/black-forest-labs/flux from dataclasses import dataclass -from typing import Dict, Literal +from typing import Literal from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams +from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType @dataclass @@ -41,30 +42,39 @@ PREFERED_KONTEXT_RESOLUTIONS = [ ] -max_seq_lengths: Dict[str, Literal[256, 512]] = { - "flux-dev": 512, - "flux-dev-fill": 512, - "flux-schnell": 256, +_flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = { + FluxVariantType.Dev: 512, + FluxVariantType.DevFill: 512, + FluxVariantType.Schnell: 256, } -ae_params = { - "flux": AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - ) -} +def get_flux_max_seq_length(variant: AnyVariant): + try: + return _flux_max_seq_lengths[variant] + except KeyError: + raise ValueError(f"Unknown variant for FLUX max seq len: {variant}") -params = { - "flux-dev": FluxParams( +_flux_ae_params = AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, +) + + +def get_flux_ae_params() -> AutoEncoderParams: + return _flux_ae_params + + +_flux_transformer_params: dict[AnyVariant, FluxParams] = { + FluxVariantType.Dev: FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, @@ -78,7 +88,7 @@ params = { qkv_bias=True, guidance_embed=True, ), - "flux-schnell": FluxParams( + FluxVariantType.Schnell: FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, @@ -92,7 +102,7 @@ params = { qkv_bias=True, guidance_embed=False, ), - "flux-dev-fill": FluxParams( + FluxVariantType.DevFill: FluxParams( in_channels=384, out_channels=64, vec_in_dim=768, @@ -108,3 +118,10 @@ params = { guidance_embed=True, ), } + + +def get_flux_transformers_params(variant: AnyVariant): + try: + return _flux_transformer_params[variant] + except KeyError: + raise ValueError(f"Unknown variant for FLUX transformer params: {variant}") diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 83f0c1d2bf..b552328153 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -44,6 +44,7 @@ from invokeai.backend.model_manager.taxonomy import ( BaseModelType, ClipVariantType, FluxLoRAFormat, + FluxVariantType, ModelFormat, ModelRepoVariant, ModelSourceType, @@ -51,6 +52,7 @@ from invokeai.backend.model_manager.taxonomy import ( ModelVariantType, SchedulerPredictionType, SubModelType, + variant_type_adapter, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -71,7 +73,7 @@ DEFAULTS_PRECISION = Literal["fp16", "fp32"] class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType - variant: AnyVariant = None + variant: AnyVariant | None = None model_config = ConfigDict(protected_namespaces=()) @@ -111,6 +113,15 @@ class MatchSpeed(int, Enum): SLOW = 2 +class MatchCertainty(int, Enum): + """Represents the certainty of a config's 'matches' method.""" + + NEVER = 0 + MAYBE = 1 + EXACT = 2 + OVERRIDE = 3 + + class LegacyProbeMixin: """Mixin for classes using the legacy probe for model classification.""" @@ -194,20 +205,47 @@ class ModelConfigBase(ABC, BaseModel): Created to deprecate ModelProbe.probe """ if isinstance(mod, Path | str): - mod = ModelOnDisk(mod, hash_algo) + mod = ModelOnDisk(Path(mod), hash_algo) candidates = ModelConfigBase.USING_CLASSIFY_API sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__)) + overrides = overrides or {} + ModelConfigBase.cast_overrides(**overrides) + + matches: dict[Type[ModelConfigBase], MatchCertainty] = {} + for config_cls in sorted_by_match_speed: try: - if not config_cls.matches(mod): + score = config_cls.matches(mod, **overrides) + + # A score of 0 means "no match" + if score is MatchCertainty.NEVER: continue + + matches[config_cls] = score + + if score is MatchCertainty.EXACT or score is MatchCertainty.OVERRIDE: + # Perfect match - skip further checks + break except Exception as e: logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}") continue - else: - return config_cls.from_model_on_disk(mod, **overrides) + + if matches: + # Select the config class with the highest score + sorted_by_score = sorted(matches.items(), key=lambda item: item[1].value) + # Check if there are multiple classes with the same top score + top_score = sorted_by_score[-1][1] + top_classes = [cls for cls, score in sorted_by_score if score is top_score] + if len(top_classes) > 1: + logger.warning( + f"Multiple model config classes matched with the same top score ({top_score}) for model {mod.name}: {[cls.__name__ for cls in top_classes]}. Using {top_classes[0].__name__}." + ) + config_cls = top_classes[0] + # Finally, create the config instance + logger.info(f"Model {mod.name} classified as {config_cls.__name__} with score {top_score.name}") + return config_cls.from_model_on_disk(mod, **overrides) if app_config.allow_unknown_models: try: @@ -234,13 +272,13 @@ class ModelConfigBase(ABC, BaseModel): @classmethod @abstractmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: """Performs a quick check to determine if the config matches the model. - This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing.""" + Returns a MatchCertainty score.""" pass @staticmethod - def cast_overrides(overrides: dict[str, Any]): + def cast_overrides(**overrides): """Casts user overrides from str to Enum""" if "type" in overrides: overrides["type"] = ModelType(overrides["type"]) @@ -255,13 +293,13 @@ class ModelConfigBase(ABC, BaseModel): overrides["source_type"] = ModelSourceType(overrides["source_type"]) if "variant" in overrides: - overrides["variant"] = ModelVariantType(overrides["variant"]) + overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"]) @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, **overrides): """Creates an instance of this config or raises InvalidModelConfigException.""" fields = cls.parse(mod) - cls.cast_overrides(overrides) + cls.cast_overrides(**overrides) fields.update(overrides) fields["path"] = mod.path.as_posix() @@ -283,8 +321,8 @@ class UnknownModelConfig(ModelConfigBase): format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - return False + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + return MatchCertainty.NEVER @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: @@ -370,17 +408,34 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.OMI] = ModelFormat.OMI @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_lora_override = overrides.get("type") is ModelType.LoRA + is_omi_override = overrides.get("format") is ModelFormat.OMI + + # If both type and format are overridden, skip the heuristic checks + if is_lora_override and is_omi_override: + return MatchCertainty.OVERRIDE + + # OMI LoRAs are always files, never directories if mod.path.is_dir(): - return False + return MatchCertainty.NEVER + + # Avoid false positive match against ControlLoRA and Diffusers + if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + return MatchCertainty.NEVER metadata = mod.metadata() - return ( + is_omi_lora_heuristic = ( bool(metadata.get("modelspec.sai_model_spec")) and metadata.get("ot_branch") == "omi_format" - and metadata["modelspec.architecture"].split("/")[1].lower() == "lora" + and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" ) + if is_omi_lora_heuristic: + return MatchCertainty.EXACT + + return MatchCertainty.NEVER + @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: metadata = mod.metadata() @@ -402,27 +457,55 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_lora_override = overrides.get("type") is ModelType.LoRA + is_omi_override = overrides.get("format") is ModelFormat.LyCORIS + + # If both type and format are overridden, skip the heuristic checks and return a perfect score + if is_lora_override and is_omi_override: + return MatchCertainty.OVERRIDE + + # LyCORIS LoRAs are always files, never directories if mod.path.is_dir(): - return False + return MatchCertainty.NEVER # Avoid false positive match against ControlLoRA and Diffusers if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - return False + return MatchCertainty.NEVER state_dict = mod.load_state_dict() for key in state_dict.keys(): if isinstance(key, int): continue - if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): - return True + # Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. + # Some main models have these keys, likely due to the creator merging in a LoRA. + + has_key_with_lora_prefix = key.startswith( + ( + "lora_te_", + "lora_unet_", + "lora_te1_", + "lora_te2_", + "lora_transformer_", + ) + ) + # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")): - return True + has_key_with_lora_suffix = key.endswith( + ( + "to_k_lora.up.weight", + "to_q_lora.down.weight", + "lora_A.weight", + "lora_B.weight", + ) + ) - return False + if has_key_with_lora_prefix or has_key_with_lora_suffix: + return MatchCertainty.MAYBE + + return MatchCertainty.NEVER @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: @@ -459,13 +542,31 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_lora_override = overrides.get("type") is ModelType.LoRA + is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers + + # If both type and format are overridden, skip the heuristic checks and return a perfect score + if is_lora_override and is_diffusers_override: + return MatchCertainty.OVERRIDE + + # Diffusers LoRAs are always directories, never files if mod.path.is_file(): - return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers + return MatchCertainty.NEVER + + is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers suffixes = ["bin", "safetensors"] weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] - return any(wf.exists() for wf in weight_files) + has_lora_weight_file = any(wf.exists() for wf in weight_files) + + if is_flux_lora_diffusers and has_lora_weight_file: + return MatchCertainty.EXACT + + if is_flux_lora_diffusers or has_lora_weight_file: + return MatchCertainty.MAYBE + + return MatchCertainty.NEVER @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: @@ -520,7 +621,7 @@ class MainConfigBase(ABC, BaseModel): default_settings: Optional[MainModelDefaultSettings] = Field( description="Default settings for this model", default=None ) - variant: AnyVariant = ModelVariantType.Normal + variant: ModelVariantType | FluxVariantType = ModelVariantType.Normal class VideoConfigBase(ABC, BaseModel): @@ -529,7 +630,7 @@ class VideoConfigBase(ABC, BaseModel): default_settings: Optional[MainModelDefaultSettings] = Field( description="Default settings for this model", default=None ) - variant: AnyVariant = ModelVariantType.Normal + variant: ModelVariantType = ModelVariantType.Normal class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): @@ -538,6 +639,14 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixi prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False + # @classmethod + # def matches(cls, mod: ModelOnDisk) -> bool: + # pass + + # @classmethod + # def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + # pass + class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for main checkpoint models.""" @@ -583,12 +692,51 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConf class CLIPEmbedDiffusersConfig(DiffusersConfigBase): """Model config for Clip Embeddings.""" - variant: ClipVariantType = Field(description="Clip variant for this model") + variant: ClipVariantType = Field(...) type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + base: Literal[BaseModelType.Any] = BaseModelType.Any + + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + clip_variant = cls.get_clip_variant_type(mod) + if clip_variant is None: + raise InvalidModelConfigException("Unable to determine CLIP variant type") + + return {"variant": clip_variant} + + @classmethod + def get_clip_variant_type(cls, mod: ModelOnDisk) -> ClipVariantType | None: + try: + with open(mod.path / "config.json") as file: + config = json.load(file) + hidden_size = config.get("hidden_size") + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None + except Exception: + return None + + @classmethod + def is_clip_text_encoder(cls, mod: ModelOnDisk) -> bool: + try: + with open(mod.path / "config.json", "r") as file: + config = json.load(file) + architectures = config.get("architectures") + return architectures[0] in ( + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + ) + except Exception: + return False -class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase): +class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-G Embeddings.""" variant: Literal[ClipVariantType.G] = ClipVariantType.G @@ -597,8 +745,28 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, Mode def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed + is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers + has_clip_variant_override = overrides.get("variant") is ClipVariantType.G -class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase): + if is_clip_embed_override and is_diffusers_override and has_clip_variant_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_file(): + return MatchCertainty.NEVER + + is_clip_embed = cls.is_clip_text_encoder(mod) + clip_variant = cls.get_clip_variant_type(mod) + + if is_clip_embed and clip_variant is ClipVariantType.G: + return MatchCertainty.EXACT + + return MatchCertainty.NEVER + + +class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): """Model config for CLIP-L Embeddings.""" variant: Literal[ClipVariantType.L] = ClipVariantType.L @@ -607,6 +775,26 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, Mode def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed + is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers + has_clip_variant_override = overrides.get("variant") is ClipVariantType.L + + if is_clip_embed_override and is_diffusers_override and has_clip_variant_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_file(): + return MatchCertainty.NEVER + + is_clip_embed = cls.is_clip_text_encoder(mod) + clip_variant = cls.get_clip_variant_type(mod) + + if is_clip_embed and clip_variant is ClipVariantType.L: + return MatchCertainty.EXACT + + return MatchCertainty.NEVER + class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for CLIPVision.""" @@ -649,22 +837,30 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): """Model config for Llava Onevision models.""" type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_llava_override = overrides.get("type") is ModelType.LlavaOnevision + is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers + + if is_llava_override and is_diffusers_override: + return MatchCertainty.OVERRIDE + if mod.path.is_file(): - return False + return MatchCertainty.NEVER config_path = mod.path / "config.json" try: with open(config_path, "r") as file: config = json.load(file) except FileNotFoundError: - return False + return MatchCertainty.NEVER architectures = config.get("architectures") - return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration" + if architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration": + return MatchCertainty.EXACT + + return MatchCertainty.NEVER @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: @@ -680,9 +876,9 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase): format: Literal[ModelFormat.Api] = ModelFormat.Api @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: # API models are not stored on disk, so we can't match them. - return False + return MatchCertainty.NEVER @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: @@ -695,9 +891,9 @@ class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): format: Literal[ModelFormat.Api] = ModelFormat.Api @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: # API models are not stored on disk, so we can't match them. - return False + return MatchCertainty.NEVER @classmethod def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: @@ -735,8 +931,7 @@ def get_model_discriminator_value(v: Any) -> str: # Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field. # To maintain compatibility, we default to ClipVariantType.L in this case. - if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value: - variant_ = variant_ or ClipVariantType.L.value + if type_ == ModelType.CLIPEmbed.value: return f"{type_}.{format_}.{variant_}" return f"{type_}.{format_}" @@ -779,7 +974,7 @@ AnyModelConfig = Annotated[ Discriminator(get_model_discriminator_value), ] -AnyModelConfigValidator = TypeAdapter(AnyModelConfig) +AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] @@ -787,8 +982,8 @@ class ModelConfigFactory: @staticmethod def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig: """Return the appropriate config object from raw dict values.""" - model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + model = AnyModelConfigValidator.validate_python(model_data) if isinstance(model, CheckpointConfigBase) and timestamp: model.converted_at = timestamp validate_hash(model.hash) - return model # type: ignore + return model diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 36fd82667d..3d3915353d 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -33,6 +33,7 @@ from invokeai.backend.model_manager.model_on_disk import ModelOnDisk from invokeai.backend.model_manager.taxonomy import ( AnyVariant, BaseModelType, + FluxVariantType, ModelFormat, ModelRepoVariant, ModelSourceType, @@ -63,6 +64,7 @@ from invokeai.backend.util.silence_warnings import SilenceWarnings CkptType = Dict[str | int, Any] + LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { BaseModelType.StableDiffusion1: { ModelVariantType.Normal: { @@ -106,7 +108,7 @@ class ProbeBase(object): """Get model file format.""" raise NotImplementedError - def get_variant_type(self) -> Optional[ModelVariantType]: + def get_variant_type(self) -> AnyVariant | None: """Get model variant type.""" return None @@ -256,7 +258,7 @@ class ModelProbe(object): if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels): fields["submodels"] = get_submodels() - model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None)) + model_info = ModelConfigFactory.make_config(fields) return model_info @classmethod @@ -580,7 +582,7 @@ class CheckpointProbeBase(ProbeBase): return ModelFormat.GGUFQuantized return ModelFormat("checkpoint") - def get_variant_type(self) -> ModelVariantType: + def get_variant_type(self) -> AnyVariant: model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint) base_type = self.get_base_type() if model_type != ModelType.Main: @@ -597,19 +599,26 @@ class CheckpointProbeBase(ProbeBase): ) return ModelVariantType.Normal + is_flux_dev = ( + "guidance_in.out_layer.weight" in state_dict + or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict + ) + # FLUX Model variant types are distinguished by input channels: # - Unquantized Dev and Schnell have in_channels=64 # - BNB-NF4 Dev and Schnell have in_channels=1 # - FLUX Fill has in_channels=384 # - Unsure of quantized FLUX Fill models # - Unsure of GGUF-quantized models - if in_channels == 384: + if is_flux_dev and in_channels == 384: # This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant # type is used to determine whether to use the fill model or the base model. - return ModelVariantType.Inpaint - else: + return FluxVariantType.DevFill + elif is_flux_dev: # Fall back on "normal" variant type for all other FLUX models. - return ModelVariantType.Normal + return FluxVariantType.Dev + else: + return FluxVariantType.Schnell in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] if in_channels == 9: diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 6ea7b53925..ccb962e747 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -33,10 +33,11 @@ from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import ( from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel -from invokeai.backend.flux.util import ae_params, params +from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params from invokeai.backend.model_manager.config import ( AnyModelConfig, CheckpointConfigBase, + CLIPEmbedCheckpointConfig, CLIPEmbedDiffusersConfig, ControlNetCheckpointConfig, ControlNetDiffusersConfig, @@ -56,6 +57,7 @@ from invokeai.backend.model_manager.taxonomy import ( BaseModelType, ModelFormat, ModelType, + ModelVariantType, SubModelType, ) from invokeai.backend.model_manager.util.model_util import ( @@ -90,7 +92,7 @@ class FluxVAELoader(ModelLoader): model_path = Path(config.path) with accelerate.init_empty_weights(): - model = AutoEncoder(ae_params[config.config_path]) + model = AutoEncoder(get_flux_ae_params()) sd = load_file(model_path) model.load_state_dict(sd, assign=True) # VAE is broken in float16, which mps defaults to @@ -107,7 +109,7 @@ class FluxVAELoader(ModelLoader): @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) -class ClipCheckpointModel(ModelLoader): +class CLIPDiffusersLoader(ModelLoader): """Class to load main models.""" def _load_model( @@ -129,6 +131,27 @@ class ClipCheckpointModel(ModelLoader): ) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint) +class CLIPCheckpointLoader(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CLIPEmbedCheckpointConfig): + raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.") + + match submodel_type: + case SubModelType.TextEncoder: + return CLIPTextModel.from_pretrained(Path(config.path), use_safetensors=True) + case _: + raise ValueError( + f"Only TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b) class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader): """Class to load main models.""" @@ -229,7 +252,7 @@ class FluxCheckpointModel(ModelLoader): model_path = Path(config.path) with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) + model = Flux(get_flux_transformers_params(config.variant)) sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: @@ -271,7 +294,7 @@ class FluxGGUFCheckpointModel(ModelLoader): model_path = Path(config.path) with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) + model = Flux(get_flux_transformers_params(config.variant)) # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) @@ -322,7 +345,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): with SilenceWarnings(): with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) + model = Flux(get_flux_transformers_params(config.variant)) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: @@ -362,7 +385,7 @@ class FluxControlnetModel(ModelLoader): def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: with accelerate.init_empty_weights(): # HACK(ryand): Is it safe to assume dev here? - model = XLabsControlNetFlux(params["flux-dev"]) + model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev)) model.load_state_dict(sd, assign=True) return model diff --git a/invokeai/backend/model_manager/single_file_config_files.py b/invokeai/backend/model_manager/single_file_config_files.py new file mode 100644 index 0000000000..22fe646b55 --- /dev/null +++ b/invokeai/backend/model_manager/single_file_config_files.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass + +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelType, + ModelVariantType, + SchedulerPredictionType, +) + + +@dataclass(frozen=True) +class LegacyConfigKey: + type: ModelType + base: BaseModelType + variant: ModelVariantType | None = None + pred: SchedulerPredictionType | None = None + + +LEGACY_CONFIG_MAP: dict[LegacyConfigKey, str] = { + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Normal, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v1-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Normal, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v1-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Inpaint, + ): "stable-diffusion/v1-inpainting-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Normal, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v2-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Normal, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v2-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Inpaint, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v2-inpainting-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Inpaint, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v2-inpainting-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Depth, + ): "stable-diffusion/v2-midas-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXL, + ModelVariantType.Normal, + ): "stable-diffusion/sd_xl_base.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXL, + ModelVariantType.Inpaint, + ): "stable-diffusion/sd_xl_inpaint.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXLRefiner, + ModelVariantType.Normal, + ): "stable-diffusion/sd_xl_refiner.yaml", + LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion1): "controlnet/cldm_v15.yaml", + LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion2): "controlnet/cldm_v21.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion1): "stable-diffusion/v1-inference.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion2): "stable-diffusion/v2-inference.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusionXL): "stable-diffusion/sd_xl_base.yaml", +} diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 07f8c8f5de..15a3d3ce72 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -5,6 +5,7 @@ import diffusers import onnxruntime as ort import torch from diffusers import ModelMixin +from pydantic import TypeAdapter from invokeai.backend.raw_model import RawModel @@ -30,6 +31,7 @@ class BaseModelType(str, Enum): Imagen4 = "imagen4" Gemini2_5 = "gemini-2.5" ChatGPT4o = "chatgpt-4o" + # This is actually the FLUX Kontext API model. Local FLUX Kontext is just BaseModelType.Flux. FluxKontext = "flux-kontext" Veo3 = "veo3" Runway = "runway" @@ -92,6 +94,12 @@ class ModelVariantType(str, Enum): Depth = "depth" +class FluxVariantType(str, Enum): + Schnell = "schnell" + Dev = "dev" + DevFill = "dev_fill" + + class ModelFormat(str, Enum): """Storage format of model.""" @@ -149,4 +157,7 @@ class FluxLoRAFormat(str, Enum): AIToolkit = "flux.aitoolkit" -AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None] +AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType] +variant_type_adapter = TypeAdapter[ModelVariantType | ClipVariantType | FluxVariantType]( + ModelVariantType | ClipVariantType | FluxVariantType +) diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py index 6ca06a0355..db218d14bb 100644 --- a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -12,7 +12,10 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.util import InvokeAILogger -def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: +def is_state_dict_likely_in_flux_aitoolkit_format( + state_dict: dict[str, Any], + metadata: dict[str, Any] | None = None, +) -> bool: if metadata: try: software = json.loads(metadata.get("software", "{}")) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 94f71e05ee..4fe6eb8772 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -1,3 +1,5 @@ +from typing import Any + from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( is_state_dict_likely_in_flux_aitoolkit_format, @@ -14,7 +16,10 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u ) -def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None: +def flux_format_from_state_dict( + state_dict: dict[str, Any], + metadata: dict[str, Any] | None = None, +) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py index 045ebbbf2c..8231e313fd 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py @@ -4,7 +4,8 @@ import accelerate from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time @@ -22,7 +23,7 @@ def main(): with log_time("Initialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - p = params["flux-schnell"] + p = get_flux_transformers_params(ModelVariantType.FluxSchnell) # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py index c8802b9e49..6a4ee3abf9 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py @@ -7,7 +7,8 @@ import torch from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 @@ -35,7 +36,7 @@ def main(): # inference_dtype = torch.bfloat16 with log_time("Initialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - p = params["flux-schnell"] + p = get_flux_transformers_params(ModelVariantType.FluxSchnell) # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index ec4ddf1a1d..85ab9b126e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -223,6 +223,9 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record = { normal: 'Normal', inpaint: 'Inpaint', depth: 'Depth', + dev: 'FLUX Dev', + dev_fill: 'FLUX Dev Fill', + schnell: 'FLUX Schnell', }; export const MODEL_FORMAT_TO_LONG_NAME: Record = { diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 4b97c2145d..96356f7cb4 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -147,7 +147,7 @@ export const zSubModelType = z.enum([ ]); export const zClipVariantType = z.enum(['large', 'gigantic']); -export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']); +export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth', 'dev', 'dev_fill', 'schnell']); export type ModelVariantType = z.infer; export const zModelFormat = z.enum([ 'omi', diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 533ff95d8d..09b13c8d52 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -5490,7 +5490,7 @@ export type components = { * Config Path * @description path to the checkpoint model config file */ - config_path: string; + config_path?: string | null; /** * Converted At * @description When this model was last converted to diffusers @@ -14818,7 +14818,7 @@ export type components = { * Config Path * @description path to the checkpoint model config file */ - config_path: string; + config_path?: string | null; /** * Converted At * @description When this model was last converted to diffusers @@ -14927,7 +14927,7 @@ export type components = { * Config Path * @description path to the checkpoint model config file */ - config_path: string; + config_path?: string | null; /** * Converted At * @description When this model was last converted to diffusers @@ -15128,7 +15128,7 @@ export type components = { * Config Path * @description path to the checkpoint model config file */ - config_path: string; + config_path?: string | null; /** * Converted At * @description When this model was last converted to diffusers @@ -17487,7 +17487,7 @@ export type components = { * @description Variant type. * @enum {string} */ - ModelVariantType: "normal" | "inpaint" | "depth"; + ModelVariantType: "normal" | "inpaint" | "depth" | "flux_dev" | "flux_dev_fill" | "flux_schnell"; /** * ModelsList * @description Return list of configs. @@ -22199,7 +22199,7 @@ export type components = { * Config Path * @description path to the checkpoint model config file */ - config_path: string; + config_path?: string | null; /** * Converted At * @description When this model was last converted to diffusers diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py index ed3e05a9b2..051ed210cd 100644 --- a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -2,7 +2,8 @@ import accelerate import pytest from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( _group_state_by_submodel, is_state_dict_likely_in_flux_aitoolkit_format, @@ -44,7 +45,7 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format(): # Initialize a FLUX model on the meta device. with accelerate.init_empty_weights(): - model = Flux(params["flux-schnell"]) + model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell)) model_keys = set(model.state_dict().keys()) for converted_key_prefix in converted_key_prefixes: diff --git a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py index 52b8ecc9c9..eb8846f456 100644 --- a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py @@ -3,7 +3,8 @@ import pytest import torch from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, @@ -63,7 +64,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format(): # Initialize a FLUX model on the meta device. with accelerate.init_empty_weights(): - model = Flux(params["flux-dev"]) + model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell)) model_keys = set(model.state_dict().keys()) # Assert that the converted state dict matches the keys in the actual model. diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 8ee4f8df1f..8112ccdd19 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -115,7 +115,7 @@ class MinimalConfigExample(ModelConfigBase): fun_quote: str @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> bool: return mod.path.suffix == ".json" @classmethod