diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 2c9ad226a2..c94ff28117 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -2162,9 +2162,6 @@ def get_model_discriminator_value(v: Any) -> str: # when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes AnyModelConfig = Annotated[ Union[ - # Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], - # Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], - # SD_1_2_XL_XLRefiner_CheckpointConfig Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()], Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()], Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()], diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 502ca596a6..a86e94d3a4 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -30,7 +30,8 @@ class ModelOnDisk: self.hash_algo = hash_algo # Having a cache helps users of ModelOnDisk (i.e. configs) to save state # This prevents redundant computations during matching and parsing - self.cache = {"_CACHED_STATE_DICTS": {}} + self._state_dict_cache: dict[Path, Any] = {} + self._metadata_cache: dict[Path, Any] = {} def hash(self) -> str: return ModelHash(algorithm=self.hash_algo).hash(self.path) @@ -47,13 +48,18 @@ class ModelOnDisk: return {f for f in self.path.rglob("*") if f.suffix in extensions} def metadata(self, path: Optional[Path] = None) -> dict[str, str]: + path = path or self.path + if path in self._metadata_cache: + return self._metadata_cache[path] try: with safe_open(self.path, framework="pt", device="cpu") as f: metadata = f.metadata() assert isinstance(metadata, dict) - return metadata except Exception: - return {} + metadata = {} + + self._metadata_cache[path] = metadata + return metadata def repo_variant(self) -> Optional[ModelRepoVariant]: if self.path.is_file(): @@ -73,10 +79,8 @@ class ModelOnDisk: return ModelRepoVariant.Default def load_state_dict(self, path: Optional[Path] = None) -> StateDict: - sd_cache = self.cache["_CACHED_STATE_DICTS"] - - if path in sd_cache: - return sd_cache[path] + if path in self._state_dict_cache: + return self._state_dict_cache[path] path = self.resolve_weight_file(path) @@ -111,7 +115,7 @@ class ModelOnDisk: raise ValueError(f"Unrecognized model extension: {path.suffix}") state_dict = checkpoint.get("state_dict", checkpoint) - sd_cache[path] = state_dict + self._state_dict_cache[path] = state_dict return state_dict def resolve_weight_file(self, path: Optional[Path] = None) -> Path: