mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(mm): add model config parsing utils
This commit is contained in:
@@ -39,7 +39,7 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@@ -612,7 +612,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
try:
|
||||
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
|
||||
except InvalidModelConfigException:
|
||||
return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo)
|
||||
return ModelConfigFactory.from_model_on_disk(
|
||||
mod=model_path,
|
||||
overrides=deepcopy(fields),
|
||||
hash_algo=hash_algo,
|
||||
)
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
@@ -633,7 +637,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
if isinstance(info, CheckpointConfigBase) and info.config_path is not None:
|
||||
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
|
||||
# invoke-managed legacy config dir, we use a relative path.
|
||||
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
|
||||
|
||||
@@ -86,34 +86,86 @@ class NotAMatch(Exception):
|
||||
reason: The reason why the model did not match.
|
||||
"""
|
||||
|
||||
def __init__(self, config_class: "Type[AnyModelConfig]", reason: str):
|
||||
def __init__(
|
||||
self,
|
||||
config_class: type,
|
||||
reason: str,
|
||||
):
|
||||
super().__init__(f"{config_class.__name__} does not match: {reason}")
|
||||
|
||||
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]:
|
||||
if "_class_name" in config:
|
||||
return config["_class_name"]
|
||||
elif "architectures" in config:
|
||||
return config["architectures"][0]
|
||||
else:
|
||||
return None
|
||||
def get_config_or_raise(
|
||||
config_class: type,
|
||||
config_path: Path,
|
||||
) -> dict[str, Any]:
|
||||
"""Load the config file at the given path, or raise NotAMatch if it cannot be loaded."""
|
||||
if not config_path.exists():
|
||||
raise NotAMatch(config_class, f"missing config file: {config_path}")
|
||||
|
||||
try:
|
||||
config = load_json(config_path)
|
||||
return config
|
||||
except Exception as e:
|
||||
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
|
||||
|
||||
|
||||
def validate_overrides(
|
||||
config_class: "Type[AnyModelConfig]", overrides: dict[str, Any], allowed: dict[str, Any]
|
||||
def raise_for_class_names(
|
||||
config_class: type,
|
||||
config_path: Path,
|
||||
valid_class_names: set[str],
|
||||
) -> None:
|
||||
for key, value in allowed.items():
|
||||
if key not in overrides:
|
||||
"""Raise NotAMatch if the config file is missing or does not contain a valid class name."""
|
||||
|
||||
config = get_config_or_raise(config_class, config_path)
|
||||
|
||||
try:
|
||||
if "_class_name" in config:
|
||||
config_class_name = config["_class_name"]
|
||||
elif "architectures" in config:
|
||||
config_class_name = config["architectures"][0]
|
||||
else:
|
||||
raise ValueError("missing _class_name or architectures field")
|
||||
except Exception as e:
|
||||
raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e
|
||||
|
||||
if config_class_name not in valid_class_names:
|
||||
raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}")
|
||||
|
||||
|
||||
def matches_overrides(
|
||||
config_class: "Type[AnyModelConfig]",
|
||||
provided_overrides: dict[str, Any],
|
||||
valid_overrides: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Check if the provided overrides match the valid overrides for this config class.
|
||||
|
||||
Args:
|
||||
config_class: The config class that is being tested.
|
||||
provided_overrides: The overrides provided by the user.
|
||||
valid_overrides: The overrides that are valid for this config class.
|
||||
|
||||
Returns:
|
||||
True if all provided overrides match the valid overrides, False if some valid overrides are missing.
|
||||
|
||||
Raises:
|
||||
NotAMatch if any override does not match the allowed value.
|
||||
"""
|
||||
is_perfect_match = True
|
||||
for key, value in valid_overrides.items():
|
||||
if key not in provided_overrides:
|
||||
is_perfect_match = False
|
||||
continue
|
||||
if overrides[key] != value:
|
||||
if provided_overrides[key] != value:
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {key}={overrides[key]} does not match required value {key}={value}",
|
||||
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
|
||||
)
|
||||
|
||||
return is_perfect_match
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
@@ -327,36 +379,32 @@ def load_json(path: Path) -> dict[str, Any]:
|
||||
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.T5Encoder,
|
||||
"format": ModelFormat.T5Encoder,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"T5EncoderModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.T5Encoder:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.T5Encoder:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not T5Encoder")
|
||||
|
||||
if type_override is ModelType.T5Encoder and format_override is ModelFormat.T5Encoder:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
# Heuristic: Look for the T5EncoderModel class name in the config
|
||||
try:
|
||||
config = load_json(mod.path / "text_encoder_2" / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e
|
||||
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e
|
||||
|
||||
if config_class_name != "T5EncoderModel":
|
||||
raise NotAMatch(cls, "model class is not T5EncoderModel")
|
||||
raise_for_class_names(
|
||||
config_class=cls,
|
||||
config_path=mod.path / "text_encoder_2" / "config.json",
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
|
||||
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
|
||||
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
|
||||
@@ -370,33 +418,30 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.T5Encoder,
|
||||
"format": ModelFormat.BnbQuantizedLlmInt8b,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"T5EncoderModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.T5Encoder:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.BnbQuantizedLlmInt8b:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not BnbQuantizedLlmInt8b")
|
||||
|
||||
if type_override is ModelType.T5Encoder and format_override is ModelFormat.BnbQuantizedLlmInt8b:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: Look for the T5EncoderModel class name in the config
|
||||
try:
|
||||
config = load_json(mod.path / "text_encoder_2" / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e
|
||||
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e
|
||||
|
||||
if config_class_name != "T5EncoderModel":
|
||||
raise NotAMatch(cls, "model class is not T5EncoderModel")
|
||||
raise_for_class_names(
|
||||
config_class=cls,
|
||||
config_path=mod.path / "text_encoder_2" / "config.json",
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
|
||||
# Heuristic: look for the quantization in the filename name
|
||||
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
|
||||
@@ -413,18 +458,18 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
|
||||
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LoRA,
|
||||
"format": ModelFormat.OMI,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.LoRA:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.OMI:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not OMI")
|
||||
|
||||
if type_override is ModelType.LoRA and format_override is ModelFormat.OMI:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: OMI LoRAs are always files, never directories
|
||||
@@ -446,12 +491,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
if not is_omi_lora_heuristic:
|
||||
raise NotAMatch(cls, "model does not match OMI LoRA heuristics")
|
||||
|
||||
base = fields.get("base") or cls.get_base_or_raise(mod)
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
metadata = mod.metadata()
|
||||
architecture = metadata["modelspec.architecture"]
|
||||
|
||||
@@ -468,18 +513,18 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LoRA,
|
||||
"format": ModelFormat.LyCORIS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.LoRA:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.LyCORIS:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not LyCORIS")
|
||||
|
||||
if type_override is ModelType.LoRA and format_override is ModelFormat.LyCORIS:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: LyCORIS LoRAs are always files, never directories
|
||||
@@ -544,18 +589,18 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LoRA,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.LoRA:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not LoRA")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Diffusers:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
|
||||
|
||||
if type_override is ModelType.LoRA and format_override is ModelFormat.Diffusers:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
# Heuristic: Diffusers LoRAs are always directories, never files
|
||||
@@ -583,8 +628,31 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.VAE,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
@@ -597,36 +665,41 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase):
|
||||
|
||||
raise NotAMatch(cls, "cannot determine base type")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.VAE:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not VAE")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Checkpoint:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint")
|
||||
|
||||
if type_override is ModelType.VAE and format_override is ModelFormat.Checkpoint:
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
|
||||
base = fields.get("base") or cls.get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
|
||||
class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"}
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.VAE,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"AutoencoderKL",
|
||||
"AutoencoderTiny",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
raise_for_class_names(
|
||||
config_class=cls,
|
||||
config_path=mod.path / "config.json",
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
|
||||
@@ -648,7 +721,8 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def get_base(cls, mod: ModelOnDisk, config: dict[str, Any]) -> BaseModelType:
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
config = get_config_or_raise(cls, mod.path / "config.json")
|
||||
if cls._config_looks_like_sdxl(config):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif cls._name_looks_like_sdxl(mod):
|
||||
@@ -657,39 +731,6 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase):
|
||||
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.VAE:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not VAE")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Diffusers:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
|
||||
|
||||
if type_override is ModelType.VAE and format_override is ModelFormat.Diffusers:
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
try:
|
||||
config = load_json(mod.path / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load config.json") from e
|
||||
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from config") from e
|
||||
|
||||
if config_class_name not in cls.CLASS_NAMES:
|
||||
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
|
||||
|
||||
base = fields.get("base") or cls.get_base(mod, config)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
@@ -710,7 +751,7 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
||||
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
|
||||
|
||||
@classmethod
|
||||
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
|
||||
def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
|
||||
try:
|
||||
p = path or mod.path
|
||||
|
||||
@@ -738,11 +779,15 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
||||
p = path or mod.path
|
||||
|
||||
try:
|
||||
state_dict = mod.load_state_dict(p)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, f"unable to load state dict from {p}: {e}") from e
|
||||
|
||||
try:
|
||||
if "string_to_token" in state_dict:
|
||||
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in state_dict:
|
||||
@@ -751,49 +796,18 @@ class TextualInversionConfigBase(ABC, BaseModel):
|
||||
token_dim = state_dict["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(state_dict.values())[0].shape[0]
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, f"unable to determine token dimension from state dict in {p}: {e}") from e
|
||||
|
||||
match token_dim:
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case _:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{p}: Could not determine base type")
|
||||
|
||||
@classmethod
|
||||
def get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
||||
p = path or mod.path
|
||||
|
||||
try:
|
||||
state_dict = mod.load_state_dict(p)
|
||||
if "string_to_token" in state_dict:
|
||||
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in state_dict:
|
||||
token_dim = state_dict["emb_params"].shape[-1]
|
||||
elif "clip_g" in state_dict:
|
||||
token_dim = state_dict["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(state_dict.values())[0].shape[0]
|
||||
|
||||
match token_dim:
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case _:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{p}: Could not determine base type")
|
||||
match token_dim:
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized token dimension {token_dim}")
|
||||
|
||||
|
||||
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
@@ -801,31 +815,31 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.TextualInversion,
|
||||
"format": ModelFormat.EmbeddingFile,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.TextualInversion:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.EmbeddingFile:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFile")
|
||||
|
||||
if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFile:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_dir():
|
||||
raise NotAMatch(cls, "model path is a directory, not a file")
|
||||
|
||||
if not cls.file_looks_like_embedding(mod):
|
||||
if not cls._file_looks_like_embedding(mod):
|
||||
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
|
||||
|
||||
base = fields.get("base") or cls.get_base_or_raise(mod)
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
|
||||
@@ -834,30 +848,30 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.TextualInversion,
|
||||
"format": ModelFormat.EmbeddingFolder,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.TextualInversion:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.EmbeddingFolder:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFolder")
|
||||
|
||||
if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFolder:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
for p in mod.weight_files():
|
||||
if cls.file_looks_like_embedding(mod, p):
|
||||
base = fields.get("base") or cls.get_base_or_raise(mod, p)
|
||||
if cls._file_looks_like_embedding(mod, p):
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod, p)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
|
||||
@@ -937,7 +951,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
|
||||
CLASS_NAMES: ClassVar = {
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
@@ -963,47 +977,37 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
variant: Literal[ClipVariantType.G] = ClipVariantType.G
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.CLIPEmbed,
|
||||
"format": ModelFormat.Diffusers,
|
||||
"variant": ClipVariantType.G,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
variant_override = fields.get("variant")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.CLIPEmbed:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Diffusers:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
|
||||
|
||||
if variant_override is not None and variant_override is not ClipVariantType.G:
|
||||
raise NotAMatch(cls, f"variant override is {variant_override}, not G")
|
||||
|
||||
if (
|
||||
type_override is ModelType.CLIPEmbed
|
||||
and format_override is ModelFormat.Diffusers
|
||||
and variant_override is ClipVariantType.G
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
try:
|
||||
config = load_json(mod.path / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load config.json") from e
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from config") from e
|
||||
raise_for_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
|
||||
if config_class_name not in cls.CLASS_NAMES:
|
||||
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
|
||||
config = get_config_or_raise(cls, config_path)
|
||||
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
|
||||
@@ -1018,48 +1022,37 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
variant: Literal[ClipVariantType.L] = ClipVariantType.L
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.CLIPEmbed,
|
||||
"format": ModelFormat.Diffusers,
|
||||
"variant": ClipVariantType.L,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
variant_override = fields.get("variant")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.CLIPEmbed:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Diffusers:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
|
||||
|
||||
if variant_override is not None and variant_override is not ClipVariantType.L:
|
||||
raise NotAMatch(cls, f"variant override is {variant_override}, not L")
|
||||
|
||||
if (
|
||||
type_override is ModelType.CLIPEmbed
|
||||
and format_override is ModelFormat.Diffusers
|
||||
and variant_override is ClipVariantType.L
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
try:
|
||||
config = load_json(mod.path / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load config.json") from e
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from config") from e
|
||||
|
||||
if config_class_name not in cls.CLASS_NAMES:
|
||||
raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}")
|
||||
raise_for_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
|
||||
config = get_config_or_raise(cls, config_path)
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.L:
|
||||
@@ -1089,25 +1082,18 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.SpandrelImageToImage,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
"base": BaseModelType.Any,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
base_override = fields.get("base")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.SpandrelImageToImage:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not SpandrelImageToImage")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Checkpoint:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint")
|
||||
|
||||
if base_override is not None and base_override is not BaseModelType.Any:
|
||||
raise NotAMatch(cls, f"base override is {base_override}, not Any")
|
||||
|
||||
if (
|
||||
type_override is ModelType.SpandrelImageToImage
|
||||
and format_override is ModelFormat.Checkpoint
|
||||
and base_override is BaseModelType.Any
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
@@ -1151,40 +1137,36 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LlavaOnevision,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"LlavaOnevisionForConditionalGeneration",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
type_override = fields.get("type")
|
||||
format_override = fields.get("format")
|
||||
|
||||
if type_override is not None and type_override is not ModelType.LlavaOnevision:
|
||||
raise NotAMatch(cls, f"type override is {type_override}, not LlavaOnevision")
|
||||
|
||||
if format_override is not None and format_override is not ModelFormat.Diffusers:
|
||||
raise NotAMatch(cls, f"format override is {format_override}, not Diffusers")
|
||||
|
||||
if type_override is ModelType.LlavaOnevision and format_override is ModelFormat.Diffusers:
|
||||
if matches_overrides(
|
||||
config_class=cls,
|
||||
provided_overrides=fields,
|
||||
valid_overrides=cls.VALID_OVERRIDES,
|
||||
):
|
||||
return cls(**fields)
|
||||
|
||||
if mod.path.is_file():
|
||||
raise NotAMatch(cls, "model path is a file, not a directory")
|
||||
|
||||
# Heuristic: Look for the LlavaOnevisionForConditionalGeneration class name in the config
|
||||
try:
|
||||
config = load_json(mod.path / "config.json")
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to load config.json") from e
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
try:
|
||||
config_class_name = get_class_name_from_config(config)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "unable to determine class name from config.json") from e
|
||||
raise_for_class_names(
|
||||
config_class=cls,
|
||||
config_path=config_path,
|
||||
valid_class_names=cls.VALID_CLASS_NAMES,
|
||||
)
|
||||
|
||||
if config_class_name != "LlavaOnevisionForConditionalGeneration":
|
||||
raise NotAMatch(cls, "model class is not LlavaOnevisionForConditionalGeneration")
|
||||
|
||||
base = fields.get("base") or BaseModelType.Any
|
||||
variant = fields.get("variant") or ModelVariantType.Normal
|
||||
return cls(**fields, base=base, variant=variant)
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class ApiModelConfig(MainConfigBase, ModelConfigBase):
|
||||
|
||||
@@ -7,7 +7,8 @@ from pathlib import Path
|
||||
from typing import get_args
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe
|
||||
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
|
||||
from invokeai.backend.model_manager.config import ModelConfigFactory
|
||||
|
||||
algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))
|
||||
|
||||
@@ -30,7 +31,7 @@ def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
|
||||
try:
|
||||
return ModelProbe.probe(path, hash_algo=hash_algo)
|
||||
except InvalidModelConfigException:
|
||||
return ModelConfigBase.classify(path, hash_algo)
|
||||
return ModelConfigFactory.from_model_on_disk(mod=path, hash_algo=hash_algo,)
|
||||
|
||||
|
||||
for path in args.model_path:
|
||||
|
||||
@@ -132,7 +132,10 @@ class MinimalConfigExample(ModelConfigBase):
|
||||
def test_minimal_working_example(datadir: Path):
|
||||
model_path = datadir / "minimal_config_model.json"
|
||||
overrides = {"base": BaseModelType.StableDiffusion1}
|
||||
config = ModelConfigBase.classify(model_path, **overrides)
|
||||
config = ModelConfigFactory.from_model_on_disk(
|
||||
mod=model_path,
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
assert isinstance(config, MinimalConfigExample)
|
||||
assert config.base == BaseModelType.StableDiffusion1
|
||||
@@ -160,7 +163,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
|
||||
|
||||
try:
|
||||
stripped_mod = StrippedModelOnDisk(path)
|
||||
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
|
||||
new_config = ModelConfigFactory.from_model_on_disk(
|
||||
mod=stripped_mod,
|
||||
overrides={"hash": fake_hash, "key": fake_key},
|
||||
)
|
||||
except InvalidModelConfigException:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user