Drop models from the cache if they have been fully offloaded.

This commit is contained in:
Ryan Dick
2025-01-15 01:44:51 +00:00
parent bb52317377
commit c8ac19f2f7

View File

@@ -445,6 +445,7 @@ class ModelCache:
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
cache_entries_deleted = 0
for cache_entry in cache_entries_increasing_size:
# We do not fully trust the count of bytes freed, so we check again on each iteration.
vram_available = self._get_vram_available(working_mem_bytes)
@@ -462,6 +463,16 @@ class ModelCache:
)
vram_bytes_freed += cache_entry_bytes_freed
if cache_entry.cached_model.cur_vram_bytes() == 0:
self._logger.debug(f"Fully unloaded {cache_entry.key} from VRAM. Dropping it from the RAM cache.")
self._delete_cache_entry(cache_entry)
# Delete the reference to the cache entry so that gc.collect() has the desired effect.
del cache_entry
cache_entries_deleted += 1
if cache_entries_deleted > 0:
gc.collect()
TorchDevice.empty_cache()
return vram_bytes_freed