Compare commits

...

1 Commits

2 changed files with 20 additions and 8 deletions

View File

@@ -57,20 +57,20 @@ class LoadedModelWithoutConfig:
self._cache = cache
def __enter__(self) -> AnyModel:
self._cache.lock(self._cache_record.key)
self._cache.lock(self._cache_record)
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._cache.unlock(self._cache_record.key)
self._cache.unlock(self._cache_record)
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record.key)
self._cache.lock(self._cache_record)
try:
yield (self._cache_record.state_dict, self._cache_record.model)
finally:
self._cache.unlock(self._cache_record.key)
self._cache.unlock(self._cache_record)
@property
def model(self) -> AnyModel:

View File

@@ -194,9 +194,15 @@ class ModelCache:
return cache_entry
def lock(self, key: str) -> None:
def lock(self, cache_entry: CacheRecord) -> None:
"""Lock a model for use and move it into VRAM."""
cache_entry = self._cached_models[key]
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code."
)
# cache_entry = self._cached_models[key]
cache_entry.lock()
try:
@@ -214,9 +220,15 @@ class ModelCache:
cache_entry.unlock()
raise
def unlock(self, key: str) -> None:
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
cache_entry = self._cached_models[key]
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code."
)
# cache_entry = self._cached_models[key]
cache_entry.unlock()
if not self._lazy_offloading:
self._offload_unlocked_models(0)