tidy(mm): clean up ModelOnDisk caching

This commit is contained in:
psychedelicious
2025-10-01 13:02:18 +10:00
parent c53c731371
commit a0a4eb9a5a
2 changed files with 12 additions and 11 deletions

View File

@@ -2162,9 +2162,6 @@ def get_model_discriminator_value(v: Any) -> str:
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
# Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
# Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
# SD_1_2_XL_XLRefiner_CheckpointConfig
Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()],
Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()],
Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()],

View File

@@ -30,7 +30,8 @@ class ModelOnDisk:
self.hash_algo = hash_algo
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
# This prevents redundant computations during matching and parsing
self.cache = {"_CACHED_STATE_DICTS": {}}
self._state_dict_cache: dict[Path, Any] = {}
self._metadata_cache: dict[Path, Any] = {}
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
@@ -47,13 +48,18 @@ class ModelOnDisk:
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = path or self.path
if path in self._metadata_cache:
return self._metadata_cache[path]
try:
with safe_open(self.path, framework="pt", device="cpu") as f:
metadata = f.metadata()
assert isinstance(metadata, dict)
return metadata
except Exception:
return {}
metadata = {}
self._metadata_cache[path] = metadata
return metadata
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
@@ -73,10 +79,8 @@ class ModelOnDisk:
return ModelRepoVariant.Default
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
sd_cache = self.cache["_CACHED_STATE_DICTS"]
if path in sd_cache:
return sd_cache[path]
if path in self._state_dict_cache:
return self._state_dict_cache[path]
path = self.resolve_weight_file(path)
@@ -111,7 +115,7 @@ class ModelOnDisk:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
self._state_dict_cache[path] = state_dict
return state_dict
def resolve_weight_file(self, path: Optional[Path] = None) -> Path: