mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
reimplement and clean up probe class
This commit is contained in:
@@ -79,7 +79,7 @@ class ModelProbe(object):
|
||||
model_path: Path,
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> ModelProbeInfo:
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
"""
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
|
||||
@@ -14,3 +14,4 @@ from .config import ( # noqa F401
|
||||
SubModelType,
|
||||
)
|
||||
from .model_install import ModelInstall # noqa F401
|
||||
from .probe import ModelProbe, InvalidModelException # noqa F401
|
||||
|
||||
@@ -171,13 +171,11 @@ class MainConfig(ModelConfigBase):
|
||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ONNXSD1Config(MainConfig):
|
||||
@@ -263,9 +261,9 @@ class ModelConfigFactory(object):
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
return class_to_return.parse_obj(model_data)
|
||||
except KeyError:
|
||||
except KeyError as exc:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(e)}") from e
|
||||
) from exc
|
||||
except ValidationError as exc:
|
||||
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc
|
||||
|
||||
@@ -1,55 +1,595 @@
|
||||
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
|
||||
"""
|
||||
Return descriptive information on Stable Diffusion models.
|
||||
|
||||
Module for probing a Stable Diffusion model and returning
|
||||
its base type, model type, format and variant.
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
read_checkpoint_meta
|
||||
read_checkpoint_meta,
|
||||
InvalidModelException,
|
||||
)
|
||||
import invokeai.configs.model_probe_templates as templates
|
||||
|
||||
from .config import (
|
||||
ModelType,
|
||||
BaseModelType,
|
||||
ModelVariantType,
|
||||
ModelFormat,
|
||||
SchedulerPredictionType
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .util import SilenceWarnings, lora_token_vector_length
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
"""Fields describing a probed model."""
|
||||
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: ModelFormat
|
||||
image_size: int
|
||||
|
||||
class ModelProbe(object):
|
||||
"""
|
||||
Class to probe a checkpoint, safetensors or diffusers folder.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class ModelProbeBase(ABC):
|
||||
"""Class to probe a checkpoint, safetensors or diffusers folder."""
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
@abstractmethod
|
||||
def probe(
|
||||
cls,
|
||||
model: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> ModelProbeInfo:
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
"""
|
||||
Probe model located at path and return ModelProbeInfo object.
|
||||
A Callable may be passed to return the SchedulerPredictionType.
|
||||
|
||||
:param model: Path to a model checkpoint or folder.
|
||||
:param prediction_type_helper: An optional Callable that takes the model path
|
||||
and returns the SchedulerPredictionType.
|
||||
"""
|
||||
pass
|
||||
|
||||
class ProbeBase(ABC):
|
||||
"""Base model for probing checkpoint and diffusers-style models."""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType for the model."""
|
||||
pass
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType for the model."""
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
"""Return the SchedulerPredictionType for the model."""
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format for the model."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelProbe(ModelProbeBase):
|
||||
"""Class to probe a checkpoint, safetensors or diffusers folder."""
|
||||
|
||||
PROBES = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
"""
|
||||
Register a probe subclass to use when interrogating a model.
|
||||
|
||||
:param format: The ModelFormat of the model to be probed.
|
||||
:param model_type: The ModelType of the model to be probed.
|
||||
:param probe_class: The class of the prober (inherits from ProbeBase).
|
||||
"""
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
"""Probe model."""
|
||||
try:
|
||||
model_type = (
|
||||
cls.get_model_type_from_folder(model)
|
||||
if model.is_dir()
|
||||
else cls.get_model_type_from_checkpoint(model)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX \
|
||||
else "diffusers" if model.is_dir() \
|
||||
else "checkpoint"
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model, prediction_type_helper)
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
format = probe.get_format()
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
variant_type=variant_type,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=(
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else 768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model: Path) -> Optional[ModelType]:
|
||||
"""
|
||||
Scan a checkpoint model and return its ModelType.
|
||||
|
||||
:param model: path to the model checkpoint/safetensors file
|
||||
"""
|
||||
if model.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
return None
|
||||
|
||||
if model.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = read_checkpoint_meta(model, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise InvalidModelException(f"Unable to determine model type for {model}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, model: Path) -> Optional[ModelType]:
|
||||
"""
|
||||
Get the model type of a hugging-face style folder.
|
||||
|
||||
:param model: Path to model folder.
|
||||
"""
|
||||
class_name = None
|
||||
if (model / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (model / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (model / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
|
||||
i = model / "model_index.json"
|
||||
c = model / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
class_name = conf["_class_name"]
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(f"Unable to determine model type for {model}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model: Path) -> dict:
|
||||
with SilenceWarnings():
|
||||
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model)
|
||||
return torch.load(model)
|
||||
else:
|
||||
return safetensors.torch.load_file(model)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model: Path):
|
||||
"""
|
||||
Scan a model for malicious code.
|
||||
|
||||
:param model: Path to the model to be scanned
|
||||
Raises an Exception if unsafe code is found.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(model)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
"""Base class for probing checkpoint-style models."""
|
||||
|
||||
def __init__(
|
||||
self, model: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None
|
||||
) -> BaseModelType:
|
||||
"""Initialize the CheckpointProbeBase object."""
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model)
|
||||
self.model = model
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of a checkpoint-style model."""
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of a checkpoint-style model."""
|
||||
return "checkpoint"
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType of a checkpoint-style model."""
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise InvalidModelException(
|
||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe a checkpoint-style main model."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the ModelBaseType for the checkpoint-style main model."""
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
else:
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
"""Return the SchedulerPredictionType for the checkpoint-style main model."""
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if (
|
||||
self.model and self.helper and not self.model.with_suffix(".yaml").exists()
|
||||
): # if a .yaml config file exists, then this step not needed
|
||||
return self.helper(self.model)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe a Checkpoint-style VAE model."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the VAE model."""
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe for LoRA Checkpoint Files."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of the LoRA."""
|
||||
return "lycoris"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the LoRA."""
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.model}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
"""TextualInversion checkpoint prober."""
|
||||
|
||||
def get_format(self) -> Optional[str]:
|
||||
"""Return the format of a TextualInversion emedding."""
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return BaseModelType of the checkpoint model."""
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe checkpoint-based ControlNet models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the model."""
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
"""Class for probing folder-based models."""
|
||||
|
||||
def __init__(self, model: Path, helper: Optional[Callable] = None): # not used
|
||||
"""
|
||||
Initialize the folder prober.
|
||||
|
||||
:param model: Path to the model to be probed.
|
||||
:param helper: Callable for returning the SchedulerPredictionType (unused).
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the model's variant type."""
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the model's format."""
|
||||
return "diffusers"
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
"""Probe a pipeline (main) folder."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of a pipeline folder."""
|
||||
with open(self.model / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
"""Return the SchedulerPredictionType of a diffusers-style sd-2 model."""
|
||||
with open(self.model / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType for diffusers-style main models."""
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
"""Probe a diffusers-style VAE model."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType for a diffusers-style VAE."""
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return (
|
||||
BaseModelType.StableDiffusionXL
|
||||
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
else BaseModelType.StableDiffusion1
|
||||
)
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
"""Probe a HuggingFace-style TextualInversion folder."""
|
||||
|
||||
def get_format(self) -> Optional[str]:
|
||||
"""Return the format of the TextualInversion."""
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder."""
|
||||
path = self.model / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
raise InvalidModelException(f"This textual inversion folder does not contain a learned_embeds.bin file.")
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
"""Probe an ONNX-format folder."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of the folder (always "onnx")."""
|
||||
return "onnx"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of the ONNX folder."""
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType of the ONNX folder."""
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
"""Probe a ControlNet model folder."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType of a ControlNet model folder."""
|
||||
config_file = self.model / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
"""Probe a LoRA model folder."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get the ModelBaseType of a LoRA model folder."""
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.model / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file).get_base_type()
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
||||
108
invokeai/backend/model_manager/util.py
Normal file
108
invokeai/backend/model_manager/util.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Various utilities used by the model manager.
|
||||
"""
|
||||
from typing import Optional
|
||||
import warnings
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
"""
|
||||
Context manager that silences warnings from transformers and diffusers.
|
||||
|
||||
Usage:
|
||||
with SilenceWarnings():
|
||||
do_something_that_generates_warnings()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SilenceWarnings context."""
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
"""Entry into the context."""
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
"""Exit from the context."""
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length.
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,64 +1,153 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Read a checkpoint/safetensors file and write out a template .json file containing
|
||||
its metadata for use in fast model probing.
|
||||
Model template creator.
|
||||
|
||||
Scan a tree of checkpoint/safetensors/diffusers models and write out
|
||||
a series of template .json file containing their metadata for use
|
||||
in fast model probing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import hashlib
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo
|
||||
from invokeai.backend.model_management.model_search import ModelSearch
|
||||
from invokeai.backend.model_manager import read_checkpoint_meta
|
||||
|
||||
|
||||
class CreateTemplateScanner(ModelSearch):
|
||||
"""Scan directory and create templates for each model found."""
|
||||
|
||||
_dest: Path
|
||||
|
||||
def __init__(self, directories: List[Path], dest: Path, **kwargs): # noqa D401
|
||||
"""Initialization routine.
|
||||
|
||||
:param dest: Base of templates directory.
|
||||
"""
|
||||
super().__init__(directories, **kwargs)
|
||||
self._dest = dest
|
||||
|
||||
def on_model_found(self, model: Path): # noqa D401
|
||||
"""Called when a model is found during recursive search."""
|
||||
info: ModelProbeInfo = ModelProbe.probe(model)
|
||||
if not info:
|
||||
return
|
||||
self.write_template(model, info)
|
||||
|
||||
def write_template(self, model: Path, info: ModelProbeInfo):
|
||||
"""Write template for a checkpoint file."""
|
||||
dest_path = Path(self._dest,
|
||||
"checkpoints" if model.is_file() else 'diffusers',
|
||||
info.base_type.value,
|
||||
info.model_type.value
|
||||
)
|
||||
template: dict = self._make_checkpoint_template(model) \
|
||||
if model.is_file() \
|
||||
else self._make_diffusers_template(model)
|
||||
if not template:
|
||||
print(f"Could not create template for {model}, got {template}")
|
||||
return
|
||||
|
||||
# sort the dict to avoid differences due to insertion order
|
||||
template = dict(sorted(template.items()))
|
||||
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
meta = dict(
|
||||
base_type=info.base_type.value,
|
||||
model_type=info.model_type.value,
|
||||
variant=info.variant_type.value,
|
||||
template=template,
|
||||
)
|
||||
payload = json.dumps(meta)
|
||||
hash = hashlib.md5(payload.encode("utf-8")).hexdigest()
|
||||
try:
|
||||
dest_file = dest_path / f"{hash}.json"
|
||||
if not dest_file.exists():
|
||||
with open(dest_file, "w", encoding="utf-8") as f:
|
||||
f.write(payload)
|
||||
print(f"Template written out as {dest_file}")
|
||||
except OSError as e:
|
||||
print(f"An exception occurred while writing template: {str(e)}")
|
||||
|
||||
def _make_checkpoint_template(self, model: Path) -> Optional[dict]:
|
||||
"""Make template dict for a checkpoint-style model."""
|
||||
tmpl = None
|
||||
try:
|
||||
ckpt = read_checkpoint_meta(model)
|
||||
while "state_dict" in ckpt:
|
||||
ckpt = ckpt["state_dict"]
|
||||
tmpl = {}
|
||||
for key, value in ckpt.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tmpl[key] = list(value.shape)
|
||||
elif isinstance(value, dict): # handle one level of nesting - if more we should recurse
|
||||
for subkey, subvalue in value.items():
|
||||
if isinstance(subvalue, torch.Tensor):
|
||||
tmpl[f"{key}.{subkey}"] = subvalue.shape
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
return tmpl
|
||||
|
||||
def _make_diffusers_template(self, model: Path) -> Optional[dict]:
|
||||
"""
|
||||
Make template dict for a diffusers-style model.
|
||||
|
||||
In case of a pipeline, template keys will be 'unet', 'text_encoder', 'text_encoder_2' and 'vae'.
|
||||
In case of another folder-style model, the template will simply contain the contents of config.json.
|
||||
"""
|
||||
tmpl = None
|
||||
if (model / "model_index.json").exists(): # a pipeline
|
||||
tmpl = {}
|
||||
for subdir in ['unet', 'text_encoder', 'vae', 'text_encoder_2']:
|
||||
config = model / subdir / "config.json"
|
||||
try:
|
||||
tmpl[subdir] = self._read_config(config)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif (model / "learned_embeds.bin").exists(): # concepts model
|
||||
return self._make_checkpoint_template(model / "learned_embeds.bin")
|
||||
else:
|
||||
config = model / "config.json"
|
||||
try:
|
||||
tmpl = self._read_config(config)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return tmpl
|
||||
|
||||
def _read_config(self, config: Path) -> dict:
|
||||
with open(config, 'r', encoding='utf-8') as f:
|
||||
return {x: y for x, y in json.load(f).items() if not x.startswith("_")}
|
||||
|
||||
def on_search_completed(self):
|
||||
"""Not used."""
|
||||
pass
|
||||
|
||||
def on_search_started(self):
|
||||
"""Not used."""
|
||||
pass
|
||||
|
||||
from invokeai.backend.model_manager import(
|
||||
read_checkpoint_meta,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
BaseModelType,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create a .json template from checkpoint/safetensors model",
|
||||
description="Scan the provided path recursively and create .json templates for all models found.",
|
||||
)
|
||||
parser.add_argument('checkpoint', type=Path, help="Path to the input checkpoint/safetensors file")
|
||||
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
|
||||
parser.add_argument("--base-type",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
help="Base model",
|
||||
parser.add_argument("--scan",
|
||||
type=Path,
|
||||
help="Path to recursively scan for models"
|
||||
)
|
||||
parser.add_argument("--model-type",
|
||||
type=str,
|
||||
choices=[x.value for x in ModelType],
|
||||
default='main',
|
||||
help="Type of the model",
|
||||
)
|
||||
parser.add_argument("--variant",
|
||||
type=str,
|
||||
choices=[x.value for x in ModelVariantType],
|
||||
default='normal',
|
||||
help="Base type of the model",
|
||||
parser.add_argument("--out",
|
||||
type=Path,
|
||||
dest="outdir",
|
||||
default=Path(__file__).resolve().parents[1] / "invokeai/configs/model_probe_templates",
|
||||
help="Destination for templates",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
ckpt = read_checkpoint_meta(opt.checkpoint)
|
||||
while "state_dict" in ckpt:
|
||||
ckpt = ckpt["state_dict"]
|
||||
|
||||
tmpl = {}
|
||||
|
||||
for key, tensor in ckpt.items():
|
||||
tmpl[key] = list(tensor.shape)
|
||||
|
||||
meta = {
|
||||
'base_type': opt.base_type,
|
||||
'model_type': opt.model_type,
|
||||
'variant': opt.variant,
|
||||
'template': tmpl
|
||||
}
|
||||
|
||||
try:
|
||||
with open(opt.template, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f)
|
||||
print(f"Template written out as {opt.template}")
|
||||
except OSError as e:
|
||||
print(f"An exception occurred while writing template: {str(e)}")
|
||||
scanner = CreateTemplateScanner([opt.scan], dest=opt.outdir)
|
||||
scanner.search()
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
#!/bin/env python
|
||||
|
||||
"""Little command-line utility for probing a model on disk."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
||||
from invokeai.backend.model_manager import (
|
||||
ModelProbe,
|
||||
SchedulerPredictionType,
|
||||
InvalidModelException,
|
||||
)
|
||||
|
||||
def helper(model_path: Path):
|
||||
print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr)
|
||||
return SchedulerPredictionType.VPrediction
|
||||
|
||||
parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
@@ -13,5 +24,8 @@ parser.add_argument(
|
||||
args = parser.parse_args()
|
||||
|
||||
for path in args.model_path:
|
||||
info = ModelProbe().probe(path)
|
||||
print(f"{path}: {info}")
|
||||
try:
|
||||
info = ModelProbe().probe(path, helper)
|
||||
print(f"{path}: {info}")
|
||||
except InvalidModelException as exc:
|
||||
print(exc)
|
||||
|
||||
Reference in New Issue
Block a user