mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Allow expensive operations to request more working memory.
This commit is contained in:
@@ -57,16 +57,22 @@ class LoadedModelWithoutConfig:
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
self._cache.lock(self._cache_record)
|
||||
self._cache.lock(self._cache_record, None)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._cache.unlock(self._cache_record)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
self._cache.lock(self._cache_record)
|
||||
def model_on_device(
|
||||
self, working_mem_bytes: Optional[int] = None
|
||||
) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.
|
||||
|
||||
:param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the
|
||||
model.
|
||||
"""
|
||||
self._cache.lock(self._cache_record, working_mem_bytes)
|
||||
try:
|
||||
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
|
||||
finally:
|
||||
|
||||
@@ -200,7 +200,7 @@ class ModelCache:
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
return cache_entry
|
||||
|
||||
def lock(self, cache_entry: CacheRecord) -> None:
|
||||
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
if cache_entry.key not in self._cached_models:
|
||||
self._logger.info(
|
||||
@@ -221,7 +221,7 @@ class ModelCache:
|
||||
return
|
||||
|
||||
try:
|
||||
self._load_locked_model(cache_entry)
|
||||
self._load_locked_model(cache_entry, working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user