mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tidy(mm): clean up ModelOnDisk caching
This commit is contained in:
@@ -2162,9 +2162,6 @@ def get_model_discriminator_value(v: Any) -> str:
|
||||
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
# Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
# Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
# SD_1_2_XL_XLRefiner_CheckpointConfig
|
||||
Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()],
|
||||
Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()],
|
||||
Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()],
|
||||
|
||||
@@ -30,7 +30,8 @@ class ModelOnDisk:
|
||||
self.hash_algo = hash_algo
|
||||
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
|
||||
# This prevents redundant computations during matching and parsing
|
||||
self.cache = {"_CACHED_STATE_DICTS": {}}
|
||||
self._state_dict_cache: dict[Path, Any] = {}
|
||||
self._metadata_cache: dict[Path, Any] = {}
|
||||
|
||||
def hash(self) -> str:
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
@@ -47,13 +48,18 @@ class ModelOnDisk:
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
|
||||
path = path or self.path
|
||||
if path in self._metadata_cache:
|
||||
return self._metadata_cache[path]
|
||||
try:
|
||||
with safe_open(self.path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
assert isinstance(metadata, dict)
|
||||
return metadata
|
||||
except Exception:
|
||||
return {}
|
||||
metadata = {}
|
||||
|
||||
self._metadata_cache[path] = metadata
|
||||
return metadata
|
||||
|
||||
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
||||
if self.path.is_file():
|
||||
@@ -73,10 +79,8 @@ class ModelOnDisk:
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
|
||||
sd_cache = self.cache["_CACHED_STATE_DICTS"]
|
||||
|
||||
if path in sd_cache:
|
||||
return sd_cache[path]
|
||||
if path in self._state_dict_cache:
|
||||
return self._state_dict_cache[path]
|
||||
|
||||
path = self.resolve_weight_file(path)
|
||||
|
||||
@@ -111,7 +115,7 @@ class ModelOnDisk:
|
||||
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
||||
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
sd_cache[path] = state_dict
|
||||
self._state_dict_cache[path] = state_dict
|
||||
return state_dict
|
||||
|
||||
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
|
||||
|
||||
Reference in New Issue
Block a user