mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): port TIs to new API
This commit is contained in:
@@ -30,6 +30,7 @@ from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
@@ -601,19 +602,137 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
|
||||
|
||||
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
class TextualInversionConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
|
||||
@classmethod
|
||||
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
|
||||
p = path or mod.path
|
||||
|
||||
if not p.exists():
|
||||
return False
|
||||
|
||||
if p.is_dir():
|
||||
return False
|
||||
|
||||
if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}:
|
||||
return True
|
||||
|
||||
state_dict = mod.load_state_dict(p)
|
||||
|
||||
# Heuristic: textual inversion embeddings have these keys
|
||||
if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()):
|
||||
return True
|
||||
|
||||
# Heuristic: small state dict with all tensor values
|
||||
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
||||
p = path or mod.path
|
||||
|
||||
try:
|
||||
state_dict = mod.load_state_dict(p)
|
||||
if "string_to_token" in state_dict:
|
||||
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in state_dict:
|
||||
token_dim = state_dict["emb_params"].shape[-1]
|
||||
elif "clip_g" in state_dict:
|
||||
token_dim = state_dict["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(state_dict.values())[0].shape[0]
|
||||
|
||||
match token_dim:
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case _:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{p}: Could not determine base type")
|
||||
|
||||
|
||||
class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
|
||||
is_file_override = overrides.get("format") is ModelFormat.EmbeddingFile
|
||||
|
||||
if is_embedding_override and is_file_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_dir():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
if cls.file_looks_like_embedding(mod):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
try:
|
||||
base = cls.get_base(mod)
|
||||
return {"base": base}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
|
||||
|
||||
|
||||
class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_embedding_override = overrides.get("type") is ModelType.TextualInversion
|
||||
is_folder_override = overrides.get("format") is ModelFormat.EmbeddingFolder
|
||||
|
||||
if is_embedding_override and is_folder_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
|
||||
if cls.file_looks_like_embedding(mod, mod.path / filename):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
try:
|
||||
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
|
||||
base = cls.get_base(mod, mod.path / filename)
|
||||
return {"base": base}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise InvalidModelConfigException(f"{mod.path}: Could not determine base type")
|
||||
|
||||
|
||||
class MainConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
Reference in New Issue
Block a user