reimplement and clean up probe class

This commit is contained in:
Lincoln Stein
2023-08-22 22:24:07 -04:00
parent f023e342ef
commit 6f9bf87a7a
13 changed files with 821 additions and 77 deletions

View File

@@ -79,7 +79,7 @@ class ModelProbe(object):
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
) -> Optional[ModelProbeInfo]:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is

View File

@@ -14,3 +14,4 @@ from .config import ( # noqa F401
SubModelType,
)
from .model_install import ModelInstall # noqa F401
from .probe import ModelProbe, InvalidModelException # noqa F401

View File

@@ -171,13 +171,11 @@ class MainConfig(ModelConfigBase):
class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""
pass
class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""
pass
class ONNXSD1Config(MainConfig):
@@ -263,9 +261,9 @@ class ModelConfigFactory(object):
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
return class_to_return.parse_obj(model_data)
except KeyError:
except KeyError as exc:
raise InvalidModelConfigException(
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"
)
except ValidationError as e:
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(e)}") from e
) from exc
except ValidationError as exc:
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc

View File

@@ -1,55 +1,595 @@
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
"""
Return descriptive information on Stable Diffusion models.
Module for probing a Stable Diffusion model and returning
its base type, model type, format and variant.
"""
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Callable
from picklescan.scanner import scan_file_path
import torch
import safetensors.torch
from invokeai.backend.model_management.models.base import (
read_checkpoint_meta
read_checkpoint_meta,
InvalidModelException,
)
import invokeai.configs.model_probe_templates as templates
from .config import (
ModelType,
BaseModelType,
ModelVariantType,
ModelFormat,
SchedulerPredictionType
SchedulerPredictionType,
)
from .util import SilenceWarnings, lora_token_vector_length
@dataclass
class ModelProbeInfo(object):
"""Fields describing a probed model."""
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: ModelFormat
image_size: int
class ModelProbe(object):
"""
Class to probe a checkpoint, safetensors or diffusers folder.
"""
def __init__(self):
pass
class ModelProbeBase(ABC):
"""Class to probe a checkpoint, safetensors or diffusers folder."""
@classmethod
def heuristic_probe(
@abstractmethod
def probe(
cls,
model: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
) -> Optional[ModelProbeInfo]:
"""
Probe model located at path and return ModelProbeInfo object.
A Callable may be passed to return the SchedulerPredictionType.
:param model: Path to a model checkpoint or folder.
:param prediction_type_helper: An optional Callable that takes the model path
and returns the SchedulerPredictionType.
"""
pass
class ProbeBase(ABC):
"""Base model for probing checkpoint and diffusers-style models."""
@abstractmethod
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType for the model."""
pass
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType for the model."""
pass
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
"""Return the SchedulerPredictionType for the model."""
pass
def get_format(self) -> str:
"""Return the format for the model."""
pass
class ModelProbe(ModelProbeBase):
"""Class to probe a checkpoint, safetensors or diffusers folder."""
PROBES = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
}
@classmethod
def register_probe(
cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase
):
"""
Register a probe subclass to use when interrogating a model.
:param format: The ModelFormat of the model to be probed.
:param model_type: The ModelType of the model to be probed.
:param probe_class: The class of the prober (inherits from ProbeBase).
"""
cls.PROBES[format][model_type] = probe_class
@classmethod
def probe(
cls,
model: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Optional[ModelProbeInfo]:
"""Probe model."""
try:
model_type = (
cls.get_model_type_from_folder(model)
if model.is_dir()
else cls.get_model_type_from_checkpoint(model)
)
format_type = "onnx" if model_type == ModelType.ONNX \
else "diffusers" if model.is_dir() \
else "checkpoint"
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512,
)
except Exception:
raise
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model: Path) -> Optional[ModelType]:
"""
Scan a checkpoint model and return its ModelType.
:param model: path to the model checkpoint/safetensors file
"""
if model.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = read_checkpoint_meta(model, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelException(f"Unable to determine model type for {model}")
@classmethod
def get_model_type_from_folder(cls, model: Path) -> Optional[ModelType]:
"""
Get the model type of a hugging-face style folder.
:param model: Path to model folder.
"""
class_name = None
if (model / "unet/model.onnx").exists():
return ModelType.ONNX
if (model / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (model / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
i = model / "model_index.json"
c = model / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
class_name = conf["_class_name"]
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
# give up
raise InvalidModelException(f"Unable to determine model type for {model}")
@classmethod
def _scan_and_load_checkpoint(cls, model: Path) -> dict:
with SilenceWarnings():
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model)
return torch.load(model)
else:
return safetensors.torch.load_file(model)
@classmethod
def _scan_model(cls, model: Path):
"""
Scan a model for malicious code.
:param model: Path to the model to be scanned
Raises an Exception if unsafe code is found.
"""
# scan model
scan_result = scan_file_path(model)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
# ##################################################3
# Checkpoint probing
# ##################################################3
class CheckpointProbeBase(ProbeBase):
"""Base class for probing checkpoint-style models."""
def __init__(
self, model: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None
) -> BaseModelType:
"""Initialize the CheckpointProbeBase object."""
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model)
self.model = model
self.helper = helper
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of a checkpoint-style model."""
pass
def get_format(self) -> str:
"""Return the format of a checkpoint-style model."""
return "checkpoint"
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType of a checkpoint-style model."""
model_type = ModelProbe.get_model_type_from_checkpoint(self.model)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
"""Probe a checkpoint-style main model."""
def get_base_type(self) -> BaseModelType:
"""Return the ModelBaseType for the checkpoint-style main model."""
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
"""Return the SchedulerPredictionType for the checkpoint-style main model."""
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if (
self.model and self.helper and not self.model.with_suffix(".yaml").exists()
): # if a .yaml config file exists, then this step not needed
return self.helper(self.model)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
"""Probe a Checkpoint-style VAE model."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the VAE model."""
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
"""Probe for LoRA Checkpoint Files."""
def get_format(self) -> str:
"""Return the format of the LoRA."""
return "lycoris"
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the LoRA."""
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.model}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
"""TextualInversion checkpoint prober."""
def get_format(self) -> Optional[str]:
"""Return the format of a TextualInversion emedding."""
return None
def get_base_type(self) -> BaseModelType:
"""Return BaseModelType of the checkpoint model."""
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Probe checkpoint-based ControlNet models."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the model."""
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
"""Class for probing folder-based models."""
def __init__(self, model: Path, helper: Optional[Callable] = None): # not used
"""
Initialize the folder prober.
:param model: Path to the model to be probed.
:param helper: Callable for returning the SchedulerPredictionType (unused).
"""
self.model = model
def get_variant_type(self) -> ModelVariantType:
"""Return the model's variant type."""
return ModelVariantType.Normal
def get_format(self) -> str:
"""Return the model's format."""
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
"""Probe a pipeline (main) folder."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of a pipeline folder."""
with open(self.model / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
"""Return the SchedulerPredictionType of a diffusers-style sd-2 model."""
with open(self.model / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType for diffusers-style main models."""
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
"""Probe a diffusers-style VAE model."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType for a diffusers-style VAE."""
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return (
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
else BaseModelType.StableDiffusion1
)
class TextualInversionFolderProbe(FolderProbeBase):
"""Probe a HuggingFace-style TextualInversion folder."""
def get_format(self) -> Optional[str]:
"""Return the format of the TextualInversion."""
return None
def get_base_type(self) -> BaseModelType:
"""Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder."""
path = self.model / "learned_embeds.bin"
if not path.exists():
raise InvalidModelException(f"This textual inversion folder does not contain a learned_embeds.bin file.")
return TextualInversionCheckpointProbe(path).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
"""Probe an ONNX-format folder."""
def get_format(self) -> str:
"""Return the format of the folder (always "onnx")."""
return "onnx"
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the ONNX folder."""
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType of the ONNX folder."""
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
"""Probe a ControlNet model folder."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of a ControlNet model folder."""
config_file = self.model / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
"""Probe a LoRA model folder."""
def get_base_type(self) -> BaseModelType:
"""Get the ModelBaseType of a LoRA model folder."""
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.model / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type()
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -0,0 +1,108 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""
Various utilities used by the model manager.
"""
from typing import Optional
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
"""
Context manager that silences warnings from transformers and diffusers.
Usage:
with SilenceWarnings():
do_something_that_generates_warnings()
"""
def __init__(self):
"""Initialize SilenceWarnings context."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
"""Entry into the context."""
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Exit from the context."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")
def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
"""
Given a checkpoint in memory, return the lora token vector length.
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,64 +1,153 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and write out a template .json file containing
its metadata for use in fast model probing.
Model template creator.
Scan a tree of checkpoint/safetensors/diffusers models and write out
a series of template .json file containing their metadata for use
in fast model probing.
"""
import argparse
import json
import hashlib
import sys
import traceback
from pathlib import Path
from typing import List, Optional
import torch
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo
from invokeai.backend.model_management.model_search import ModelSearch
from invokeai.backend.model_manager import read_checkpoint_meta
class CreateTemplateScanner(ModelSearch):
"""Scan directory and create templates for each model found."""
_dest: Path
def __init__(self, directories: List[Path], dest: Path, **kwargs): # noqa D401
"""Initialization routine.
:param dest: Base of templates directory.
"""
super().__init__(directories, **kwargs)
self._dest = dest
def on_model_found(self, model: Path): # noqa D401
"""Called when a model is found during recursive search."""
info: ModelProbeInfo = ModelProbe.probe(model)
if not info:
return
self.write_template(model, info)
def write_template(self, model: Path, info: ModelProbeInfo):
"""Write template for a checkpoint file."""
dest_path = Path(self._dest,
"checkpoints" if model.is_file() else 'diffusers',
info.base_type.value,
info.model_type.value
)
template: dict = self._make_checkpoint_template(model) \
if model.is_file() \
else self._make_diffusers_template(model)
if not template:
print(f"Could not create template for {model}, got {template}")
return
# sort the dict to avoid differences due to insertion order
template = dict(sorted(template.items()))
dest_path.mkdir(parents=True, exist_ok=True)
meta = dict(
base_type=info.base_type.value,
model_type=info.model_type.value,
variant=info.variant_type.value,
template=template,
)
payload = json.dumps(meta)
hash = hashlib.md5(payload.encode("utf-8")).hexdigest()
try:
dest_file = dest_path / f"{hash}.json"
if not dest_file.exists():
with open(dest_file, "w", encoding="utf-8") as f:
f.write(payload)
print(f"Template written out as {dest_file}")
except OSError as e:
print(f"An exception occurred while writing template: {str(e)}")
def _make_checkpoint_template(self, model: Path) -> Optional[dict]:
"""Make template dict for a checkpoint-style model."""
tmpl = None
try:
ckpt = read_checkpoint_meta(model)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
tmpl = {}
for key, value in ckpt.items():
if isinstance(value, torch.Tensor):
tmpl[key] = list(value.shape)
elif isinstance(value, dict): # handle one level of nesting - if more we should recurse
for subkey, subvalue in value.items():
if isinstance(subvalue, torch.Tensor):
tmpl[f"{key}.{subkey}"] = subvalue.shape
except Exception as e:
traceback.print_exception(e)
return tmpl
def _make_diffusers_template(self, model: Path) -> Optional[dict]:
"""
Make template dict for a diffusers-style model.
In case of a pipeline, template keys will be 'unet', 'text_encoder', 'text_encoder_2' and 'vae'.
In case of another folder-style model, the template will simply contain the contents of config.json.
"""
tmpl = None
if (model / "model_index.json").exists(): # a pipeline
tmpl = {}
for subdir in ['unet', 'text_encoder', 'vae', 'text_encoder_2']:
config = model / subdir / "config.json"
try:
tmpl[subdir] = self._read_config(config)
except FileNotFoundError:
pass
elif (model / "learned_embeds.bin").exists(): # concepts model
return self._make_checkpoint_template(model / "learned_embeds.bin")
else:
config = model / "config.json"
try:
tmpl = self._read_config(config)
except FileNotFoundError:
pass
return tmpl
def _read_config(self, config: Path) -> dict:
with open(config, 'r', encoding='utf-8') as f:
return {x: y for x, y in json.load(f).items() if not x.startswith("_")}
def on_search_completed(self):
"""Not used."""
pass
def on_search_started(self):
"""Not used."""
pass
from invokeai.backend.model_manager import(
read_checkpoint_meta,
ModelType,
ModelVariantType,
BaseModelType,
)
parser = argparse.ArgumentParser(
description="Create a .json template from checkpoint/safetensors model",
description="Scan the provided path recursively and create .json templates for all models found.",
)
parser.add_argument('checkpoint', type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
parser.add_argument("--base-type",
type=str,
choices=[x.value for x in BaseModelType],
help="Base model",
parser.add_argument("--scan",
type=Path,
help="Path to recursively scan for models"
)
parser.add_argument("--model-type",
type=str,
choices=[x.value for x in ModelType],
default='main',
help="Type of the model",
)
parser.add_argument("--variant",
type=str,
choices=[x.value for x in ModelVariantType],
default='normal',
help="Base type of the model",
parser.add_argument("--out",
type=Path,
dest="outdir",
default=Path(__file__).resolve().parents[1] / "invokeai/configs/model_probe_templates",
help="Destination for templates",
)
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
tmpl = {}
for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
meta = {
'base_type': opt.base_type,
'model_type': opt.model_type,
'variant': opt.variant,
'template': tmpl
}
try:
with open(opt.template, "w", encoding="utf-8") as f:
json.dump(meta, f)
print(f"Template written out as {opt.template}")
except OSError as e:
print(f"An exception occurred while writing template: {str(e)}")
scanner = CreateTemplateScanner([opt.scan], dest=opt.outdir)
scanner.search()

View File

@@ -1,8 +1,19 @@
#!/bin/env python
"""Little command-line utility for probing a model on disk."""
import argparse
import sys
from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe
from invokeai.backend.model_manager import (
ModelProbe,
SchedulerPredictionType,
InvalidModelException,
)
def helper(model_path: Path):
print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr)
return SchedulerPredictionType.VPrediction
parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
@@ -13,5 +24,8 @@ parser.add_argument(
args = parser.parse_args()
for path in args.model_path:
info = ModelProbe().probe(path)
print(f"{path}: {info}")
try:
info = ModelProbe().probe(path, helper)
print(f"{path}: {info}")
except InvalidModelException as exc:
print(exc)