New API for model classification

This commit is contained in:
Billy
2025-03-06 14:01:39 +11:00
parent a89d68b93a
commit bfdace6437
10 changed files with 1082 additions and 813 deletions

View File

@@ -38,9 +38,11 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
ModelRepoVariant,
ModelSourceType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
@@ -49,7 +51,6 @@ from invokeai.backend.model_manager.metadata import (
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
@@ -182,9 +183,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@@ -644,12 +643,22 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path)
return new_path
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
config = config or ModelRecordChanges()
overrides = config.model_dump()
try:
return ModelConfigBase.classify(model_path, **overrides)
except InvalidModelConfigException:
return ModelProbe.probe(
model_path=model_path, fields=overrides, hash_algo=self._app_config.hashing_algorithm
) # type: ignore
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or ModelRecordChanges()
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
info = info or self._probe(model_path, config)
model_path = model_path.resolve()

View File

@@ -13,8 +13,8 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
__all__ = [

View File

@@ -20,21 +20,32 @@ Validation errors will raise an InvalidModelConfigException error.
"""
import logging
import time
from abc import ABC, abstractmethod
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
import diffusers
import onnxruntime as ort
import safetensors.torch
import torch
from diffusers.models.modeling_utils import ModelMixin
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings
logger = logging.getLogger(__name__)
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
@@ -44,7 +55,7 @@ AnyModel = Union[
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
"""Exception for when config parser doesn't recognize this combination of model type and format."""
class BaseModelType(str, Enum):
@@ -190,12 +201,68 @@ class MainModelDefaultSettings(BaseModel):
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
model_config = ConfigDict(extra="forbid")
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
if self.format_type == ModelFormat.Diffusers:
raise NotImplementedError()
with SilenceWarnings():
if self.path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(self.path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(self.path, map_location="cpu")
elif self.path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(self.path, compute_dtype=torch.float32)
else:
checkpoint = safetensors.torch.load_file(self.path)
state_dict = checkpoint.get("state_dict") or checkpoint
return state_dict
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
FAST = 0
MED = 1
SLOW = 2
class ModelConfigBase(ABC, BaseModel):
"""
Abstract Base class for model configurations.
To create a new config type, inherit from this class and implement its interface:
- (mandatory) override methods 'matches' and 'parse'
- (mandatory) define fields 'type' and 'format' as class attributes
- (mandatory) return field 'base' in 'matches' return value OR as a class attribute
- (optional) override method 'get_tag'
- (optional) override field _MATCH_SPEED
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
@staticmethod
def json_schema_extra(schema: dict[str, Any]) -> None:
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
@@ -203,27 +270,120 @@ class ModelConfigBase(BaseModel):
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
)
name: str = Field(description="Name of the model.")
type: ModelType = Field(description="Model type")
format: ModelFormat = Field(description="Model format")
base: BaseModelType = Field(description="The base model.")
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
hash_algo: Optional[HASHING_ALGORITHMS] = Field(description="The algorithm used to compute the hash.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
def __init__(self, **fields):
path = Path(fields["path"])
fields["path"] = path.as_posix()
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
default_hash_algo: HASHING_ALGORITHMS = "blake3_single"
fields["hash_algo"] = hash_algo = fields.get("hash_algo", default_hash_algo)
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(path)
name = fields.get("name")
if not name:
if path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
fields["name"] = path.stem
else:
fields["name"] = path.name
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or path.as_posix()
super().__init__(**fields)
@classmethod
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
format = cls.model_fields["format"].default.value
return Tag(f"{type}.{format}")
@classmethod
@abstractmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
"""Returns a dictionary with the fields needed to construct the model.
Raises InvalidModelConfigException if the model is invalid.
"""
pass
@classmethod
@abstractmethod
def matches(cls, mod: ModelOnDisk) -> bool:
"""Performs a quick check to determine if the config matches the model.
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
fields["path"] = fields.get("path") or mod.path
fields.update(overrides)
return cls(**fields)
_USING_LEGACY_PROBE: ClassVar[set] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
@staticmethod
def classify(path: Path, **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
candidates = concrete_subclasses(ModelConfigBase) - ModelConfigBase._USING_LEGACY_PROBE
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
mod = ModelOnDisk(path)
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
raise InvalidModelConfigException("No valid config found")
def legacy_probe(cls):
"""Registers classes using the legacy probe for model classification.
NOT intended for bass classes like LoRAConfigBase OR T5EncoderConfigBase
To port a config over, remove this decorator and implement 'matches' and 'parse'.
"""
def matches(c):
raise NotImplementedError()
def parse(c):
raise NotImplementedError()
cls.matches = cls.matches or classmethod(matches)
cls.parse = cls.parse or classmethod(parse)
cls.__abstractmethods__ -= {"matches", "parse"}
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
return cls
class CheckpointConfigBase(BaseModel):
"""Base class for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
@@ -234,47 +394,42 @@ class CheckpointConfigBase(ModelConfigBase):
)
class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models."""
class DiffusersConfigBase(BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfigBase(ModelConfigBase):
class LoRAConfigBase(BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ModelConfigBase):
class T5EncoderConfigBase(BaseModel):
"""Base class for diffusers-style models."""
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase):
@legacy_probe
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
@legacy_probe
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
@legacy_probe
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
@@ -282,105 +437,78 @@ class ControlAdapterConfigBase(BaseModel):
)
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
@legacy_probe
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
@legacy_probe
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
class LoRADiffusersConfig(LoRAConfigBase):
@legacy_probe
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VAECheckpointConfig(CheckpointConfigBase):
@legacy_probe
class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
@legacy_probe
class VAEDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
@legacy_probe
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
@legacy_probe
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
@legacy_probe
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
@legacy_probe
class TextualInversionFolderConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainConfigBase(ModelConfigBase):
class MainConfigBase(BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
@@ -389,219 +517,173 @@ class MainConfigBase(ModelConfigBase):
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
@legacy_probe
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
@legacy_probe
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
@legacy_probe
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.GGUFQuantized
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
@legacy_probe
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main diffusers models."""
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
pass
class IPAdapterBaseConfig(ModelConfigBase):
class IPAdapterConfigBase(BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
@legacy_probe
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
@legacy_probe
class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint]
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
variant: ClipVariantType = Field(description="Clip variant for this model")
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
@legacy_probe
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
variant: Literal[ClipVariantType.G] = ClipVariantType.G
variant: ClipVariantType = ClipVariantType.G
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
@legacy_probe
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
variant: Literal[ClipVariantType.L] = ClipVariantType.L
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
@legacy_probe
class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
@legacy_probe
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
@legacy_probe
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
class SigLIPConfig(DiffusersConfigBase):
@legacy_probe
class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SigLIP.value}.{ModelFormat.Diffusers.value}")
@legacy_probe
class FluxReduxConfig(ModelConfigBase):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = None
type_ = None
format_ = type_ = variant_ = None
if isinstance(v, dict):
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
variant_ = v.get("variant")
if isinstance(variant_, Enum):
variant_ = variant_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
variant_ = getattr(v, "variant", None)
# Ideally, each config would be uniquely identified with a combination of fields
# i.e. (type, format, variant) without any special cases. Alas...
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
# To maintain compatibility, we default to ClipVariantType.L in this case.
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
variant_ = variant_ or ClipVariantType.L.value
return f"{type_}.{format_}.{variant_}"
return f"{type_}.{format_}"
def concrete_subclasses(base):
subclasses = set(base.__subclasses__())
for sc in base.__subclasses__():
subclasses.update(concrete_subclasses(sc))
return {sc for sc in subclasses if not isabstract(sc)}
config_classes = sorted(concrete_subclasses(ModelConfigBase), key=lambda c: c.__name__) # sorted for consistency
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
],
Union[tuple(Annotated[cls, cls.get_tag()] for cls in config_classes)],
Discriminator(get_model_discriminator_value),
]
@@ -609,39 +691,12 @@ AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type[ModelConfigBase]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
model: Optional[ModelConfigBase] = None
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
class ModelConfigFactory:
@staticmethod
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
if isinstance(model, CheckpointConfigBase) and timestamp:
model.converted_at = timestamp
if model:
validate_hash(model.hash)
validate_hash(model.hash)
return model # type: ignore

View File

@@ -1,4 +1,4 @@
"""Utilities for parsing model files, used mostly by probe.py"""
"""Utilities for parsing model files, used mostly by legacy_probe.py"""
import json
from pathlib import Path

File diff suppressed because it is too large Load Diff

View File

@@ -52,9 +52,9 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type CLIPEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type CLIPLEmbedModelConfig = S['CLIPLEmbedDiffusersConfig'];
export type CLIPGEmbedModelConfig = S['CLIPGEmbedDiffusersConfig'];
export type CLIPEmbedModelConfig = CLIPLEmbedModelConfig | CLIPGEmbedModelConfig;
export type T5EncoderModelConfig = S['T5EncoderConfig'];
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];

View File

@@ -125,6 +125,7 @@ dependencies = [
"pytest-datadir",
"requests_testadapter",
"httpx",
"polyfactory==2.19.0"
]
[project.scripts]

View File

@@ -1,18 +1,25 @@
import abc
import json
from pathlib import Path
from typing import Any
import pydantic
import pytest
import torch
from polyfactory.factories.pydantic_factory import ModelFactory
from sympy.testing.pytest import slow
from torch import tensor
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
from invokeai.backend.model_manager.config import InvalidModelConfigException, MainDiffusersConfig, ModelVariantType
from invokeai.backend.model_manager.probe import (
from invokeai.backend.model_manager.config import *
from invokeai.backend.model_manager.legacy_probe import (
CkptType,
ModelProbe,
VaeFolderProbe,
get_default_settings_control_adapters,
get_default_settings_main,
)
from invokeai.backend.model_manager.search import ModelSearch
@pytest.mark.parametrize(
@@ -88,3 +95,135 @@ def test_probe_sd1_diffusers_inpainting(datadir: Path):
assert config.base is BaseModelType.StableDiffusion1
assert config.variant is ModelVariantType.Inpaint
assert config.repo_variant is ModelRepoVariant.FP16
class MinimalConfigExample(ModelConfigBase):
type: ModelType = ModelType.Main
format: ModelFormat = ModelFormat.Checkpoint
fun_quote: str
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
return mod.path.suffix == ".json"
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
with open(mod.path, "r") as f:
contents = json.load(f)
return {
"fun_quote": contents["quote"],
"base": BaseModelType.Any,
}
def test_minimal_working_example(datadir: Path):
model_path = datadir / "minimal_config_model.json"
overrides = {"base": BaseModelType.StableDiffusion1}
config = ModelConfigBase.classify(model_path, **overrides)
assert isinstance(config, MinimalConfigExample)
assert config.base == BaseModelType.StableDiffusion1
assert config.path == model_path.as_posix()
assert config.fun_quote == "Minimal working example of a ModelConfigBase subclass"
def test_regression_against_model_probe(datadir: Path):
"""Ensure ModelConfigBase.classify returns consistent results as ModelProbe.probe"""
model_paths = ModelSearch().search(datadir) # TODO: add more 'stripped' models to test_model_probe directory
for path in model_paths:
legacy_config = new_config = None
probe_success = classify_success = True
try:
legacy_config = ModelProbe.probe(path)
except InvalidModelConfigException:
probe_success = False
try:
new_config = ModelConfigBase.classify(path)
except InvalidModelConfigException:
classify_success = False
if probe_success and classify_success:
assert legacy_config == new_config
elif probe_success:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
elif classify_success:
assert type(new_config) not in ModelConfigBase._USING_LEGACY_PROBE
else:
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
class ConfigMocker:
"""Utility class to create config instances with random data"""
def __init__(self):
self._factories = {}
def mock(self, config_cls, count):
if config_cls not in self._factories:
class Factory(ModelFactory[config_cls]):
__use_defaults__ = True
__random_seed__ = 1234
__check_model__ = True
f = Factory()
self._factories[config_cls] = f
factory = self._factories[config_cls]
return [factory.build() for _ in range(count)]
@slow
def test_serialisation_roundtrip():
"""After classification, models are serialised to json and stored in the database.
We need to ensure they are de-serialised into the original config with all relevant fields restored.
"""
mocker = ConfigMocker()
excluded = {MinimalConfigExample}
for config_cls in concrete_subclasses(ModelConfigBase) - excluded:
trials_per_class = 50
configs_with_random_data = mocker.mock(config_cls, trials_per_class)
for config in configs_with_random_data:
as_json = config.model_dump_json()
as_dict = json.loads(as_json)
reconstructed = ModelConfigFactory.make_config(as_dict)
assert isinstance(reconstructed, config_cls)
assert config.model_dump_json() == reconstructed.model_dump_json()
def test_inheritance_order():
"""
Safeguard test to warn against incorrect inheritance order.
Config classes using multiple inheritance should inherit from ModelConfigBase last
to ensure that more specific fields take precedence over the generic defaults.
It may be worth rethinking our config taxonomy in the future, but in the meantime,
this test can help prevent the debugging effort I went through discovering this.
"""
for config_cls in concrete_subclasses(ModelConfigBase):
excluded = { abc.ABC, pydantic.BaseModel, object }
inheritance_list = [cls for cls in config_cls.mro() if cls not in excluded]
assert inheritance_list[-1] is ModelConfigBase
def test_concrete_subclasses():
excluded = { MinimalConfigExample }
config_classes = concrete_subclasses(ModelConfigBase) - excluded
expected = {
CLIPGEmbedDiffusersConfig, MainGGUFCheckpointConfig, T2IAdapterConfig, TextualInversionFolderConfig,
IPAdapterInvokeAIConfig, ControlNetDiffusersConfig, ControlLoRALyCORISConfig, MainDiffusersConfig,
LoRALyCORISConfig, CLIPVisionDiffusersConfig, MainCheckpointConfig, T5EncoderConfig, IPAdapterCheckpointConfig,
VAEDiffusersConfig, LoRADiffusersConfig, ControlNetCheckpointConfig, FluxReduxConfig,
T5EncoderBnbQuantizedLlmInt8bConfig, SpandrelImageToImageConfig, MainBnbQuantized4bCheckpointConfig,
TextualInversionFileConfig, CLIPLEmbedDiffusersConfig, VAECheckpointConfig, ControlLoRADiffusersConfig,
SigLIPConfig,
}
assert config_classes == expected

View File

@@ -0,0 +1,3 @@
{
"quote": "Minimal working example of a ModelConfigBase subclass"
}