Allow expensive operations to request more working memory.

This commit is contained in:
Ryan Dick
2024-12-31 21:44:55 +00:00
parent a167632f09
commit fc4a22fe78
2 changed files with 12 additions and 6 deletions

View File

@@ -57,16 +57,22 @@ class LoadedModelWithoutConfig:
self._cache = cache
def __enter__(self) -> AnyModel:
self._cache.lock(self._cache_record)
self._cache.lock(self._cache_record, None)
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._cache.unlock(self._cache_record)
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record)
def model_on_device(
self, working_mem_bytes: Optional[int] = None
) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.
:param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the
model.
"""
self._cache.lock(self._cache_record, working_mem_bytes)
try:
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
finally:

View File

@@ -200,7 +200,7 @@ class ModelCache:
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
return cache_entry
def lock(self, cache_entry: CacheRecord) -> None:
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
"""Lock a model for use and move it into VRAM."""
if cache_entry.key not in self._cached_models:
self._logger.info(
@@ -221,7 +221,7 @@ class ModelCache:
return
try:
self._load_locked_model(cache_entry)
self._load_locked_model(cache_entry, working_mem_bytes)
self._logger.debug(
f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)