diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b552328153..5111db6cd4 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,6 +30,7 @@ from inspect import isabstract from pathlib import Path from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union +import torch from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict @@ -601,19 +602,137 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, type: Literal[ModelType.ControlNet] = ModelType.ControlNet -class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase): +class TextualInversionConfigBase(ABC, BaseModel): + type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion + + @classmethod + def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: + p = path or mod.path + + if not p.exists(): + return False + + if p.is_dir(): + return False + + if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}: + return True + + state_dict = mod.load_state_dict(p) + + # Heuristic: textual inversion embeddings have these keys + if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()): + return True + + # Heuristic: small state dict with all tensor values + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): + return True + + return False + + @classmethod + def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + p = path or mod.path + + try: + state_dict = mod.load_state_dict(p) + if "string_to_token" in state_dict: + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] + elif "emb_params" in state_dict: + token_dim = state_dict["emb_params"].shape[-1] + elif "clip_g" in state_dict: + token_dim = state_dict["clip_g"].shape[-1] + else: + token_dim = list(state_dict.values())[0].shape[0] + + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + pass + except Exception: + pass + + raise InvalidModelConfigException(f"{p}: Could not determine base type") + + +class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): """Model config for textual inversion embeddings.""" - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile + @classmethod + def get_tag(cls) -> Tag: + return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}") -class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase): + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_embedding_override = overrides.get("type") is ModelType.TextualInversion + is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile + + if is_embedding_override and is_file_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_dir(): + return MatchCertainty.NEVER + + if cls.file_looks_like_embedding(mod): + return MatchCertainty.MAYBE + + return MatchCertainty.NEVER + + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + try: + base = cls.get_base(mod) + return {"base": base} + except Exception: + pass + + raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") + + +class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): """Model config for textual inversion embeddings.""" - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder + @classmethod + def get_tag(cls) -> Tag: + return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") + + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_embedding_override = overrides.get("type") is ModelType.TextualInversion + is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder + + if is_embedding_override and is_folder_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_file(): + return MatchCertainty.NEVER + + for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: + if cls.file_looks_like_embedding(mod, mod.path / filename): + return MatchCertainty.MAYBE + + return MatchCertainty.NEVER + + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + try: + for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: + base = cls.get_base(mod, mod.path / filename) + return {"base": base} + except Exception: + pass + + raise InvalidModelConfigException(f"{mod.path}: Could not determine base type") + class MainConfigBase(ABC, BaseModel): type: Literal[ModelType.Main] = ModelType.Main