feat(mm): port spandrel to new API

This commit is contained in:
psychedelicious
2025-09-23 14:20:33 +10:00
parent 6877e0bd01
commit 4f413d2714
2 changed files with 55 additions and 46 deletions

View File

@@ -30,6 +30,7 @@ from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
import spandrel
import torch
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@@ -56,6 +57,7 @@ from invokeai.backend.model_manager.taxonomy import (
variant_type_adapter,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
logger = logging.getLogger(__name__)
@@ -605,31 +607,37 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
class TextualInversionConfigBase(ABC, BaseModel):
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
KNOWN_SUFFIXES: ClassVar = {"bin", "safetensors", "pt", "ckpt"}
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
@classmethod
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
p = path or mod.path
try:
p = path or mod.path
if not p.exists():
return False
if p.is_dir():
return False
if p.name in [f"learned_embeds.{s}" for s in cls.KNOWN_SUFFIXES]:
return True
state_dict = mod.load_state_dict(p)
# Heuristic: textual inversion embeddings have these keys
if any(key in cls.KNOWN_KEYS for key in state_dict.keys()):
return True
# Heuristic: small state dict with all tensor values
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
return True
if not p.exists():
return False
if p.is_dir():
except Exception:
return False
if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}:
return True
state_dict = mod.load_state_dict(p)
# Heuristic: textual inversion embeddings have these keys
if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()):
return True
# Heuristic: small state dict with all tensor values
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
return True
return False
@classmethod
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
@@ -716,8 +724,8 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
if mod.path.is_file():
return MatchCertainty.NEVER
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
if cls.file_looks_like_embedding(mod, mod.path / filename):
for p in mod.path.iterdir():
if cls.file_looks_like_embedding(mod, p):
return MatchCertainty.MAYBE
return MatchCertainty.NEVER
@@ -929,7 +937,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
@@ -937,6 +945,31 @@ class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
if not mod.path.is_file():
return MatchCertainty.NEVER
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
return MatchCertainty.EXACT
except spandrel.UnsupportedModelError:
pass
except Exception as e:
logger.warning(
f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}"
)
return MatchCertainty.NEVER
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for SigLIP."""

View File

@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Literal, Optional, Union
import picklescan.scanner as pscan
import safetensors.torch
import spandrel
import torch
import invokeai.backend.util.logging as logger
@@ -59,7 +58,6 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
CkptType = Dict[str | int, Any]
@@ -340,26 +338,6 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel.
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage
except spandrel.UnsupportedModelError:
pass
except Exception as e:
logger.warning(
f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}"
)
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod
@@ -1110,7 +1088,6 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
@@ -1123,7 +1100,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)