mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 11:25:32 -05:00
When unsafe_disable_picklescan is enabled, instead of erroring on detections or scan failures, a warning is logged. A warning is also logged on app startup when this setting is enabled. The setting is disabled by default and there is no change in behaviour when disabled.
209 lines
8.1 KiB
Python
209 lines
8.1 KiB
Python
"""Utilities for parsing model files, used mostly by legacy_probe.py"""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Union
|
|
|
|
import picklescan.scanner as pscan
|
|
import safetensors
|
|
import torch
|
|
|
|
from invokeai.app.services.config.config_default import get_config
|
|
from invokeai.backend.model_manager.taxonomy import ClipVariantType
|
|
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
logger = InvokeAILogger.get_logger()
|
|
|
|
|
|
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
|
checkpoint = {}
|
|
device = torch.device("meta")
|
|
with open(path, "rb") as f:
|
|
definition_len = int.from_bytes(f.read(8), "little")
|
|
definition_json = f.read(definition_len)
|
|
definition = json.loads(definition_json)
|
|
|
|
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
|
"pt",
|
|
"torch",
|
|
"pytorch",
|
|
}:
|
|
raise Exception("Supported only pytorch safetensors files")
|
|
definition.pop("__metadata__", None)
|
|
|
|
for key, info in definition.items():
|
|
dtype = {
|
|
"I8": torch.int8,
|
|
"I16": torch.int16,
|
|
"I32": torch.int32,
|
|
"I64": torch.int64,
|
|
"F16": torch.float16,
|
|
"F32": torch.float32,
|
|
"F64": torch.float64,
|
|
}[info["dtype"]]
|
|
|
|
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
|
|
|
return checkpoint
|
|
|
|
|
|
def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, torch.Tensor]:
|
|
if str(path).endswith(".safetensors"):
|
|
try:
|
|
path_str = path.as_posix() if isinstance(path, Path) else path
|
|
checkpoint = _fast_safetensors_reader(path_str)
|
|
except Exception:
|
|
# TODO: create issue for support "meta"?
|
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
|
elif str(path).endswith(".gguf"):
|
|
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
|
|
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
|
else:
|
|
if scan:
|
|
scan_result = pscan.scan_file_path(path)
|
|
if scan_result.infected_files != 0:
|
|
if get_config().unsafe_disable_picklescan:
|
|
logger.warning(
|
|
f"The model {path} is potentially infected by malware, but picklescan is disabled. "
|
|
"Proceeding with caution."
|
|
)
|
|
else:
|
|
raise RuntimeError(f"The model {path} is potentially infected by malware. Aborting import.")
|
|
if scan_result.scan_err:
|
|
if get_config().unsafe_disable_picklescan:
|
|
logger.warning(
|
|
f"Error scanning the model at {path} for malware, but picklescan is disabled. "
|
|
"Proceeding with caution."
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Error scanning the model at {path} for malware. Aborting import.")
|
|
|
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
|
return checkpoint
|
|
|
|
|
|
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
|
"""
|
|
Given a checkpoint in memory, return the lora token vector length
|
|
|
|
:param checkpoint: The checkpoint
|
|
"""
|
|
|
|
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
|
lora_token_vector_length = None
|
|
|
|
if "." not in key:
|
|
return lora_token_vector_length # wrong key format
|
|
model_key, lora_key = key.split(".", 1)
|
|
|
|
# check lora/locon
|
|
if lora_key == "lora_down.weight":
|
|
lora_token_vector_length = tensor.shape[1]
|
|
|
|
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
|
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
|
lora_token_vector_length = tensor.shape[1]
|
|
|
|
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
|
elif "lokr_" in lora_key:
|
|
if model_key + ".lokr_w1" in checkpoint:
|
|
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
|
elif model_key + "lokr_w1_b" in checkpoint:
|
|
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
|
else:
|
|
return lora_token_vector_length # unknown format
|
|
|
|
if model_key + ".lokr_w2" in checkpoint:
|
|
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
|
elif model_key + "lokr_w2_b" in checkpoint:
|
|
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
|
else:
|
|
return lora_token_vector_length # unknown format
|
|
|
|
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
|
|
|
elif lora_key == "diff":
|
|
lora_token_vector_length = tensor.shape[1]
|
|
|
|
# ia3 can be detected only by shape[0] in text encoder
|
|
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
|
lora_token_vector_length = tensor.shape[0]
|
|
|
|
return lora_token_vector_length
|
|
|
|
lora_token_vector_length = None
|
|
lora_te1_length = None
|
|
lora_te2_length = None
|
|
for key, tensor in checkpoint.items():
|
|
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
|
elif key.startswith("lora_unet_") and (
|
|
"time_emb_proj.lora_down" in key
|
|
): # recognizes format at https://civitai.com/models/224641
|
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
|
elif key.startswith("lora_te") and "_self_attn_" in key:
|
|
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
|
if key.startswith("lora_te_"):
|
|
lora_token_vector_length = tmp_length
|
|
elif key.startswith("lora_te1_"):
|
|
lora_te1_length = tmp_length
|
|
elif key.startswith("lora_te2_"):
|
|
lora_te2_length = tmp_length
|
|
|
|
if lora_te1_length is not None and lora_te2_length is not None:
|
|
lora_token_vector_length = lora_te1_length + lora_te2_length
|
|
|
|
if lora_token_vector_length is not None:
|
|
break
|
|
|
|
return lora_token_vector_length
|
|
|
|
|
|
def convert_bundle_to_flux_transformer_checkpoint(
|
|
transformer_state_dict: dict[str, torch.Tensor],
|
|
) -> dict[str, torch.Tensor]:
|
|
original_state_dict: dict[str, torch.Tensor] = {}
|
|
keys_to_remove: list[str] = []
|
|
|
|
for k, v in transformer_state_dict.items():
|
|
if not k.startswith("model.diffusion_model"):
|
|
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
|
|
continue
|
|
if k.endswith("scale"):
|
|
# Scale math must be done at bfloat16 due to our current flux model
|
|
# support limitations at inference time
|
|
v = v.to(dtype=torch.bfloat16)
|
|
new_key = k.replace("model.diffusion_model.", "")
|
|
original_state_dict[new_key] = v
|
|
keys_to_remove.append(k)
|
|
|
|
# Remove processed keys from the original dictionary, leaving others in case
|
|
# other model state dicts need to be pulled
|
|
for k in keys_to_remove:
|
|
del transformer_state_dict[k]
|
|
|
|
return original_state_dict
|
|
|
|
|
|
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
|
try:
|
|
path = Path(location)
|
|
config_path = path / "config.json"
|
|
if not config_path.exists():
|
|
config_path = path / "text_encoder" / "config.json"
|
|
if not config_path.exists():
|
|
return ClipVariantType.L
|
|
with open(config_path) as file:
|
|
clip_conf = json.load(file)
|
|
hidden_size = clip_conf.get("hidden_size", -1)
|
|
match hidden_size:
|
|
case 1280:
|
|
return ClipVariantType.G
|
|
case 768:
|
|
return ClipVariantType.L
|
|
case _:
|
|
return ClipVariantType.L
|
|
except Exception:
|
|
return ClipVariantType.L
|