diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index b6b4d3c7d9..a0c1083c82 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -4,7 +4,7 @@ import threading import time from functools import wraps from logging import Logger -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Protocol import psutil import torch @@ -54,6 +54,22 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]: return wrapper +class CacheMissCallback(Protocol): + def __call__( + self, + model_key: str, + cache_overview: dict[str, int], + ) -> None: ... + + +class CacheHitCallback(Protocol): + def __call__( + self, + model_key: str, + cache_overview: dict[str, int], + ) -> None: ... + + class ModelCache: """A cache for managing models in memory. @@ -144,6 +160,21 @@ class ModelCache: # - Requests to empty the cache from a separate thread self._lock = threading.RLock() + self._on_cache_hit_callbacks: set[CacheHitCallback] = set() + self._on_cache_miss_callbacks: set[CacheMissCallback] = set() + + def register_on_cache_hit(self, cb: CacheHitCallback) -> None: + self._on_cache_hit_callbacks.add(cb) + + def register_on_cache_miss(self, cb: CacheMissCallback) -> None: + self._on_cache_miss_callbacks.add(cb) + + def unregister_on_cache_hit(self, cb: CacheHitCallback) -> None: + self._on_cache_hit_callbacks.discard(cb) + + def unregister_on_cache_miss(self, cb: CacheMissCallback) -> None: + self._on_cache_miss_callbacks.discard(cb) + @property @synchronized def stats(self) -> Optional[CacheStats]: @@ -195,6 +226,15 @@ class ModelCache: f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)" ) + @synchronized + def _get_cache_overview(self) -> dict[str, int]: + overview: dict[str, int] = {} + for model_key, cache_entry in self._cached_models.items(): + overview[model_key] = cache_entry.cached_model.total_bytes() + # Useful? cache_entry.cached_model.is_in_vram() + + return overview + @synchronized def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord: """Retrieve a model from the cache. @@ -208,6 +248,8 @@ class ModelCache: if self.stats: self.stats.hits += 1 else: + for cb in self._on_cache_miss_callbacks: + cb(model_key=key, cache_overview=self._get_cache_overview()) if self.stats: self.stats.misses += 1 self._logger.debug(f"Cache miss: {key}") @@ -229,6 +271,8 @@ class ModelCache: self._cache_stack.append(key) self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})") + for cb in self._on_cache_hit_callbacks: + cb(model_key=key, cache_overview=self._get_cache_overview()) return cache_entry @synchronized