from pathlib import Path from typing import Any, Optional, TypeAlias import safetensors.torch import torch from picklescan.scanner import scan_file_path from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash from invokeai.backend.model_manager.taxonomy import ModelRepoVariant from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.util.silence_warnings import SilenceWarnings StateDict: TypeAlias = dict[str | int, Any] # When are the keys int? class ModelOnDisk: """A utility class representing a model stored on disk.""" def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): self.path = path if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: self.name = path.stem else: self.name = path.name self.hash_algo = hash_algo # Having a cache helps users of ModelOnDisk (i.e. configs) to save state # This prevents redundant computations during matching and parsing self.cache = {"_CACHED_STATE_DICTS": {}} def hash(self) -> str: return ModelHash(algorithm=self.hash_algo).hash(self.path) def size(self) -> int: if self.path.is_file(): return self.path.stat().st_size return sum(file.stat().st_size for file in self.path.rglob("*")) def component_paths(self) -> set[Path]: if self.path.is_file(): return {self.path} extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"} return {f for f in self.path.rglob("*") if f.suffix in extensions} def repo_variant(self) -> Optional[ModelRepoVariant]: if self.path.is_file(): return None weight_files = list(self.path.glob("**/*.safetensors")) weight_files.extend(list(self.path.glob("**/*.bin"))) for x in weight_files: if ".fp16" in x.suffixes: return ModelRepoVariant.FP16 if "openvino_model" in x.name: return ModelRepoVariant.OpenVINO if "flax_model" in x.name: return ModelRepoVariant.Flax if x.suffix == ".onnx": return ModelRepoVariant.ONNX return ModelRepoVariant.Default def load_state_dict(self, path: Optional[Path] = None) -> StateDict: sd_cache = self.cache["_CACHED_STATE_DICTS"] if path in sd_cache: return sd_cache[path] if not path: components = list(self.component_paths()) match components: case []: raise ValueError("No weight files found for this model") case [p]: path = p case ps if len(ps) >= 2: raise ValueError( f"Multiple weight files found for this model: {ps}. " f"Please specify the intended file using the 'path' argument" ) with SilenceWarnings(): if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): scan_result = scan_file_path(path) if scan_result.infected_files != 0 or scan_result.scan_err: raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.") checkpoint = torch.load(path, map_location="cpu") assert isinstance(checkpoint, dict) elif path.suffix.endswith(".gguf"): checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32) elif path.suffix.endswith(".safetensors"): checkpoint = safetensors.torch.load_file(path) else: raise ValueError(f"Unrecognized model extension: {path.suffix}") state_dict = checkpoint.get("state_dict", checkpoint) sd_cache[path] = state_dict return state_dict