mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Do not apply the autocast context when models are fully loaded onto the GPU - it adds some overhead.
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
add_autocast_to_module_forward,
|
||||
remove_autocast_from_module_forward,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
@@ -32,12 +33,11 @@ class CachedModelWithPartialLoad:
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(model, compute_device)
|
||||
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._update_model_autocast_context()
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
@@ -104,6 +104,12 @@ class CachedModelWithPartialLoad:
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
if self._cur_vram_bytes == self.total_bytes():
|
||||
# HACK(ryand): The model should already be on the compute device, but we have to call this to ensure that
|
||||
# all non-persistent buffers are moved (i.e. buffers that are not registered in the state dict).
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
self._update_model_autocast_context()
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -133,4 +139,17 @@ class CachedModelWithPartialLoad:
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
self._update_model_autocast_context()
|
||||
return vram_bytes_freed
|
||||
|
||||
def _update_model_autocast_context(self):
|
||||
"""A helper function that should be called whenever the model's VRAM usage changes to add/remove the autocast
|
||||
context.
|
||||
"""
|
||||
if self.cur_vram_bytes() == self.total_bytes():
|
||||
# We remove the autocast context when the model is fully loaded into VRAM, because the context causes some
|
||||
# runtime overhead.
|
||||
remove_autocast_from_module_forward(self._model)
|
||||
else:
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(self._model, self._compute_device)
|
||||
|
||||
Reference in New Issue
Block a user