mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): port spandrel to new API
This commit is contained in:
@@ -30,6 +30,7 @@ from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import spandrel
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
@@ -56,6 +57,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
variant_type_adapter,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -605,31 +607,37 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
||||
class TextualInversionConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
|
||||
KNOWN_SUFFIXES: ClassVar = {"bin", "safetensors", "pt", "ckpt"}
|
||||
KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"}
|
||||
|
||||
@classmethod
|
||||
def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
|
||||
p = path or mod.path
|
||||
try:
|
||||
p = path or mod.path
|
||||
|
||||
if not p.exists():
|
||||
return False
|
||||
|
||||
if p.is_dir():
|
||||
return False
|
||||
|
||||
if p.name in [f"learned_embeds.{s}" for s in cls.KNOWN_SUFFIXES]:
|
||||
return True
|
||||
|
||||
state_dict = mod.load_state_dict(p)
|
||||
|
||||
# Heuristic: textual inversion embeddings have these keys
|
||||
if any(key in cls.KNOWN_KEYS for key in state_dict.keys()):
|
||||
return True
|
||||
|
||||
# Heuristic: small state dict with all tensor values
|
||||
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
|
||||
return True
|
||||
|
||||
if not p.exists():
|
||||
return False
|
||||
|
||||
if p.is_dir():
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if p.name in {"learned_embeds.bin", "learned_embeds.safetensors"}:
|
||||
return True
|
||||
|
||||
state_dict = mod.load_state_dict(p)
|
||||
|
||||
# Heuristic: textual inversion embeddings have these keys
|
||||
if any(key in {"string_to_param", "emb_params"} for key in state_dict.keys()):
|
||||
return True
|
||||
|
||||
# Heuristic: small state dict with all tensor values
|
||||
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
|
||||
p = path or mod.path
|
||||
@@ -716,8 +724,8 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
for filename in {"learned_embeds.bin", "learned_embeds.safetensors"}:
|
||||
if cls.file_looks_like_embedding(mod, mod.path / filename):
|
||||
for p in mod.path.iterdir():
|
||||
if cls.file_looks_like_embedding(mod, p):
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
@@ -929,7 +937,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProb
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
|
||||
@@ -937,6 +945,31 @@ class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
if not mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(mod.path)
|
||||
return MatchCertainty.EXACT
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Encountered error while probing to determine if {mod.path} is a Spandrel model. Ignoring. Error: {e}"
|
||||
)
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
|
||||
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for SigLIP."""
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import picklescan.scanner as pscan
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@@ -59,7 +58,6 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
@@ -340,26 +338,6 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(model_path)
|
||||
return ModelType.SpandrelImageToImage
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}"
|
||||
)
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
@@ -1110,7 +1088,6 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
|
||||
@@ -1123,7 +1100,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)
|
||||
|
||||
Reference in New Issue
Block a user