From 4a4360a40c7483b9ff82850538e8fd744344af4c Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 18 Dec 2024 17:17:08 +0000 Subject: [PATCH] Add enable_partial_loading config. --- invokeai/app/services/config/config_default.py | 2 ++ invokeai/app/services/model_manager/model_manager_default.py | 1 + .../backend/model_manager/load/model_cache/model_cache.py | 4 +++- tests/backend/model_manager/model_manager_fixtures.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 56c8c7fc29..833c87bbe9 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -107,6 +107,7 @@ class InvokeAIAppConfig(BaseSettings): lazy_offload: Keep models in VRAM until their space is needed. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value. + enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. @@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings): lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.") log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") device_working_mem_gb: float = Field(default=2, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.") + enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.") # DEVICE device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.") diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 6f5dfdb77a..ef6754531d 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -83,6 +83,7 @@ class ModelManagerService(ModelManagerServiceBase): ram_cache = ModelCache( execution_device_working_mem_gb=app_config.device_working_mem_gb, + enable_partial_loading=app_config.enable_partial_loading, lazy_offloading=app_config.lazy_offload, logger=logger, execution_device=execution_device or TorchDevice.choose_torch_device(), diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index f1a831c56b..7c4ab0662e 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -72,6 +72,7 @@ class ModelCache: def __init__( self, execution_device_working_mem_gb: float, + enable_partial_loading: bool, execution_device: torch.device | str = "cuda", storage_device: torch.device | str = "cpu", lazy_offloading: bool = True, @@ -91,6 +92,7 @@ class ModelCache: """ # TODO(ryand): Think about what lazy_offloading should mean in the new model cache. self._lazy_offloading = lazy_offloading + self._enable_partial_loading = enable_partial_loading self._execution_device_working_mem_gb = execution_device_working_mem_gb self._execution_device: torch.device = torch.device(execution_device) self._storage_device: torch.device = torch.device(storage_device) @@ -127,7 +129,7 @@ class ModelCache: running_on_cpu = self._execution_device.type == "cpu" # Wrap model. - if isinstance(model, torch.nn.Module) and not running_on_cpu: + if isinstance(model, torch.nn.Module) and not running_on_cpu and self._enable_partial_loading: wrapped_model = CachedModelWithPartialLoad(model, self._execution_device) else: wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 7e5fa8e6cf..a7d5a08b92 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -92,6 +92,7 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: ram_cache = ModelCache( execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb, + enable_partial_loading=mm2_app_config.enable_partial_loading, execution_device="cpu", logger=InvokeAILogger.get_logger(), )