Do not apply the autocast context when models are fully loaded onto the GPU - it adds some overhead.

This commit is contained in:
Ryan Dick
2024-12-18 21:51:39 +00:00
parent 4ce2042d65
commit e684e49299

View File

@@ -2,6 +2,7 @@ import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
add_autocast_to_module_forward,
remove_autocast_from_module_forward,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -32,12 +33,11 @@ class CachedModelWithPartialLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# Monkey-patch the model to add autocasting to the model's forward method.
add_autocast_to_module_forward(model, compute_device)
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
self._cur_vram_bytes: int | None = None
self._update_model_autocast_context()
@property
def model(self) -> torch.nn.Module:
return self._model
@@ -104,6 +104,12 @@ class CachedModelWithPartialLoad:
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
if self._cur_vram_bytes == self.total_bytes():
# HACK(ryand): The model should already be on the compute device, but we have to call this to ensure that
# all non-persistent buffers are moved (i.e. buffers that are not registered in the state dict).
self._model.to(self._compute_device)
self._update_model_autocast_context()
return vram_bytes_loaded
@torch.no_grad()
@@ -133,4 +139,17 @@ class CachedModelWithPartialLoad:
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
self._update_model_autocast_context()
return vram_bytes_freed
def _update_model_autocast_context(self):
"""A helper function that should be called whenever the model's VRAM usage changes to add/remove the autocast
context.
"""
if self.cur_vram_bytes() == self.total_bytes():
# We remove the autocast context when the model is fully loaded into VRAM, because the context causes some
# runtime overhead.
remove_autocast_from_module_forward(self._model)
else:
# Monkey-patch the model to add autocasting to the model's forward method.
add_autocast_to_module_forward(self._model, self._compute_device)