From 4f413d271440f3a9122fbdc43f83ff9d92b77302 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Sep 2025 14:20:33 +1000 Subject: [PATCH] feat(mm): port spandrel to new API --- invokeai/backend/model_manager/config.py | 77 +++++++++++++------ .../backend/model_manager/legacy_probe.py | 24 ------ 2 files changed, 55 insertions(+), 46 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 5111db6cd4..943641c968 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,6 +30,7 @@ from inspect import isabstract from pathlib import Path from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union +import spandrel import torch from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict @@ -56,6 +57,7 @@ from invokeai.backend.model_manager.taxonomy import ( variant_type_adapter, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES logger = logging.getLogger(__name__) @@ -605,31 +607,37 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, class TextualInversionConfigBase(ABC, BaseModel): type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion + KNOWN_SUFFIXES: ClassVar = {"bin", "safetensors", "pt", "ckpt"} + KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"} + @classmethod def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: - p = path or mod.path + try: + p = path or mod.path + + if not p.exists(): + return False + + if p.is_dir(): + return False + + if p.name in [f"learned_embeds.{s}" for s in cls.KNOWN_SUFFIXES]: + return True + + state_dict = mod.load_state_dict(p) + + # Heuristic: textual inversion embeddings have these keys + if any(key in cls.KNOWN_KEYS for key in state_dict.keys()): + return True + + # Heuristic: small state dict with all tensor values + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): + return True - if not p.exists(): return False - - if p.is_dir(): + except Exception: return False - if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}: - return True - - state_dict = mod.load_state_dict(p) - - # Heuristic: textual inversion embeddings have these keys - if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()): - return True - - # Heuristic: small state dict with all tensor values - if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): - return True - - return False - @classmethod def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: p = path or mod.path @@ -716,8 +724,8 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): if mod.path.is_file(): return MatchCertainty.NEVER - for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}: - if cls.file_looks_like_embedding(mod, mod.path / filename): + for p in mod.path.iterdir(): + if cls.file_looks_like_embedding(mod, p): return MatchCertainty.MAYBE return MatchCertainty.NEVER @@ -929,7 +937,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers -class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase): +class SpandrelImageToImageConfig(ModelConfigBase): """Model config for Spandrel Image to Image models.""" _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk @@ -937,6 +945,31 @@ class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase): type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + if not mod.path.is_file(): + return MatchCertainty.NEVER + + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(mod.path) + return MatchCertainty.EXACT + except spandrel.UnsupportedModelError: + pass + except Exception as e: + logger.warning( + f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}" + ) + + return MatchCertainty.NEVER + class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for SigLIP.""" diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 7d95e33081..b6bda0c15f 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Literal, Optional, Union import picklescan.scanner as pscan import safetensors.torch -import spandrel import torch import invokeai.backend.util.logging as logger @@ -59,7 +58,6 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u ) from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader -from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.util.silence_warnings import SilenceWarnings CkptType = Dict[str | int, Any] @@ -340,26 +338,6 @@ class ModelProbe(object): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion - # Check if the model can be loaded as a SpandrelImageToImageModel. - # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk). - try: - # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were - # explored to avoid this: - # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta - # device. Unfortunately, some Spandrel models perform operations during initialization that are not - # supported on meta tensors. - # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. - # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to - # maintain it, and the risk of false positive detections is higher. - SpandrelImageToImageModel.load_from_file(model_path) - return ModelType.SpandrelImageToImage - except spandrel.UnsupportedModelError: - pass - except Exception as e: - logger.warning( - f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}" - ) - raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") @classmethod @@ -1110,7 +1088,6 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe) ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe) @@ -1123,7 +1100,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)