refactor(mm): add model config parsing utils

This commit is contained in:
psychedelicious
2025-09-24 17:42:56 +10:00
parent fd47da6842
commit 3488975b2b
4 changed files with 316 additions and 323 deletions

View File

@@ -39,7 +39,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
@@ -612,7 +612,11 @@ class ModelInstallService(ModelInstallServiceBase):
try:
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo)
return ModelConfigFactory.from_model_on_disk(
mod=model_path,
overrides=deepcopy(fields),
hash_algo=hash_algo,
)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
@@ -633,7 +637,7 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
if isinstance(info, CheckpointConfigBase):
if isinstance(info, CheckpointConfigBase) and info.config_path is not None:
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path

View File

@@ -86,34 +86,86 @@ class NotAMatch(Exception):
reason: The reason why the model did not match.
"""
def __init__(self, config_class: "Type[AnyModelConfig]", reason: str):
def __init__(
self,
config_class: type,
reason: str,
):
super().__init__(f"{config_class.__name__} does not match: {reason}")
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
if "_class_name" in config:
return config["_class_name"]
elif "architectures" in config:
return config["architectures"][0]
else:
return None
def get_config_or_raise(
config_class: type,
config_path: Path,
) -> dict[str, Any]:
"""Load the config file at the given path, or raise NotAMatch if it cannot be loaded."""
if not config_path.exists():
raise NotAMatch(config_class, f"missing config file: {config_path}")
try:
config = load_json(config_path)
return config
except Exception as e:
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
def validate_overrides(
config_class: "Type[AnyModelConfig]", overrides: dict[str, Any], allowed: dict[str, Any]
def raise_for_class_names(
config_class: type,
config_path: Path,
valid_class_names: set[str],
) -> None:
for key, value in allowed.items():
if key not in overrides:
"""Raise NotAMatch if the config file is missing or does not contain a valid class name."""
config = get_config_or_raise(config_class, config_path)
try:
if "_class_name" in config:
config_class_name = config["_class_name"]
elif "architectures" in config:
config_class_name = config["architectures"][0]
else:
raise ValueError("missing _class_name or architectures field")
except Exception as e:
raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e
if config_class_name not in valid_class_names:
raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}")
def matches_overrides(
config_class: "Type[AnyModelConfig]",
provided_overrides: dict[str, Any],
valid_overrides: dict[str, Any],
) -> bool:
"""Check if the provided overrides match the valid overrides for this config class.
Args:
config_class: The config class that is being tested.
provided_overrides: The overrides provided by the user.
valid_overrides: The overrides that are valid for this config class.
Returns:
True if all provided overrides match the valid overrides, False if some valid overrides are missing.
Raises:
NotAMatch if any override does not match the allowed value.
"""
is_perfect_match = True
for key, value in valid_overrides.items():
if key not in provided_overrides:
is_perfect_match = False
continue
if overrides[key] != value:
if provided_overrides[key] != value:
raise NotAMatch(
config_class,
f"override {key}={overrides[key]} does not match required value {key}={value}",
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
)
return is_perfect_match
class SubmodelDefinition(BaseModel):
path_or_prefix: str
@@ -327,36 +379,32 @@ def load_json(path: Path) -> dict[str, Any]:
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T5Encoder,
"format": ModelFormat.T5Encoder,
}
VALID_CLASS_NAMES: ClassVar = {
"T5EncoderModel",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.T5Encoder:
raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder")
if format_override is not None and format_override is not ModelFormat.T5Encoder:
raise NotAMatch(cls, f"format override is {format_override}, not T5Encoder")
if type_override is ModelType.T5Encoder and format_override is ModelFormat.T5Encoder:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
# Heuristic: Look for the T5EncoderModel class name in the config
try:
config = load_json(mod.path / "text_encoder_2" / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e
if config_class_name != "T5EncoderModel":
raise NotAMatch(cls, "model class is not T5EncoderModel")
raise_for_class_names(
config_class=cls,
config_path=mod.path / "text_encoder_2" / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
)
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
@@ -370,33 +418,30 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
VALID_OVERRIDES: ClassVar = {
"type": ModelType.T5Encoder,
"format": ModelFormat.BnbQuantizedLlmInt8b,
}
VALID_CLASS_NAMES: ClassVar = {
"T5EncoderModel",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.T5Encoder:
raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder")
if format_override is not None and format_override is not ModelFormat.BnbQuantizedLlmInt8b:
raise NotAMatch(cls, f"format override is {format_override}, not BnbQuantizedLlmInt8b")
if type_override is ModelType.T5Encoder and format_override is ModelFormat.BnbQuantizedLlmInt8b:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# Heuristic: Look for the T5EncoderModel class name in the config
try:
config = load_json(mod.path / "text_encoder_2" / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e
if config_class_name != "T5EncoderModel":
raise NotAMatch(cls, "model class is not T5EncoderModel")
raise_for_class_names(
config_class=cls,
config_path=mod.path / "text_encoder_2" / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
)
# Heuristic: look for the quantization in the filename name
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
@@ -413,18 +458,18 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
"format": ModelFormat.OMI,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LoRA:
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
if format_override is not None and format_override is not ModelFormat.OMI:
raise NotAMatch(cls, f"format override is {format_override}, not OMI")
if type_override is ModelType.LoRA and format_override is ModelFormat.OMI:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# Heuristic: OMI LoRAs are always files, never directories
@@ -446,12 +491,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
if not is_omi_lora_heuristic:
raise NotAMatch(cls, "model does not match OMI LoRA heuristics")
base = fields.get("base") or cls.get_base_or_raise(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
@@ -468,18 +513,18 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
"format": ModelFormat.LyCORIS,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LoRA:
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
if format_override is not None and format_override is not ModelFormat.LyCORIS:
raise NotAMatch(cls, f"format override is {format_override}, not LyCORIS")
if type_override is ModelType.LoRA and format_override is ModelFormat.LyCORIS:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# Heuristic: LyCORIS LoRAs are always files, never directories
@@ -544,18 +589,18 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LoRA,
"format": ModelFormat.Diffusers,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LoRA:
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if type_override is ModelType.LoRA and format_override is ModelFormat.Diffusers:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
# Heuristic: Diffusers LoRAs are always directories, never files
@@ -583,8 +628,31 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
VALID_OVERRIDES: ClassVar = {
"type": ModelType.VAE,
"format": ModelFormat.Checkpoint,
}
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
@@ -597,36 +665,41 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
raise NotAMatch(cls, "cannot determine base type")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.VAE:
raise NotAMatch(cls, f"type override is {type_override}, not VAE")
if format_override is not None and format_override is not ModelFormat.Checkpoint:
raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint")
if type_override is ModelType.VAE and format_override is ModelFormat.Checkpoint:
return cls(**fields)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
base = fields.get("base") or cls.get_base_or_raise(mod)
return cls(**fields, base=base)
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
VALID_OVERRIDES: ClassVar = {
"type": ModelType.VAE,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"AutoencoderKL",
"AutoencoderTiny",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
raise_for_class_names(
config_class=cls,
config_path=mod.path / "config.json",
valid_class_names=cls.VALID_CLASS_NAMES,
)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
@@ -648,7 +721,8 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
return name
@classmethod
def get_base(cls, mod: ModelOnDisk, config: dict[str, Any]) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config = get_config_or_raise(cls, mod.path / "config.json")
if cls._config_looks_like_sdxl(config):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod):
@@ -657,39 +731,6 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
return BaseModelType.StableDiffusion1
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.VAE:
raise NotAMatch(cls, f"type override is {type_override}, not VAE")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if type_override is ModelType.VAE and format_override is ModelFormat.Diffusers:
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config") from e
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
base = fields.get("base") or cls.get_base(mod, config)
return cls(**fields, base=base)
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
@@ -710,7 +751,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
@classmethod
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
try:
p = path or mod.path
@@ -738,11 +779,15 @@ class TextualInversionConfigBase(ABC, BaseModel):
return False
@classmethod
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
try:
state_dict = mod.load_state_dict(p)
except Exception as e:
raise NotAMatch(cls, f"unable to load state dict from {p}: {e}") from e
try:
if "string_to_token" in state_dict:
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
elif "emb_params" in state_dict:
@@ -751,49 +796,18 @@ class TextualInversionConfigBase(ABC, BaseModel):
token_dim = state_dict["clip_g"].shape[-1]
else:
token_dim = list(state_dict.values())[0].shape[0]
except Exception as e:
raise NotAMatch(cls, f"unable to determine token dimension from state dict in {p}: {e}") from e
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
pass
except Exception:
pass
raise InvalidModelConfigException(f"{p}: Could not determine base type")
@classmethod
def get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
try:
state_dict = mod.load_state_dict(p)
if "string_to_token" in state_dict:
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
elif "emb_params" in state_dict:
token_dim = state_dict["emb_params"].shape[-1]
elif "clip_g" in state_dict:
token_dim = state_dict["clip_g"].shape[-1]
else:
token_dim = list(state_dict.values())[0].shape[0]
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
pass
except Exception:
pass
raise InvalidModelConfigException(f"{p}: Could not determine base type")
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatch(cls, f"unrecognized token dimension {token_dim}")
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
@@ -801,31 +815,31 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
VALID_OVERRIDES: ClassVar = {
"type": ModelType.TextualInversion,
"format": ModelFormat.EmbeddingFile,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.TextualInversion:
raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion")
if format_override is not None and format_override is not ModelFormat.EmbeddingFile:
raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFile")
if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFile:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_dir():
raise NotAMatch(cls, "model path is a directory, not a file")
if not cls.file_looks_like_embedding(mod):
if not cls._file_looks_like_embedding(mod):
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
base = fields.get("base") or cls.get_base_or_raise(mod)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@@ -834,30 +848,30 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
VALID_OVERRIDES: ClassVar = {
"type": ModelType.TextualInversion,
"format": ModelFormat.EmbeddingFolder,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.TextualInversion:
raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion")
if format_override is not None and format_override is not ModelFormat.EmbeddingFolder:
raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFolder")
if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFolder:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
for p in mod.weight_files():
if cls.file_looks_like_embedding(mod, p):
base = fields.get("base") or cls.get_base_or_raise(mod, p)
if cls._file_looks_like_embedding(mod, p):
base = fields.get("base") or cls._get_base_or_raise(mod, p)
return cls(**fields, base=base)
raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
@@ -937,7 +951,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
base: Literal[BaseModelType.Any] = BaseModelType.Any
CLASS_NAMES: ClassVar = {
VALID_CLASS_NAMES: ClassVar = {
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
@@ -963,47 +977,37 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
variant: Literal[ClipVariantType.G] = ClipVariantType.G
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPEmbed,
"format": ModelFormat.Diffusers,
"variant": ClipVariantType.G,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
variant_override = fields.get("variant")
if type_override is not None and type_override is not ModelType.CLIPEmbed:
raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if variant_override is not None and variant_override is not ClipVariantType.G:
raise NotAMatch(cls, f"variant override is {variant_override}, not G")
if (
type_override is ModelType.CLIPEmbed
and format_override is ModelFormat.Diffusers
and variant_override is ClipVariantType.G
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
config_path = mod.path / "config.json"
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config") from e
raise_for_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
config = get_config_or_raise(cls, config_path)
clip_variant = cls.get_clip_variant_type(config)
@@ -1018,48 +1022,37 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
variant: Literal[ClipVariantType.L] = ClipVariantType.L
VALID_OVERRIDES: ClassVar = {
"type": ModelType.CLIPEmbed,
"format": ModelFormat.Diffusers,
"variant": ClipVariantType.L,
}
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
variant_override = fields.get("variant")
if type_override is not None and type_override is not ModelType.CLIPEmbed:
raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if variant_override is not None and variant_override is not ClipVariantType.L:
raise NotAMatch(cls, f"variant override is {variant_override}, not L")
if (
type_override is ModelType.CLIPEmbed
and format_override is ModelFormat.Diffusers
and variant_override is ClipVariantType.L
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
config_path = mod.path / "config.json"
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config") from e
if config_class_name not in cls.CLASS_NAMES:
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
raise_for_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
config = get_config_or_raise(cls, config_path)
clip_variant = cls.get_clip_variant_type(config)
if clip_variant is not ClipVariantType.L:
@@ -1089,25 +1082,18 @@ class SpandrelImageToImageConfig(ModelConfigBase):
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
VALID_OVERRIDES: ClassVar = {
"type": ModelType.SpandrelImageToImage,
"format": ModelFormat.Checkpoint,
"base": BaseModelType.Any,
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
base_override = fields.get("base")
if type_override is not None and type_override is not ModelType.SpandrelImageToImage:
raise NotAMatch(cls, f"type override is {type_override}, not SpandrelImageToImage")
if format_override is not None and format_override is not ModelFormat.Checkpoint:
raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint")
if base_override is not None and base_override is not BaseModelType.Any:
raise NotAMatch(cls, f"base override is {base_override}, not Any")
if (
type_override is ModelType.SpandrelImageToImage
and format_override is ModelFormat.Checkpoint
and base_override is BaseModelType.Any
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
@@ -1151,40 +1137,36 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
base: Literal[BaseModelType.Any] = BaseModelType.Any
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
VALID_OVERRIDES: ClassVar = {
"type": ModelType.LlavaOnevision,
"format": ModelFormat.Diffusers,
}
VALID_CLASS_NAMES: ClassVar = {
"LlavaOnevisionForConditionalGeneration",
}
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
type_override = fields.get("type")
format_override = fields.get("format")
if type_override is not None and type_override is not ModelType.LlavaOnevision:
raise NotAMatch(cls, f"type override is {type_override}, not LlavaOnevision")
if format_override is not None and format_override is not ModelFormat.Diffusers:
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
if type_override is ModelType.LlavaOnevision and format_override is ModelFormat.Diffusers:
if matches_overrides(
config_class=cls,
provided_overrides=fields,
valid_overrides=cls.VALID_OVERRIDES,
):
return cls(**fields)
if mod.path.is_file():
raise NotAMatch(cls, "model path is a file, not a directory")
# Heuristic: Look for the LlavaOnevisionForConditionalGeneration class name in the config
try:
config = load_json(mod.path / "config.json")
except Exception as e:
raise NotAMatch(cls, "unable to load config.json") from e
config_path = mod.path / "config.json"
try:
config_class_name = get_class_name_from_config(config)
except Exception as e:
raise NotAMatch(cls, "unable to determine class name from config.json") from e
raise_for_class_names(
config_class=cls,
config_path=config_path,
valid_class_names=cls.VALID_CLASS_NAMES,
)
if config_class_name != "LlavaOnevisionForConditionalGeneration":
raise NotAMatch(cls, "model class is not LlavaOnevisionForConditionalGeneration")
base = fields.get("base") or BaseModelType.Any
variant = fields.get("variant") or ModelVariantType.Normal
return cls(**fields, base=base, variant=variant)
return cls(**fields)
class ApiModelConfig(MainConfigBase, ModelConfigBase):

View File

@@ -7,7 +7,8 @@ from pathlib import Path
from typing import get_args
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
from invokeai.backend.model_manager.config import ModelConfigFactory
algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))
@@ -30,7 +31,7 @@ def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
try:
return ModelProbe.probe(path, hash_algo=hash_algo)
except InvalidModelConfigException:
return ModelConfigBase.classify(path, hash_algo)
return ModelConfigFactory.from_model_on_disk(mod=path, hash_algo=hash_algo,)
for path in args.model_path:

View File

@@ -132,7 +132,10 @@ class MinimalConfigExample(ModelConfigBase):
def test_minimal_working_example(datadir: Path):
model_path = datadir / "minimal_config_model.json"
overrides = {"base": BaseModelType.StableDiffusion1}
config = ModelConfigBase.classify(model_path, **overrides)
config = ModelConfigFactory.from_model_on_disk(
mod=model_path,
overrides=overrides,
)
assert isinstance(config, MinimalConfigExample)
assert config.base == BaseModelType.StableDiffusion1
@@ -160,7 +163,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
try:
stripped_mod = StrippedModelOnDisk(path)
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
new_config = ModelConfigFactory.from_model_on_disk(
mod=stripped_mod,
overrides={"hash": fake_hash, "key": fake_key},
)
except InvalidModelConfigException:
pass