mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 17:18:11 -05:00
Compare commits
23 Commits
v5.8.0
...
ryan/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2144d21f80 | ||
|
|
958efa19d7 | ||
|
|
11af57def3 | ||
|
|
8b70a5b9bd | ||
|
|
5d9fdcd78d | ||
|
|
c7b84cf012 | ||
|
|
8e409e3436 | ||
|
|
987393853c | ||
|
|
91c5af1b95 | ||
|
|
5c67dd507a | ||
|
|
2ff928ec17 | ||
|
|
4327bbe77e | ||
|
|
ad1c0d37ef | ||
|
|
9708d87946 | ||
|
|
3ad44f7850 | ||
|
|
9a482981b2 | ||
|
|
6b02362b12 | ||
|
|
8fec4ec91c | ||
|
|
693e421970 | ||
|
|
dc14104bc8 | ||
|
|
f286a1d1f3 | ||
|
|
9dc86b2b71 | ||
|
|
2cab689b79 |
@@ -1364,7 +1364,6 @@ the in-memory loaded model:
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
NodeExecutionStatsSummary,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
|
||||
# Size of 1GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -78,9 +78,8 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
self._ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
|
||||
from invokeai.app.services.model_load.model_load_default import ModelLoadService
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
|
||||
@@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
@@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
|
||||
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
|
||||
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
return self._cache_record.cached_model.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
self.config = config
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
@@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -28,7 +29,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
@@ -54,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -66,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -78,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
get_model_cache_key(config.key, submodel_type),
|
||||
model=loaded_model,
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
|
||||
from .model_cache_default import ModelCache # noqa F401
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord:
|
||||
"""A class that represents a model in the model cache."""
|
||||
|
||||
# Cache key.
|
||||
key: str
|
||||
# Model in memory.
|
||||
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
|
||||
# If locks > 0, the model is actively being used, so we should do our best to keep it on the compute device.
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def is_locked(self) -> bool:
|
||||
return self._locks > 0
|
||||
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -0,0 +1,81 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
self._total_bytes = total_bytes
|
||||
self._is_in_vram = False
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better and implement it.
|
||||
return None
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._is_in_vram:
|
||||
return self._total_bytes
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_in_vram(self) -> bool:
|
||||
"""Return true if the model is currently in VRAM."""
|
||||
return self._is_in_vram
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM (if supported by the model).
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
if self._is_in_vram:
|
||||
# Already in VRAM.
|
||||
return 0
|
||||
|
||||
if not hasattr(self._model, "to"):
|
||||
# Model doesn't support moving to a device.
|
||||
return 0
|
||||
|
||||
self._model.to(self._compute_device)
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
if not self._is_in_vram:
|
||||
# Already in RAM.
|
||||
return 0
|
||||
|
||||
self._model.to("cpu")
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(model, compute_device)
|
||||
|
||||
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
|
||||
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
|
||||
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._model.parameters())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM."""
|
||||
return self.partial_load_to_vram(self.total_bytes())
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
# Skip parameters that are already on the compute device.
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
# Check the size of the parameter.
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
continue
|
||||
|
||||
# Copy the parameter to the compute device.
|
||||
# We use the 'overwrite' strategy from torch.nn.Module._apply().
|
||||
# TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g.
|
||||
# swap).
|
||||
assert isinstance(param, torch.nn.Parameter)
|
||||
assert param.is_leaf
|
||||
out_param = torch.nn.Parameter(param.to(self._compute_device, copy=True), requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
|
||||
# handling gradients. We assert that this assumption is true.
|
||||
assert param.grad is None
|
||||
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
|
||||
if param.device.type != self._compute_device.type:
|
||||
continue
|
||||
|
||||
# Create a new parameter, but inject the existing CPU tensor into it.
|
||||
out_param = torch.nn.Parameter(self._cpu_state_dict[key], requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
return vram_bytes_freed
|
||||
534
invokeai/backend/model_manager/load/model_cache/model_cache.py
Normal file
534
invokeai/backend/model_manager/load/model_cache/model_cache.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import gc
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
|
||||
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
"""Get the cache key for a model based on the optional submodel type."""
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = PrefixedLoggerAdapter(
|
||||
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
|
||||
)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collecting cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
"""Add a model to the cache."""
|
||||
if key in self._cached_models:
|
||||
self._logger.debug(
|
||||
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
|
||||
)
|
||||
return
|
||||
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module):
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
# running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
self._logger.debug(
|
||||
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
|
||||
)
|
||||
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Model key
|
||||
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
|
||||
|
||||
Raises IndexError if the model is not in the cache.
|
||||
"""
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
self._logger.debug(f"Cache miss: {key}")
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
self._logger.debug(f"Locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
try:
|
||||
self._load_locked_model(cache_entry)
|
||||
self._logger.debug(
|
||||
f"Finished locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
|
||||
self._log_cache_state()
|
||||
|
||||
def unlock(self, key: str) -> None:
|
||||
"""Unlock a model."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.unlock()
|
||||
self._logger.debug(f"Unlocked model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
|
||||
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
||||
vram_available = self._get_vram_available()
|
||||
|
||||
# The amount of additional VRAM that will be used if we fully load the model into VRAM.
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
model_total_bytes = cache_entry.cached_model.total_bytes()
|
||||
model_vram_needed = model_total_bytes - model_cur_vram_bytes
|
||||
|
||||
self._logger.debug(
|
||||
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Make room for the model in VRAM.
|
||||
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
|
||||
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
|
||||
# possible.
|
||||
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed)
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
vram_available = self._get_vram_available()
|
||||
self._logger.debug(
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Move as much of the model as possible into VRAM.
|
||||
model_bytes_loaded = 0
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
model_bytes_loaded = cache_entry.cached_model.partial_load_to_vram(vram_available)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
|
||||
model_bytes_loaded = cache_entry.cached_model.full_load_to_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
vram_available = self._get_vram_available()
|
||||
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
|
||||
self._logger.debug(
|
||||
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
def _get_vram_available(self) -> int:
|
||||
"""Get the amount of VRAM available in the cache."""
|
||||
return int(self._max_vram_cache_size * GB) - self._get_vram_in_use()
|
||||
|
||||
def _get_vram_in_use(self) -> int:
|
||||
"""Get the amount of VRAM currently in use."""
|
||||
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available in the cache."""
|
||||
return int(self._max_cache_size * GB) - self._get_ram_in_use()
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
|
||||
"""Helper function for preparing a VRAM state log string."""
|
||||
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
||||
return (
|
||||
f"model_total={model_total_bytes/MB:.0f} MB, "
|
||||
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
|
||||
+ f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
|
||||
+ f"vram_available={(vram_available/MB):.0f} MB, "
|
||||
)
|
||||
|
||||
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
|
||||
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
|
||||
offloaded. Of course, locked models are not offloaded.
|
||||
|
||||
Returns:
|
||||
int: The number of bytes freed.
|
||||
"""
|
||||
self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.")
|
||||
vram_bytes_freed = 0
|
||||
# TODO(ryand): Give more thought to the offloading policy used here.
|
||||
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
||||
for cache_entry in cache_entries_increasing_size:
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
if cache_entry.is_locked:
|
||||
continue
|
||||
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
cache_entry_bytes_freed = cache_entry.cached_model.partial_unload_from_vram(
|
||||
vram_bytes_to_free - vram_bytes_freed
|
||||
)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
cache_entry_bytes_freed = cache_entry.cached_model.full_unload_from_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
if cache_entry_bytes_freed > 0:
|
||||
self._logger.debug(
|
||||
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
|
||||
)
|
||||
vram_bytes_freed += cache_entry_bytes_freed
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
return vram_bytes_freed
|
||||
|
||||
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
# """Move model into the indicated device.
|
||||
|
||||
# :param cache_entry: The CacheRecord for the model
|
||||
# :param target_device: The torch.device to move the model into
|
||||
|
||||
# May raise a torch.cuda.OutOfMemoryError
|
||||
# """
|
||||
# self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
# source_device = cache_entry.device
|
||||
|
||||
# # Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# # This would need to be revised to support multi-GPU.
|
||||
# if torch.device(source_device).type == torch.device(target_device).type:
|
||||
# return
|
||||
|
||||
# # Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
# if not hasattr(cache_entry.model, "to"):
|
||||
# return
|
||||
|
||||
# # This roundabout method for moving the model around is done to avoid
|
||||
# # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# # When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# # RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# # This operation is slightly faster than running `to()` on the whole model.
|
||||
# #
|
||||
# # When the model needs to be removed from VRAM we simply delete the copy
|
||||
# # of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# # in RAM into the model. So this operation is very fast.
|
||||
# start_model_to_time = time.time()
|
||||
# snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
# try:
|
||||
# if cache_entry.state_dict is not None:
|
||||
# assert hasattr(cache_entry.model, "load_state_dict")
|
||||
# if target_device == self._storage_device:
|
||||
# cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
# else:
|
||||
# new_dict: Dict[str, torch.Tensor] = {}
|
||||
# for k, v in cache_entry.state_dict.items():
|
||||
# new_dict[k] = v.to(target_device, copy=True)
|
||||
# cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
# cache_entry.model.to(target_device)
|
||||
# cache_entry.device = target_device
|
||||
# except Exception as e: # blow away cache entry
|
||||
# self._delete_cache_entry(cache_entry)
|
||||
# raise e
|
||||
|
||||
# snapshot_after = self._capture_memory_snapshot()
|
||||
# end_model_to_time = time.time()
|
||||
# self._logger.debug(
|
||||
# f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
# f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
# f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# snapshot_before is not None
|
||||
# and snapshot_after is not None
|
||||
# and snapshot_before.vram is not None
|
||||
# and snapshot_after.vram is not None
|
||||
# ):
|
||||
# vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# # If the estimated model size does not match the change in VRAM, log a warning.
|
||||
# if not math.isclose(
|
||||
# vram_change,
|
||||
# cache_entry.size,
|
||||
# rel_tol=0.1,
|
||||
# abs_tol=10 * MB,
|
||||
# ):
|
||||
# self._logger.debug(
|
||||
# f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
# f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
# " estimated size may be incorrect. Estimated model size:"
|
||||
# f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
# )
|
||||
|
||||
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
|
||||
ram_size_bytes = self._max_cache_size * GB
|
||||
ram_in_use_bytes = self._get_ram_in_use()
|
||||
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
ram_available_bytes = self._get_ram_available()
|
||||
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
|
||||
vram_size_bytes = self._max_vram_cache_size * GB
|
||||
vram_in_use_bytes = self._get_vram_in_use()
|
||||
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
vram_available_bytes = self._get_vram_available()
|
||||
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
|
||||
log = f"{title}\n"
|
||||
|
||||
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
|
||||
log += log_format.format(
|
||||
f"Storage Device ({self._storage_device.type})",
|
||||
ram_size_bytes / MB,
|
||||
ram_in_use_bytes / MB,
|
||||
ram_in_use_bytes_percent,
|
||||
ram_available_bytes / MB,
|
||||
ram_available_bytes_percent,
|
||||
)
|
||||
log += log_format.format(
|
||||
f"Compute Device ({self._execution_device.type})",
|
||||
vram_size_bytes / MB,
|
||||
vram_in_use_bytes / MB,
|
||||
vram_in_use_bytes_percent,
|
||||
vram_available_bytes / MB,
|
||||
vram_available_bytes_percent,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
log += " {:<30} {} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
|
||||
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
|
||||
|
||||
if include_entry_details and len(self._cached_models) > 0:
|
||||
log += " Models:\n"
|
||||
log_format = (
|
||||
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
|
||||
)
|
||||
for cache_record in self._cached_models.values():
|
||||
total_bytes = cache_record.cached_model.total_bytes()
|
||||
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
|
||||
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
cur_ram_bytes = total_bytes - cur_vram_bytes
|
||||
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
|
||||
log += log_format.format(
|
||||
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
|
||||
total_bytes / MB,
|
||||
cur_vram_bytes / MB,
|
||||
cur_vram_bytes_percent,
|
||||
cur_ram_bytes / MB,
|
||||
cur_ram_bytes_percent,
|
||||
cache_record.is_locked,
|
||||
)
|
||||
|
||||
self._logger.debug(log)
|
||||
|
||||
def make_room(self, bytes_needed: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="Before dropping models:")
|
||||
|
||||
ram_bytes_available = self._get_ram_available()
|
||||
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
|
||||
|
||||
ram_bytes_freed = 0
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
if not cache_entry.is_locked:
|
||||
ram_bytes_freed += cache_entry.cached_model.total_bytes()
|
||||
self._logger.debug(
|
||||
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
|
||||
)
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
models_cleared += 1
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="After dropping models:")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@abstractmethod
|
||||
def lock(self) -> AnyModel:
|
||||
"""Lock the contained model and move it into VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Virtual base class for RAM model cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
@stats.setter
|
||||
@abstractmethod
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log debugging information on CUDA usage."""
|
||||
pass
|
||||
@@ -1,426 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
CacheStats,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
|
||||
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
|
||||
"""
|
||||
Initialize the model locker.
|
||||
|
||||
:param cache: The ModelCache object
|
||||
:param cache_entry: The entry in the model cache
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
|
||||
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
|
||||
old_forward = m.forward
|
||||
|
||||
def new_forward(*args: Any, **kwargs: Any):
|
||||
with TorchFunctionAutocastDeviceContext(to_device):
|
||||
return old_forward(*args, **kwargs)
|
||||
|
||||
m.forward = new_forward
|
||||
|
||||
|
||||
def _cast_to_device_and_run(
|
||||
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
|
||||
):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_function__(
|
||||
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
|
||||
):
|
||||
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)
|
||||
@@ -26,7 +26,7 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class LoRALoader(ModelLoader):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
||||
12
invokeai/backend/util/prefix_logger_adapter.py
Normal file
12
invokeai/backend/util/prefix_logger_adapter.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import logging
|
||||
from typing import Any, MutableMapping
|
||||
|
||||
|
||||
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
|
||||
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
def __init__(self, logger: logging.Logger, prefix: str):
|
||||
super().__init__(logger, {})
|
||||
self.prefix = prefix
|
||||
|
||||
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
|
||||
return f"[{self.prefix}] {msg}", kwargs
|
||||
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.total_bytes() == 100
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_is_in_vram(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.is_in_vram()
|
||||
|
||||
cached_model.full_unload_from_vram()
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
assert cached_model.full_load_to_vram() == 100
|
||||
assert cached_model.is_in_vram()
|
||||
assert all(p.device.type == device for p in cached_model.model.parameters())
|
||||
|
||||
assert cached_model.full_unload_from_vram() == 100
|
||||
assert not cached_model.is_in_vram()
|
||||
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())
|
||||
@@ -0,0 +1,174 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("MPS is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
assert cached_model.total_bytes() == linear_numel * 4 * 2
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() > 0
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == device)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
# Partially unload the model from VRAM.
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
assert freed_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == "cpu")
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
loaded_bytes = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
# Full load the rest of the model into VRAM.
|
||||
loaded_bytes_2 = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes_2 > 0
|
||||
assert loaded_bytes_2 < model_total_bytes
|
||||
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == loaded_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
# The CPU state dict is still available, and still on the CPU.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
13
tests/backend/model_manager/load/model_cache/dummy_module.py
Normal file
13
tests/backend/model_manager/load/model_cache/dummy_module.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
TorchFunctionAutocastDeviceContext,
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
|
||||
|
||||
|
||||
def test_torch_function_autocast_device_context():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
|
||||
model = DummyModule()
|
||||
# Model parameters should start off on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
def test_add_autocast_to_module_forward():
|
||||
model = DummyModule()
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
add_autocast_to_module_forward(model, torch.device("cuda"))
|
||||
# After adding autocast, the model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
|
||||
# The model output should be on the GPU.
|
||||
assert y.device.type == "cuda"
|
||||
|
||||
# The model parameters should still be on the CPU.
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
# The autocast context should automatically be disabled after the model forward call completes.
|
||||
# So, attempting to perform an operation with comflicting devices should raise an error.
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")
|
||||
@@ -25,7 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelVariantType,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
HFTestLoraMetadata,
|
||||
|
||||
Reference in New Issue
Block a user