From 8488ab01347488d9a41b360aaad1658c2dab4b7a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 3 Nov 2023 15:19:45 -0400 Subject: [PATCH] Reduce frequency that we call gc.collect() in the model cache. --- .../backend/model_management/model_cache.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 1c3ad06e8e..2385fd9bec 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -438,6 +438,7 @@ class ModelCache(object): self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") pos = 0 + models_cleared = 0 while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] @@ -482,6 +483,7 @@ class ModelCache(object): f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) current_size -= cache_entry.size + models_cleared += 1 if self.stats: self.stats.cleared += 1 del self._cache_stack[pos] @@ -491,7 +493,20 @@ class ModelCache(object): else: pos += 1 - gc.collect() + if models_cleared < 0: + # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but + # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost + # is high even if no garbage gets collected.) + # + # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: + # - If models had to be cleared, it's a signal that we are close to our memory limit. + # - If models were cleared, there's a good chance that there's a significant amount of garbage to be + # collected. + # + # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up + # immediately when their reference count hits 0. + gc.collect() + torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): mps.empty_cache()