refactor(mm): continue iterating on config

This commit is contained in:
psychedelicious
2025-09-25 20:08:48 +10:00
parent 7ca0a0a0fd
commit eaddd6f533
18 changed files with 1370 additions and 1118 deletions

View File

@@ -24,8 +24,9 @@ class ModelIdentifierField(BaseModel):
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
submodel_type: SubModelType | None = Field(
description="The submodel to load, if this is a main model",
default=None,
)
@classmethod

View File

@@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.model_manager.config import AnyModelConfigValidator
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@@ -115,6 +116,13 @@ def get_openapi_func(
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
any_model_config_schema = AnyModelConfigValidator.json_schema(
mode="serialization",
ref_template="#/components/schemas/{model}",
)
move_defs_to_top_level(openapi_schema, any_model_config_schema)
openapi_schema["components"]["schemas"]["AnyModelConfig"] = any_model_config_schema
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)

View File

@@ -20,7 +20,6 @@ Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import re
@@ -111,7 +110,9 @@ def _get_config_or_raise(
raise NotAMatch(config_class, f"missing config file: {config_path}")
try:
config = load_json(config_path)
with open(config_path, "r") as file:
config = json.load(file)
return config
except Exception as e:
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
@@ -165,7 +166,6 @@ def _validate_overrides(
)
def _raise_if_not_file(
config_class: type,
mod: ModelOnDisk,
@@ -245,45 +245,75 @@ class ModelConfigBase(ABC, BaseModel):
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
@staticmethod
def json_schema_extra(schema: dict[str, Any]) -> None:
schema["required"].extend(["key", "base", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
key: str = Field(
description="A unique key for this model.",
default_factory=uuid_string,
)
hash: str = Field(
description="The hash of the model file(s).",
)
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.",
)
file_size: int = Field(description="The size of the model in bytes.")
name: str = Field(description="Name of the model.")
type: ModelType = Field(description="Model type")
format: ModelFormat = Field(description="Model format")
base: BaseModelType = Field(description="The base model.")
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
description: Optional[str] = Field(description="Model description", default=None)
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
file_size: int = Field(
description="The size of the model in bytes.",
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
name: str = Field(
description="Name of the model.",
)
description: str | None = Field(
description="Model description",
default=None,
)
source: str = Field(
description="The original source of the model (path, URL or repo_id).",
)
source_type: ModelSourceType = Field(
description="The type of source",
)
source_api_response: str | None = Field(
description="The original API response from the source, as stringified JSON.",
default=None,
)
cover_image: str | None = Field(
description="Url for image to preview model",
default=None,
)
submodels: dict[SubModelType, SubmodelDefinition] | None = Field(
description="Loadable submodels in this model",
default=None,
)
usage_info: str | None = Field(
default=None,
description="Usage information for this model",
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
USING_LEGACY_PROBE: ClassVar[set[Type["AnyModelConfig"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["AnyModelConfig"]]] = set()
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
json_schema_mode_override="serialization",
)
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase.USING_CLASSIFY_API.add(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# Ensure that subclasses define 'base', 'type' and 'format' fields. These are not in this base class, because
# subclasses may redefine them as different types, causing type-checking issues.
assert "base" in cls.model_fields, f"{cls.__name__} must define a 'base' field"
assert "type" in cls.model_fields, f"{cls.__name__} must define a 'type' field"
assert "format" in cls.model_fields, f"{cls.__name__} must define a 'format' field"
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
@@ -305,9 +335,9 @@ class ModelConfigBase(ABC, BaseModel):
class UnknownModelConfig(ModelConfigBase):
base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown
type: Literal[ModelType.Unknown] = ModelType.Unknown
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
@@ -317,14 +347,7 @@ class UnknownModelConfig(ModelConfigBase):
class CheckpointConfigBase(ABC, BaseModel):
"""Base class for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model",
default=ModelFormat.Checkpoint,
)
config_path: str | None = Field(
description="path to the checkpoint model config file",
default=None,
)
config_path: str | None = Field(None, description="Path to the config for this model, if any.")
converted_at: float | None = Field(
description="When this model was last converted to diffusers",
default_factory=time.time,
@@ -334,66 +357,14 @@ class CheckpointConfigBase(ABC, BaseModel):
class DiffusersConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default)
class LoRAConfigBase(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[LoraModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
return value
@classmethod
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
if cls.flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException("Unknown LoRA type")
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
base: Literal[BaseModelType.Any] = BaseModelType.Any
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
def load_json(path: Path) -> dict[str, Any]:
with open(path, "r") as file:
return json.load(file)
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
class T5EncoderConfig(ModelConfigBase):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T5Encoder,
@@ -421,8 +392,10 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
return cls(**fields)
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T5Encoder,
@@ -454,8 +427,32 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
return cls(**fields)
class LoRAConfigBase(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[LoraModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
def get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
return value
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
base: Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL] = Field()
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
@@ -470,7 +467,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
# Heuristic: Look for OMI LoRA metadata
@@ -489,7 +486,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
@@ -501,10 +498,20 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
LoRALyCORIS_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
base: LoRALyCORIS_SupportedBases = Field()
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
@@ -518,7 +525,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
@@ -547,33 +554,69 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
return cls(**fields)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases:
if get_flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
ControlLoRALyCORIS_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
base: ControlLoRALyCORIS_SupportedBases = Field()
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
trigger_phrases: set[str] | None = Field(None)
ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: ControlLoRADiffusers_SupportedBases = Field()
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
trigger_phrases: set[str] | None = Field(None)
LoRADiffusers_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: LoRADiffusers_SupportedBases = Field()
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
@@ -587,7 +630,7 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
is_flux_lora_diffusers = get_flux_lora_format(mod) is FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
@@ -599,20 +642,33 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
return cls(**fields)
class VAEConfigBase(ABC, BaseModel):
type: Literal[ModelType.VAE] = ModelType.VAE
VAECheckpointConfig_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
"""Model config for standalone VAE models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
base: VAECheckpointConfig_SupportedBases = Field()
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.VAE,
"format": ModelFormat.Checkpoint,
}
REGEX_TO_BASE: ClassVar[dict[str, VAECheckpointConfig_SupportedBases]] = {
r"xl": BaseModelType.StableDiffusionXL,
r"sd2": BaseModelType.StableDiffusion2,
r"vae": BaseModelType.StableDiffusion1,
r"FLUX.1-schnell_ae": BaseModelType.Flux,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_raise_if_not_file(cls, mod)
@@ -626,24 +682,27 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAECheckpointConfig_SupportedBases:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
for regexp, base in cls.REGEX_TO_BASE.items():
if re.search(regexp, mod.path.name, re.IGNORECASE):
return basetype
return base
raise NotAMatch(cls, "cannot determine base type")
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
VAEDiffusersConfig_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusionXL,
]
class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: VAEDiffusersConfig_SupportedBases = Field()
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.VAE,
@@ -685,7 +744,7 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
return name
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
config = _get_config_or_raise(cls, mod.path / "config.json")
if cls._config_looks_like_sdxl(config):
return BaseModelType.StableDiffusionXL
@@ -696,21 +755,48 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
return BaseModelType.StableDiffusion1
ControlNetDiffusers_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: ControlNetDiffusers_SupportedBases = Field()
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
ControlNetCheckpoint_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
base: ControlNetDiffusers_SupportedBases = Field()
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
TextualInversion_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]
class TextualInversionConfigBase(ABC, BaseModel):
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
base: TextualInversion_SupportedBases = Field()
type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
@@ -743,7 +829,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
return False
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> TextualInversion_SupportedBases:
p = path or mod.path
try:
@@ -777,7 +863,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
"""Model config for textual inversion embeddings."""
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.TextualInversion,
@@ -804,7 +890,7 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
"""Model config for textual inversion embeddings."""
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.TextualInversion,
@@ -830,65 +916,89 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
class MainConfigBase(ABC, BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType | FluxVariantType = ModelVariantType.Normal
variant: ModelVariantType | FluxVariantType = Field()
class VideoConfigBase(ABC, BaseModel):
type: Literal[ModelType.Video] = ModelType.Video
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType = ModelVariantType.Normal
MainCheckpointConfigBase_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusion3,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusionXLRefiner,
BaseModelType.Flux,
BaseModelType.CogView4,
]
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
base: MainCheckpointConfigBase_SupportedBases = Field()
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = Field(False)
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = False
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
prediction_type: SchedulerPredictionType = Field(default=SchedulerPredictionType.Epsilon)
upcast_attention: bool = Field(False)
MainDiffusers_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusion3,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusionXLRefiner,
BaseModelType.Flux,
BaseModelType.CogView4,
]
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main diffusers models."""
pass
base: MainDiffusers_SupportedBases = Field()
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
IPAdapterInvokeAI_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""
base: IPAdapterInvokeAI_SupportedBases = Field()
format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
image_encoder_model_id: str = Field()
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
@@ -913,7 +1023,7 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAI_SupportedBases:
state_dict = mod.load_state_dict()
try:
@@ -932,12 +1042,19 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
IPAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
base: IPAdapterCheckpoint_SupportedBases = Field()
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.IPAdapter,
@@ -964,7 +1081,7 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBases:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_ip_adapter(state_dict):
@@ -989,10 +1106,9 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
variant: ClipVariantType = Field(...)
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: Literal[BaseModelType.Any] = BaseModelType.Any
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_CLASS_NAMES: ClassVar = {
"CLIPModel",
@@ -1018,7 +1134,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
variant: Literal[ClipVariantType.G] = ClipVariantType.G
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPEmbed,
@@ -1053,7 +1169,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: Literal[ClipVariantType.L] = ClipVariantType.L
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPEmbed,
@@ -1087,8 +1203,9 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPVision,
@@ -1112,24 +1229,31 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
return cls(**fields)
T2IAdapterCheckpoint_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
]
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: T2IAdapterCheckpoint_SupportedBases = Field()
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
base: Literal[BaseModelType.Any] = BaseModelType.Any
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SpandrelImageToImage,
"format": ModelFormat.Checkpoint,
"base": BaseModelType.Any,
}
@classmethod
@@ -1148,8 +1272,7 @@ class SpandrelImageToImageConfig(ModelConfigBase):
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
base = fields.get("base") or BaseModelType.Any
return cls(**fields, base=base)
return cls(**fields)
except Exception as e:
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
@@ -1157,8 +1280,9 @@ class SpandrelImageToImageConfig(ModelConfigBase):
class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SigLIP,
@@ -1185,8 +1309,9 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
class FluxReduxConfig(ModelConfigBase):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.FluxRedux,
@@ -1208,9 +1333,9 @@ class FluxReduxConfig(ModelConfigBase):
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
base: Literal[BaseModelType.Any] = BaseModelType.Any
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal)
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LlavaOnevision,
@@ -1234,20 +1359,44 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
return cls(**fields)
ApiModel_SupportedBases: TypeAlias = Literal[
BaseModelType.ChatGPT4o,
BaseModelType.Gemini2_5,
BaseModelType.Imagen3,
BaseModelType.Imagen4,
BaseModelType.FluxKontext,
]
class ApiModelConfig(MainConfigBase, ModelConfigBase):
"""Model config for API-based models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
base: ApiModel_SupportedBases = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
raise NotAMatch(cls, "API models cannot be built from disk")
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
VideoApiModel_SupportedBases: TypeAlias = Literal[
BaseModelType.Veo3,
BaseModelType.Runway,
]
class VideoApiModelConfig(ModelConfigBase):
"""Model config for API-based video models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
type: Literal[ModelType.Video] = Field(default=ModelType.Video)
base: VideoApiModel_SupportedBases = Field()
format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None)
default_settings: MainModelDefaultSettings | None = Field(
description="Default settings for this model", default=None
)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:

View File

@@ -1,41 +1,70 @@
from enum import Enum
from typing import Dict, TypeAlias, Union
import diffusers
import onnxruntime as ort
import torch
from diffusers import ModelMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from pydantic import TypeAdapter
from invokeai.backend.raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
AnyModel: TypeAlias = Union[
ModelMixin,
RawModel,
torch.nn.Module,
Dict[str, torch.Tensor],
DiffusionPipeline,
ort.InferenceSession,
]
"""Type alias for any kind of runtime, in-memory model representation. For example, a torch module or diffusers pipeline."""
class BaseModelType(str, Enum):
"""Base model type."""
"""An enumeration of base model architectures. For example, Stable Diffusion 1.x, Stable Diffusion 2.x, FLUX, etc.
Every model config must have a base architecture type.
Not all models are associated with a base architecture. For example, CLIP models are their own thing, not related
to any particular model architecture. To simplify internal APIs and make it easier to work with models, we use a
fallback/null value `BaseModelType.Any` for these models, instead of making the model base optional."""
Any = "any"
"""`Any` is essentially a fallback/null value for models with no base architecture association.
For example, CLIP models are not related to Stable Diffusion, FLUX, or any other model arch."""
StableDiffusion1 = "sd-1"
"""Indicates the model is associated with the Stable Diffusion 1.x model architecture, including 1.4 and 1.5."""
StableDiffusion2 = "sd-2"
"""Indicates the model is associated with the Stable Diffusion 2.x model architecture, including 2.0 and 2.1."""
StableDiffusion3 = "sd-3"
"""Indicates the model is associated with the Stable Diffusion 3.5 model architecture."""
StableDiffusionXL = "sdxl"
"""Indicates the model is associated with the Stable Diffusion XL model architecture."""
StableDiffusionXLRefiner = "sdxl-refiner"
"""Indicates the model is associated with the Stable Diffusion XL Refiner model architecture."""
Flux = "flux"
"""Indicates the model is associated with FLUX.1 model architecture, including FLUX Dev, Schnell and Fill."""
CogView4 = "cogview4"
"""Indicates the model is associated with CogView 4 model architecture."""
Imagen3 = "imagen3"
"""Indicates the model is associated with Google Imagen 3 model architecture. This is an external API model."""
Imagen4 = "imagen4"
"""Indicates the model is associated with Google Imagen 4 model architecture. This is an external API model."""
Gemini2_5 = "gemini-2.5"
"""Indicates the model is associated with Google Gemini 2.5 Flash Image model architecture. This is an external API model."""
ChatGPT4o = "chatgpt-4o"
# This is actually the FLUX Kontext API model. Local FLUX Kontext is just BaseModelType.Flux.
"""Indicates the model is associated with OpenAI ChatGPT 4o Image model architecture. This is an external API model."""
FluxKontext = "flux-kontext"
"""Indicates the model is associated with FLUX Kontext model architecture. This is an external API model; local FLUX
Kontext models use the base `Flux`."""
Veo3 = "veo3"
"""Indicates the model is associated with Google Veo 3 video model architecture. This is an external API model."""
Runway = "runway"
"""Indicates the model is associated with Runway video model architecture. This is an external API model."""
Unknown = "unknown"
"""Indicates the model's base architecture is unknown."""
class ModelType(str, Enum):

View File

@@ -83,14 +83,14 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
return checkpoint
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
def lora_token_vector_length(checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]:
lora_token_vector_length = None
if "." not in key:
@@ -136,6 +136,8 @@ def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[in
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if isinstance(key, int):
continue
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (

View File

@@ -17,7 +17,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
def flux_format_from_state_dict(
state_dict: dict[str, Any],
state_dict: dict[str | int, Any],
metadata: dict[str, Any] | None = None,
) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_flux_kohya_format(state_dict):

View File

@@ -1,4 +1,4 @@
import type { BaseModelType, ModelFormat, ModelType, ModelVariantType } from 'features/nodes/types/common';
import type { AnyModelVariant, BaseModelType, ModelFormat, ModelType } from 'features/nodes/types/common';
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
@@ -219,13 +219,15 @@ export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
unknown: 'Unknown',
};
export const MODEL_VARIANT_TO_LONG_NAME: Record<ModelVariantType, string> = {
export const MODEL_VARIANT_TO_LONG_NAME: Record<AnyModelVariant, string> = {
normal: 'Normal',
inpaint: 'Inpaint',
depth: 'Depth',
dev: 'FLUX Dev',
dev_fill: 'FLUX Dev Fill',
dev_fill: 'FLUX Dev - Fill',
schnell: 'FLUX Schnell',
large: 'CLIP L',
gigantic: 'CLIP G',
};
export const MODEL_FORMAT_TO_LONG_NAME: Record<ModelFormat, string> = {

View File

@@ -1,12 +1,12 @@
import { Badge } from '@invoke-ai/ui-library';
import type { ModelFormat } from 'features/nodes/types/common';
import { memo } from 'react';
import type { AnyModelConfig } from 'services/api/types';
type Props = {
format: AnyModelConfig['format'];
format: ModelFormat;
};
const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
const FORMAT_NAME_MAP: Record<ModelFormat, string> = {
diffusers: 'diffusers',
lycoris: 'lycoris',
checkpoint: 'checkpoint',
@@ -20,9 +20,11 @@ const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
api: 'api',
omi: 'omi',
unknown: 'unknown',
olive: 'olive',
onnx: 'onnx',
};
const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
const FORMAT_COLOR_MAP: Record<ModelFormat, string> = {
diffusers: 'base',
omi: 'base',
lycoris: 'base',
@@ -36,6 +38,8 @@ const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
gguf_quantized: 'base',
api: 'base',
unknown: 'red',
olive: 'base',
onnx: 'base',
};
const ModelFormatBadge = ({ format }: Props) => {

View File

@@ -12,6 +12,7 @@ import type {
T2IAdapterField,
zBaseModelType,
zClipVariantType,
zFluxVariantType,
zModelFormat,
zModelVariantType,
zSubModelType,
@@ -45,6 +46,7 @@ describe('Common types', () => {
test('ModelIdentifier', () => assert<Equals<z.infer<typeof zSubModelType>, S['SubModelType']>>());
test('ClipVariantType', () => assert<Equals<z.infer<typeof zClipVariantType>, S['ClipVariantType']>>());
test('ModelVariantType', () => assert<Equals<z.infer<typeof zModelVariantType>, S['ModelVariantType']>>());
test('FluxVariantType', () => assert<Equals<z.infer<typeof zFluxVariantType>, S['FluxVariantType']>>());
test('ModelFormat', () => assert<Equals<z.infer<typeof zModelFormat>, S['ModelFormat']>>());
// Misc types

View File

@@ -147,8 +147,10 @@ export const zSubModelType = z.enum([
]);
export const zClipVariantType = z.enum(['large', 'gigantic']);
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth', 'dev', 'dev_fill', 'schnell']);
export type ModelVariantType = z.infer<typeof zModelVariantType>;
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
export const zFluxVariantType = z.enum(['dev', 'dev_fill', 'schnell']);
export const zAnyModelVariant = z.union([zModelVariantType, zClipVariantType, zFluxVariantType]);
export type AnyModelVariant = z.infer<typeof zAnyModelVariant>;
export const zModelFormat = z.enum([
'omi',
'diffusers',

View File

@@ -10,15 +10,14 @@ import { z } from 'zod';
import type { ImageField } from './common';
import {
zAnyModelVariant,
zBaseModelType,
zBoardField,
zClipVariantType,
zColorField,
zImageField,
zModelFormat,
zModelIdentifierField,
zModelType,
zModelVariantType,
zSchedulerField,
} from './common';
@@ -73,7 +72,7 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({
ui_choice_labels: z.record(z.string(), z.string()).nullish(),
ui_model_base: z.array(zBaseModelType).nullish(),
ui_model_type: z.array(zModelType).nullish(),
ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(),
ui_model_variant: z.array(zAnyModelVariant).nullish(),
ui_model_format: z.array(zModelFormat).nullish(),
});
const zFieldOutputTemplateBase = zFieldTemplateBase.extend({

View File

@@ -673,6 +673,8 @@ describe('Graph', () => {
variant: 'inpaint',
format: 'diffusers',
repo_variant: 'fp16',
submodels: null,
usage_info: null,
});
expect(field).toEqual({
key: 'b00ee8df-523d-40d2-9578-597283b07cb2',

View File

@@ -25,9 +25,7 @@ export const MainModelPicker = memo(() => {
const isFluxDevSelected = useMemo(
() =>
selectedModelConfig &&
isCheckpointMainModelConfig(selectedModelConfig) &&
selectedModelConfig.variant === 'flux_dev',
selectedModelConfig && isCheckpointMainModelConfig(selectedModelConfig) && selectedModelConfig.variant === 'dev',
[selectedModelConfig]
);

View File

@@ -24,9 +24,7 @@ export const InitialStateMainModelPicker = memo(() => {
const isFluxDevSelected = useMemo(
() =>
selectedModelConfig &&
isCheckpointMainModelConfig(selectedModelConfig) &&
selectedModelConfig.variant === 'flux_dev',
selectedModelConfig && isCheckpointMainModelConfig(selectedModelConfig) && selectedModelConfig.variant === 'dev',
[selectedModelConfig]
);

File diff suppressed because it is too large Load Diff

View File

@@ -119,8 +119,6 @@ type LlavaOnevisionConfig = S['LlavaOnevisionConfig'];
export type T5EncoderModelConfig = S['T5EncoderConfig'];
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
type SigLipModelConfig = S['SigLIPConfig'];
@@ -128,29 +126,11 @@ export type FLUXReduxModelConfig = S['FluxReduxConfig'];
type ApiModelConfig = S['ApiModelConfig'];
export type VideoApiModelConfig = S['VideoApiModelConfig'];
type UnknownModelConfig = S['UnknownModelConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig | ApiModelConfig;
export type MainModelConfig = Extract<S['AnyModelConfig'], { type: 'main' }>;
export type FLUXKontextModelConfig = MainModelConfig;
export type ChatGPT4oModelConfig = ApiModelConfig;
export type Gemini2_5ModelConfig = ApiModelConfig;
export type AnyModelConfig =
| ControlLoRAModelConfig
| LoRAModelConfig
| VAEModelConfig
| ControlNetModelConfig
| IPAdapterModelConfig
| T5EncoderModelConfig
| T5EncoderBnbQuantizedLlmInt8bModelConfig
| CLIPEmbedModelConfig
| T2IAdapterModelConfig
| SpandrelImageToImageModelConfig
| TextualInversionModelConfig
| MainModelConfig
| VideoApiModelConfig
| CLIPVisionDiffusersConfig
| SigLipModelConfig
| FLUXReduxModelConfig
| LlavaOnevisionConfig
| UnknownModelConfig;
export type AnyModelConfig = S['AnyModelConfig'];
/**
* Checks if a list of submodels contains any that match a given variant or type

View File

@@ -295,18 +295,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
const { id, config } = data;
if (
config.type === 'unknown' ||
config.base === 'unknown' ||
/**
* Checking if type/base are 'unknown' technically narrows the config such that it's not possible for a config
* that passes to the `config.[type|base] === 'unknown'` checks. In the future, if we have more model config
* classes, this may change, so we will continue to check all three. Any one being 'unknown' is concerning
* enough to warrant a toast.
*/
/* @ts-expect-error See note above */
config.format === 'unknown'
) {
if (config.type === 'unknown') {
toast({
id: 'UNKNOWN_MODEL',
title: t('modelManager.unidentifiedModelTitle'),

View File

@@ -31,7 +31,10 @@ def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
try:
return ModelProbe.probe(path, hash_algo=hash_algo)
except InvalidModelConfigException:
return ModelConfigFactory.from_model_on_disk(mod=path, hash_algo=hash_algo,)
return ModelConfigFactory.from_model_on_disk(
mod=path,
hash_algo=hash_algo,
)
for path in args.model_path: