mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): iterate on cache callbacks API
This commit is contained in:
@@ -70,6 +70,15 @@ class CacheHitCallback(Protocol):
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class CacheModelsClearedCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
models_cleared: int,
|
||||
bytes_requested: int,
|
||||
bytes_freed: int,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
@@ -162,18 +171,31 @@ class ModelCache:
|
||||
|
||||
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
|
||||
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
|
||||
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
|
||||
|
||||
def register_on_cache_hit(self, cb: CacheHitCallback) -> None:
|
||||
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
|
||||
self._on_cache_hit_callbacks.add(cb)
|
||||
|
||||
def register_on_cache_miss(self, cb: CacheMissCallback) -> None:
|
||||
def unsubscribe() -> None:
|
||||
self._on_cache_hit_callbacks.discard(cb)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
|
||||
self._on_cache_miss_callbacks.add(cb)
|
||||
|
||||
def unregister_on_cache_hit(self, cb: CacheHitCallback) -> None:
|
||||
self._on_cache_hit_callbacks.discard(cb)
|
||||
def unsubscribe() -> None:
|
||||
self._on_cache_miss_callbacks.discard(cb)
|
||||
|
||||
def unregister_on_cache_miss(self, cb: CacheMissCallback) -> None:
|
||||
self._on_cache_miss_callbacks.discard(cb)
|
||||
return unsubscribe
|
||||
|
||||
def on_cache_models_cleared(self, cb: CacheModelsClearedCallback) -> Callable[[], None]:
|
||||
self._on_cache_models_cleared_callbacks.add(cb)
|
||||
|
||||
def unsubscribe() -> None:
|
||||
self._on_cache_models_cleared_callbacks.discard(cb)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
@property
|
||||
@synchronized
|
||||
@@ -693,6 +715,12 @@ class ModelCache:
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
for cb in self._on_cache_models_cleared_callbacks:
|
||||
cb(
|
||||
models_cleared=models_cleared,
|
||||
bytes_requested=bytes_needed,
|
||||
bytes_freed=ram_bytes_freed,
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user