mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add enable_partial_loading config.
This commit is contained in:
@@ -107,6 +107,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
|
||||
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
@@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
device_working_mem_gb: float = Field(default=2, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
|
||||
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
|
||||
@@ -83,6 +83,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
ram_cache = ModelCache(
|
||||
execution_device_working_mem_gb=app_config.device_working_mem_gb,
|
||||
enable_partial_loading=app_config.enable_partial_loading,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
|
||||
@@ -72,6 +72,7 @@ class ModelCache:
|
||||
def __init__(
|
||||
self,
|
||||
execution_device_working_mem_gb: float,
|
||||
enable_partial_loading: bool,
|
||||
execution_device: torch.device | str = "cuda",
|
||||
storage_device: torch.device | str = "cpu",
|
||||
lazy_offloading: bool = True,
|
||||
@@ -91,6 +92,7 @@ class ModelCache:
|
||||
"""
|
||||
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
|
||||
self._lazy_offloading = lazy_offloading
|
||||
self._enable_partial_loading = enable_partial_loading
|
||||
self._execution_device_working_mem_gb = execution_device_working_mem_gb
|
||||
self._execution_device: torch.device = torch.device(execution_device)
|
||||
self._storage_device: torch.device = torch.device(storage_device)
|
||||
@@ -127,7 +129,7 @@ class ModelCache:
|
||||
running_on_cpu = self._execution_device.type == "cpu"
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module) and not running_on_cpu:
|
||||
if isinstance(model, torch.nn.Module) and not running_on_cpu and self._enable_partial_loading:
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
@@ -92,6 +92,7 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
|
||||
enable_partial_loading=mm2_app_config.enable_partial_loading,
|
||||
execution_device="cpu",
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user