feat(mm): port TIs to new API

This commit is contained in:
psychedelicious
2025-09-23 13:59:24 +10:00
parent e72c78f7d4
commit 8036bb0e8f

View File

@@ -30,6 +30,7 @@ from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
import torch
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@@ -601,19 +602,137 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
class TextualInversionConfigBase(ABC, BaseModel):
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
@classmethod
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
p = path or mod.path
if not p.exists():
return False
if p.is_dir():
return False
if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}:
return True
state_dict = mod.load_state_dict(p)
# Heuristic: textual inversion embeddings have these keys
if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()):
return True
# Heuristic: small state dict with all tensor values
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
return True
return False
@classmethod
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
try:
state_dict = mod.load_state_dict(p)
if "string_to_token" in state_dict:
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
elif "emb_params" in state_dict:
token_dim = state_dict["emb_params"].shape[-1]
elif "clip_g" in state_dict:
token_dim = state_dict["clip_g"].shape[-1]
else:
token_dim = list(state_dict.values())[0].shape[0]
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
pass
except Exception:
pass
raise InvalidModelConfigException(f"{p}: Could not determine base type")
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile
if is_embedding_override and is_file_override:
return MatchCertainty.OVERRIDE
if mod.path.is_dir():
return MatchCertainty.NEVER
if cls.file_looks_like_embedding(mod):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
try:
base = cls.get_base(mod)
return {"base": base}
except Exception:
pass
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder
if is_embedding_override and is_folder_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
if cls.file_looks_like_embedding(mod, mod.path / filename):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
try:
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
base = cls.get_base(mod, mod.path / filename)
return {"base": base}
except Exception:
pass
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
class MainConfigBase(ABC, BaseModel):
type: Literal[ModelType.Main] = ModelType.Main