diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 1bfc15c046..e221b57252 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -28,7 +28,7 @@ from abc import ABC, abstractmethod from enum import Enum from inspect import isabstract from pathlib import Path -from typing import ClassVar, Literal, Optional, TypeAlias, Union +from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict @@ -109,6 +109,18 @@ class MatchSpeed(int, Enum): SLOW = 2 +class LegacyProbeMixin: + """Mixin for classes using the legacy probe for model classification.""" + + @classmethod + def matches(cls, *args, **kwargs): + raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}") + + @classmethod + def parse(cls, *args, **kwargs): + raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}") + + class ModelConfigBase(ABC, BaseModel): """ Abstract Base class for model configurations. @@ -152,15 +164,15 @@ class ModelConfigBase(ABC, BaseModel): ) usage_info: Optional[str] = Field(default=None, description="Usage information for this model") - USING_LEGACY_PROBE: ClassVar[set] = set() - USING_CLASSIFY_API: ClassVar[set] = set() + USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set() + USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set() _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if issubclass(cls, LegacyProbeMixin): ModelConfigBase.USING_LEGACY_PROBE.add(cls) - else: + elif cls is not UnknownModelConfig: ModelConfigBase.USING_CLASSIFY_API.add(cls) @staticmethod @@ -170,7 +182,9 @@ class ModelConfigBase(ABC, BaseModel): return concrete @staticmethod - def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides): + def classify( + mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides + ) -> "AnyModelConfig": """ Returns the best matching ModelConfig instance from a model's file/folder path. Raises InvalidModelConfigException if no valid configuration is found. @@ -192,7 +206,10 @@ class ModelConfigBase(ABC, BaseModel): else: return config_cls.from_model_on_disk(mod, **overrides) - raise InvalidModelConfigException("Unable to determine model type") + try: + return UnknownModelConfig.from_model_on_disk(mod, **overrides) + except Exception: + raise InvalidModelConfigException("Unable to determine model type") @classmethod def get_tag(cls) -> Tag: @@ -256,16 +273,17 @@ class ModelConfigBase(ABC, BaseModel): return cls(**fields) -class LegacyProbeMixin: - """Mixin for classes using the legacy probe for model classification.""" +class UnknownModelConfig(ModelConfigBase): + type: Literal[ModelType.Unknown] = ModelType.Unknown + format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown @classmethod - def matches(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}") + def matches(cls, *args, **kwargs) -> bool: + raise NotImplementedError("UnknownModelConfig cannot match anything") @classmethod - def parse(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}") + def parse(cls, *args, **kwargs) -> dict[str, Any]: + raise NotImplementedError("UnknownModelConfig cannot parse anything") class CheckpointConfigBase(ABC, BaseModel): @@ -353,7 +371,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): metadata = mod.metadata() return ( - metadata.get("modelspec.sai_model_spec") + bool(metadata.get("modelspec.sai_model_spec")) and metadata.get("ot_branch") == "omi_format" and metadata["modelspec.architecture"].split("/")[1].lower() == "lora" ) @@ -751,6 +769,7 @@ AnyModelConfig = Annotated[ Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()], Annotated[ApiModelConfig, ApiModelConfig.get_tag()], Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()], + Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()], ], Discriminator(get_model_discriminator_value), ] diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index ba3c8586db..120b4a4dd9 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -55,6 +55,7 @@ class ModelType(str, Enum): FluxRedux = "flux_redux" LlavaOnevision = "llava_onevision" Video = "video" + Unknown = "unknown" class SubModelType(str, Enum): @@ -107,6 +108,7 @@ class ModelFormat(str, Enum): BnbQuantizednf4b = "bnb_quantized_nf4b" GGUFQuantized = "gguf_quantized" Api = "api" + Unknown = "unknown" class SchedulerPredictionType(str, Enum):