WIP - add device_working_mem_gb config

This commit is contained in:
Ryan Dick
2024-12-18 03:31:37 +00:00
parent e0c899104b
commit 79a4d0890f
7 changed files with 72 additions and 73 deletions

View File

@@ -106,6 +106,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: Amount of VRAM reserved for model storage (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -176,6 +177,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=2, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")

View File

@@ -22,7 +22,6 @@ class ModelCacheStatsSummary:
"""The stats for the model cache."""
high_water_mark_gb: float
cache_size_gb: float
total_usage_gb: float
cache_hits: int
cache_misses: int
@@ -79,7 +78,7 @@ class InvocationStatsSummary:
_str += f" Model cache misses: {self.model_cache_stats.cache_misses}\n"
_str += f" Models cached: {self.model_cache_stats.models_cached}\n"
_str += f" Models cleared from cache: {self.model_cache_stats.models_cleared}\n"
_str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}/{self.model_cache_stats.cache_size_gb:4.2f}G\n"
_str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}G\n"
return _str

View File

@@ -111,7 +111,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
cache_hits=cache_stats.hits,
cache_misses=cache_stats.misses,
high_water_mark_gb=cache_stats.high_watermark / GB,
cache_size_gb=cache_stats.cache_size / GB,
total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB,
models_cached=cache_stats.in_cache,
models_cleared=cache_stats.cleared,

View File

@@ -82,8 +82,7 @@ class ModelManagerService(ModelManagerServiceBase):
logger.setLevel(app_config.log_level.upper())
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
execution_device_working_mem_gb=app_config.device_working_mem_gb,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),

View File

@@ -11,5 +11,4 @@ class CacheStats(object):
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)

View File

@@ -2,6 +2,7 @@ import gc
from logging import Logger
from typing import Dict, List, Optional
import psutil
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
@@ -70,19 +71,15 @@ class ModelCache:
def __init__(
self,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
execution_device_working_mem_gb: float,
execution_device: torch.device | str = "cuda",
storage_device: torch.device | str = "cpu",
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
"""Initialize the model RAM cache.
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
@@ -92,13 +89,11 @@ class ModelCache:
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._lazy_offloading = lazy_offloading
self._execution_device_working_mem_gb = execution_device_working_mem_gb
self._execution_device: torch.device = torch.device(execution_device)
self._storage_device: torch.device = torch.device(storage_device)
self._logger = PrefixedLoggerAdapter(
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
)
@@ -108,26 +103,6 @@ class ModelCache:
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
@@ -149,14 +124,14 @@ class ModelCache:
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
running_on_cpu = self._execution_device.type == "cpu"
# Wrap model.
if isinstance(model, torch.nn.Module):
if isinstance(model, torch.nn.Module) and not running_on_cpu:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
# running_on_cpu = self._execution_device == torch.device("cpu")
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -186,7 +161,6 @@ class ModelCache:
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@@ -208,6 +182,10 @@ class ModelCache:
self._logger.debug(f"Locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
if self._execution_device.type == "cpu":
# Models don't need to be loaded into VRAM if we're running on CPU.
return
try:
self._load_locked_model(cache_entry)
self._logger.debug(
@@ -277,16 +255,38 @@ class ModelCache:
)
def _get_vram_available(self) -> int:
"""Get the amount of VRAM available in the cache."""
return int(self._max_vram_cache_size * GB) - self._get_vram_in_use()
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
"""
if self._execution_device.type == "cuda":
vram_reserved = torch.cuda.memory_reserved(self._execution_device)
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_reserved
elif self._execution_device.type == "mps":
# TODO(ryand): Would it be better to use psutil.virtual_memory().total here? I haven't looked into the
# behaviors of some of these functions when multiple processes are using MPS memory.
vram_reserved = torch.mps.driver_allocated_memory()
vram_total: int = torch.mps.recommended_max_memory()
vram_available_to_process = vram_total
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
vram_total_available_to_cache = vram_available_to_process - int(self._execution_device_working_mem_gb * GB)
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
return vram_cur_available_to_cache
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use."""
"""Get the amount of VRAM currently in use by the cache."""
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available in the cache."""
return int(self._max_cache_size * GB) - self._get_ram_in_use()
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
virtual_memory = psutil.virtual_memory()
ram_total = virtual_memory.total
ram_available = virtual_memory.available
ram_used = ram_total - ram_available
# Aim to keep 10% of RAM free.
return int(ram_total * 0.9) - ram_used
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
@@ -303,7 +303,7 @@ class ModelCache:
return (
f"model_total={model_total_bytes/MB:.0f} MB, "
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
+ f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
# + f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
+ f"vram_available={(vram_available/MB):.0f} MB, "
)
@@ -422,21 +422,15 @@ class ModelCache:
# )
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
ram_size_bytes = self._max_cache_size * GB
ram_in_use_bytes = self._get_ram_in_use()
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
ram_available_bytes = self._get_ram_available()
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
vram_size_bytes = self._max_vram_cache_size * GB
vram_in_use_bytes = self._get_vram_in_use()
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes = self._get_vram_available()
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
log = f"{title}\n"
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
ram_in_use_bytes = self._get_ram_in_use()
ram_available_bytes = self._get_ram_available()
ram_size_bytes = ram_in_use_bytes + ram_available_bytes
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
log += log_format.format(
f"Storage Device ({self._storage_device.type})",
ram_size_bytes / MB,
@@ -445,17 +439,24 @@ class ModelCache:
ram_available_bytes / MB,
ram_available_bytes_percent,
)
log += log_format.format(
f"Compute Device ({self._execution_device.type})",
vram_size_bytes / MB,
vram_in_use_bytes / MB,
vram_in_use_bytes_percent,
vram_available_bytes / MB,
vram_available_bytes_percent,
)
if self._execution_device.type != "cpu":
vram_in_use_bytes = self._get_vram_in_use()
vram_available_bytes = self._get_vram_available()
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
log += log_format.format(
f"Compute Device ({self._execution_device.type})",
vram_size_bytes / MB,
vram_in_use_bytes / MB,
vram_in_use_bytes_percent,
vram_available_bytes / MB,
vram_available_bytes_percent,
)
if torch.cuda.is_available():
log += " {:<30} {} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
if include_entry_details and len(self._cached_models) > 0:

View File

@@ -91,9 +91,9 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache(
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
execution_device="cpu",
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram,
max_vram_cache_size=mm2_app_config.vram,
)
return ModelLoadService(
app_config=mm2_app_config,