mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Allow invocations to request more working VRAM when loading a model via the ModelCache.
This commit is contained in:
@@ -57,7 +57,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
# Reserve 6GB of VRAM for the VAE.
|
||||
# Experimentally, this was found to be sufficient for decoding a 1024x1024 image.
|
||||
# TODO(ryand): Set the requested working memory dynamically based on the image size (and self.fp32).
|
||||
vae_info = context.models.load(self.vae.vae, working_mem_bytes=6 * 2**30)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE decoder")
|
||||
|
||||
@@ -14,12 +14,19 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
working_mem_bytes: Optional[int] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param working_mem_bytes: The number of bytes of working memory to keep on the GPU while this model is loaded on the
|
||||
GPU.
|
||||
"""
|
||||
|
||||
@property
|
||||
|
||||
@@ -49,7 +49,12 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
working_mem_bytes: Optional[int] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
@@ -67,7 +72,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
).load_model(model_config, submodel_type, working_mem_bytes)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
|
||||
@@ -361,7 +361,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
self,
|
||||
identifier: Union[str, "ModelIdentifierField"],
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
working_mem_bytes: Optional[int] = None,
|
||||
) -> LoadedModel:
|
||||
"""Load a model.
|
||||
|
||||
@@ -386,7 +389,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, working_mem_bytes)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
|
||||
@@ -52,12 +52,13 @@ class LoadedModelWithoutConfig:
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache, working_mem_bytes: Optional[int] = None):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
self._working_mem_bytes = working_mem_bytes
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
self._cache.lock(self._cache_record.key)
|
||||
self._cache.lock(self._cache_record.key, self._working_mem_bytes)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -66,7 +67,7 @@ class LoadedModelWithoutConfig:
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
self._cache.lock(self._cache_record.key)
|
||||
self._cache.lock(self._cache_record.key, self._working_mem_bytes)
|
||||
try:
|
||||
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
|
||||
finally:
|
||||
@@ -81,8 +82,14 @@ class LoadedModelWithoutConfig:
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AnyModelConfig],
|
||||
cache_record: CacheRecord,
|
||||
cache: ModelCache,
|
||||
working_mem_bytes: Optional[int] = None,
|
||||
):
|
||||
super().__init__(cache_record=cache_record, cache=cache, working_mem_bytes=working_mem_bytes)
|
||||
self.config = config
|
||||
|
||||
|
||||
@@ -108,7 +115,12 @@ class ModelLoaderBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
working_mem_bytes: Optional[int] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its confguration.
|
||||
|
||||
|
||||
@@ -38,7 +38,12 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
self._torch_device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
working_mem_bytes: Optional[int] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its configuration.
|
||||
|
||||
@@ -56,7 +61,9 @@ class ModelLoader(ModelLoaderBase):
|
||||
|
||||
with skip_torch_weight_init():
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
return LoadedModel(
|
||||
config=model_config, cache_record=cache_record, cache=self._ram_cache, working_mem_bytes=working_mem_bytes
|
||||
)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCache:
|
||||
|
||||
@@ -177,8 +177,12 @@ class ModelCache:
|
||||
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
def lock(self, key: str, working_mem_bytes: Optional[int]) -> None:
|
||||
"""Lock a model for use and move it into VRAM.
|
||||
|
||||
:param working_mem_bytes: The number of bytes of working memory to keep on the GPU while this model is loaded on
|
||||
the GPU. If None, self._execution_device_working_mem_gb is used.
|
||||
"""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
@@ -189,7 +193,7 @@ class ModelCache:
|
||||
return
|
||||
|
||||
try:
|
||||
self._load_locked_model(cache_entry)
|
||||
self._load_locked_model(cache_entry, working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Finished locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
@@ -209,9 +213,9 @@ class ModelCache:
|
||||
cache_entry.unlock()
|
||||
self._logger.debug(f"Unlocked model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
|
||||
def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int] = None) -> None:
|
||||
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
|
||||
# Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into
|
||||
# VRAM.
|
||||
@@ -234,7 +238,7 @@ class ModelCache:
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
@@ -250,16 +254,19 @@ class ModelCache:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
vram_available = self._get_vram_available()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
|
||||
self._logger.debug(
|
||||
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
def _get_vram_available(self) -> int:
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int] = None) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
|
||||
memory).
|
||||
"""
|
||||
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
|
||||
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
|
||||
|
||||
if self._execution_device.type == "cuda":
|
||||
vram_reserved = torch.cuda.memory_reserved(self._execution_device)
|
||||
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
|
||||
@@ -272,7 +279,7 @@ class ModelCache:
|
||||
else:
|
||||
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
|
||||
|
||||
vram_total_available_to_cache = vram_available_to_process - int(self._execution_device_working_mem_gb * GB)
|
||||
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
|
||||
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
|
||||
return vram_cur_available_to_cache
|
||||
|
||||
|
||||
Reference in New Issue
Block a user