feat(mm): add UnknownModelConfig

This commit is contained in:
psychedelicious
2025-09-18 15:37:45 +10:00
parent bd4bb075a5
commit 3af504eee8
2 changed files with 34 additions and 13 deletions

View File

@@ -28,7 +28,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@@ -109,6 +109,18 @@ class MatchSpeed(int, Enum):
SLOW = 2
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
class ModelConfigBase(ABC, BaseModel):
"""
Abstract Base class for model configurations.
@@ -152,15 +164,15 @@ class ModelConfigBase(ABC, BaseModel):
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
USING_LEGACY_PROBE: ClassVar[set] = set()
USING_CLASSIFY_API: ClassVar[set] = set()
USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set()
USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
elif cls is not UnknownModelConfig:
ModelConfigBase.USING_CLASSIFY_API.add(cls)
@staticmethod
@@ -170,7 +182,9 @@ class ModelConfigBase(ABC, BaseModel):
return concrete
@staticmethod
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
def classify(
mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides
) -> "AnyModelConfig":
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
@@ -192,7 +206,10 @@ class ModelConfigBase(ABC, BaseModel):
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("Unable to determine model type")
try:
return UnknownModelConfig.from_model_on_disk(mod, **overrides)
except Exception:
raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:
@@ -256,16 +273,17 @@ class ModelConfigBase(ABC, BaseModel):
return cls(**fields)
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
class UnknownModelConfig(ModelConfigBase):
type: Literal[ModelType.Unknown] = ModelType.Unknown
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
def matches(cls, *args, **kwargs) -> bool:
raise NotImplementedError("UnknownModelConfig cannot match anything")
@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
def parse(cls, *args, **kwargs) -> dict[str, Any]:
raise NotImplementedError("UnknownModelConfig cannot parse anything")
class CheckpointConfigBase(ABC, BaseModel):
@@ -353,7 +371,7 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
metadata = mod.metadata()
return (
metadata.get("modelspec.sai_model_spec")
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
)
@@ -751,6 +769,7 @@ AnyModelConfig = Annotated[
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()],
Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -55,6 +55,7 @@ class ModelType(str, Enum):
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
Video = "video"
Unknown = "unknown"
class SubModelType(str, Enum):
@@ -107,6 +108,7 @@ class ModelFormat(str, Enum):
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
Api = "api"
Unknown = "unknown"
class SchedulerPredictionType(str, Enum):