feat(mm): iterate on cache callbacks API

This commit is contained in:
psychedelicious
2025-05-15 13:28:51 +10:00
parent a33da450fd
commit 823ca214e6

View File

@@ -70,6 +70,15 @@ class CacheHitCallback(Protocol):
) -> None: ...
class CacheModelsClearedCallback(Protocol):
def __call__(
self,
models_cleared: int,
bytes_requested: int,
bytes_freed: int,
) -> None: ...
class ModelCache:
"""A cache for managing models in memory.
@@ -162,18 +171,31 @@ class ModelCache:
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
def register_on_cache_hit(self, cb: CacheHitCallback) -> None:
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
self._on_cache_hit_callbacks.add(cb)
def register_on_cache_miss(self, cb: CacheMissCallback) -> None:
def unsubscribe() -> None:
self._on_cache_hit_callbacks.discard(cb)
return unsubscribe
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
self._on_cache_miss_callbacks.add(cb)
def unregister_on_cache_hit(self, cb: CacheHitCallback) -> None:
self._on_cache_hit_callbacks.discard(cb)
def unsubscribe() -> None:
self._on_cache_miss_callbacks.discard(cb)
def unregister_on_cache_miss(self, cb: CacheMissCallback) -> None:
self._on_cache_miss_callbacks.discard(cb)
return unsubscribe
def on_cache_models_cleared(self, cb: CacheModelsClearedCallback) -> Callable[[], None]:
self._on_cache_models_cleared_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_models_cleared_callbacks.discard(cb)
return unsubscribe
@property
@synchronized
@@ -693,6 +715,12 @@ class ModelCache:
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
for cb in self._on_cache_models_cleared_callbacks:
cb(
models_cleared=models_cleared,
bytes_requested=bytes_needed,
bytes_freed=ram_bytes_freed,
)
gc.collect()
TorchDevice.empty_cache()