mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): add UnknownModelConfig
This commit is contained in:
@@ -28,7 +28,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, TypeAlias, Union
|
||||
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
@@ -109,6 +109,18 @@ class MatchSpeed(int, Enum):
|
||||
SLOW = 2
|
||||
|
||||
|
||||
class LegacyProbeMixin:
|
||||
"""Mixin for classes using the legacy probe for model classification."""
|
||||
|
||||
@classmethod
|
||||
def matches(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
|
||||
|
||||
|
||||
class ModelConfigBase(ABC, BaseModel):
|
||||
"""
|
||||
Abstract Base class for model configurations.
|
||||
@@ -152,15 +164,15 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
)
|
||||
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
|
||||
|
||||
USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set()
|
||||
USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set()
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, LegacyProbeMixin):
|
||||
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
|
||||
else:
|
||||
elif cls is not UnknownModelConfig:
|
||||
ModelConfigBase.USING_CLASSIFY_API.add(cls)
|
||||
|
||||
@staticmethod
|
||||
@@ -170,7 +182,9 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
return concrete
|
||||
|
||||
@staticmethod
|
||||
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
|
||||
def classify(
|
||||
mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides
|
||||
) -> "AnyModelConfig":
|
||||
"""
|
||||
Returns the best matching ModelConfig instance from a model's file/folder path.
|
||||
Raises InvalidModelConfigException if no valid configuration is found.
|
||||
@@ -192,7 +206,10 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
try:
|
||||
return UnknownModelConfig.from_model_on_disk(mod, **overrides)
|
||||
except Exception:
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
@@ -256,16 +273,17 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class LegacyProbeMixin:
|
||||
"""Mixin for classes using the legacy probe for model classification."""
|
||||
class UnknownModelConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.Unknown] = ModelType.Unknown
|
||||
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
|
||||
|
||||
@classmethod
|
||||
def matches(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
|
||||
def matches(cls, *args, **kwargs) -> bool:
|
||||
raise NotImplementedError("UnknownModelConfig cannot match anything")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
|
||||
def parse(cls, *args, **kwargs) -> dict[str, Any]:
|
||||
raise NotImplementedError("UnknownModelConfig cannot parse anything")
|
||||
|
||||
|
||||
class CheckpointConfigBase(ABC, BaseModel):
|
||||
@@ -353,7 +371,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
metadata = mod.metadata()
|
||||
return (
|
||||
metadata.get("modelspec.sai_model_spec")
|
||||
bool(metadata.get("modelspec.sai_model_spec"))
|
||||
and metadata.get("ot_branch") == "omi_format"
|
||||
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
|
||||
)
|
||||
@@ -751,6 +769,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
|
||||
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
|
||||
Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()],
|
||||
Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -55,6 +55,7 @@ class ModelType(str, Enum):
|
||||
FluxRedux = "flux_redux"
|
||||
LlavaOnevision = "llava_onevision"
|
||||
Video = "video"
|
||||
Unknown = "unknown"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@@ -107,6 +108,7 @@ class ModelFormat(str, Enum):
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
Api = "api"
|
||||
Unknown = "unknown"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
|
||||
Reference in New Issue
Block a user