Compare commits

...

7 Commits

17 changed files with 608 additions and 428 deletions

View File

@@ -1,105 +0,0 @@
import torch
from invokeai.backend.model_cache_v2.torch_module_overrides import CustomLinear, inject_custom_layers_into_module
class CachedModelV2:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
print("CachedModelV2.__init__")
self._model = model
inject_custom_layers_into_module(self._model)
self._compute_device = compute_device
# Memoized values.
self._total_size_cache = None
self._cur_vram_bytes_cache = None
@property
def model(self) -> torch.nn.Module:
return self._model
def total_bytes(self) -> int:
if self._total_size_cache is None:
self._total_size_cache = sum(p.numel() * p.element_size() for p in self._model.parameters())
return self._total_size_cache
def cur_vram_bytes(self) -> int:
"""Return the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes_cache is None:
self._cur_vram_bytes_cache = sum(
p.numel() * p.element_size()
for p in self._model.parameters()
if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes_cache
def full_load_to_vram(self):
"""Load all weights into VRAM."""
raise NotImplementedError("Not implemented")
self._cur_vram_bytes_cache = self.total_bytes()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Returns:
The number of bytes loaded into VRAM.
"""
vram_bytes_loaded = 0
def to_vram(m: torch.nn.Module):
nonlocal vram_bytes_loaded
if not isinstance(m, CustomLinear):
# We don't handle offload of this type of module.
return
m_device = m.weight.device
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
# Skip modules that are already on the compute device.
if m_device.type == self._compute_device.type:
return
# Check the size of the parameter.
if vram_bytes_loaded + m_bytes > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
return
vram_bytes_loaded += m_bytes
m.to(self._compute_device)
self._model.apply(to_vram)
self._cur_vram_bytes_cache = None
return vram_bytes_loaded
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded."""
vram_bytes_freed = 0
def from_vram(m: torch.nn.Module):
nonlocal vram_bytes_freed
if vram_bytes_freed >= vram_bytes_to_free:
return
m_device = m.weight.device
m_bytes = sum(p.numel() * p.element_size() for p in m.parameters())
if m_device.type != self._compute_device.type:
return
vram_bytes_freed += m_bytes
m.to("cpu")
self._model.apply(from_vram)
self._cur_vram_bytes_cache = None
return vram_bytes_freed

View File

@@ -1,18 +0,0 @@
import torch
from torch.utils._python_dispatch import TorchDispatchMode
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchAutocastContext(TorchDispatchMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_dispatch__(self, func, types, args, kwargs):
# print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
# print(f"Dispatch Log: {types}")
return cast_to_device_and_run(func, args, kwargs, self._to_device)

View File

@@ -1,16 +0,0 @@
import torch
from torch.overrides import TorchFunctionMode
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchFunctionAutocastContext(TorchFunctionMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_function__(self, func, types, args, kwargs=None):
return cast_to_device_and_run(func, args, kwargs or {}, self._to_device)

View File

@@ -1,26 +0,0 @@
from typing import TypeVar
import torch
T = TypeVar("T", torch.Tensor, None)
def cast_to_device(t: T, to_device: torch.device, non_blocking: bool = True) -> T:
if t is None:
return t
return t.to(to_device, non_blocking=non_blocking)
def inject_custom_layers_into_module(model: torch.nn.Module):
def inject_custom_layers(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
module.__class__ = CustomLinear
model.apply(inject_custom_layers)
class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)

View File

@@ -68,14 +68,14 @@ class LoadedModelWithoutConfig:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record.key)
try:
yield (self._cache_record.state_dict, self._cache_record.model)
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
finally:
self._cache.unlock(self._cache_record.key)
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._cache_record.model
return self._cache_record.cached_model.model
class LoadedModel(LoadedModelWithoutConfig):

View File

@@ -1,38 +1,22 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
@dataclass
class CacheRecord:
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
"""A class that represents a model in the model cache."""
# Cache key.
key: str
model: Any
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
# Model in memory.
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
# If locks > 0, the model is actively being used, so we should do our best to keep it on the compute device.
_locks: int = 0
def lock(self) -> None:

View File

@@ -28,10 +28,22 @@ class CachedModelOnlyFullLoad:
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better and implement it.
return None
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._is_in_vram:
return self._total_bytes
else:
return 0
def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram

View File

@@ -1,8 +1,23 @@
import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
add_autocast_to_module_forward,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
def set_nested_attr(obj: object, attr: str, value: object):
"""A helper function that extends setattr() to support nested attributes.
Example:
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
"""
attrs = attr.split(".")
for attr in attrs[:-1]:
obj = getattr(obj, attr)
setattr(obj, attrs[-1], value)
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
@@ -14,20 +29,48 @@ class CachedModelWithPartialLoad:
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# Monkey-patch the model to add autocasting to the model's forward method.
add_autocast_to_module_forward(model, compute_device)
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
self._total_bytes = sum(calc_tensor_size(p) for p in self._model.parameters())
self._cur_vram_bytes: int | None = None
@property
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return sum(calc_tensor_size(p) for p in self._model.parameters())
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
return sum(calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type)
if self._cur_vram_bytes is None:
self._cur_vram_bytes = sum(
calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
@@ -36,29 +79,39 @@ class CachedModelWithPartialLoad:
"""
vram_bytes_loaded = 0
# TODO(ryand): Should we use self._model.apply(...) instead and move modules around instead of moving tensors?
# This way we don't have to use the private _apply() method.
def to_vram(t: torch.Tensor):
nonlocal vram_bytes_loaded
# TODO(ryand): Iterate over buffers too?
for key, param in self._model.named_parameters():
# Skip parameters that are already on the compute device.
if t.device.type == self._compute_device.type:
return t
if param.device.type == self._compute_device.type:
continue
# Check the size of the parameter.
param_size = calc_tensor_size(t)
param_size = calc_tensor_size(param)
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
return t
continue
# Copy the parameter to the compute device.
# We use the 'overwrite' strategy from torch.nn.Module._apply().
# TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g.
# swap).
assert isinstance(param, torch.nn.Parameter)
assert param.is_leaf
out_param = torch.nn.Parameter(param.to(self._compute_device, copy=True), requires_grad=param.requires_grad)
set_nested_attr(self._model, key, out_param)
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
# handling gradients. We assert that this assumption is true.
assert param.grad is None
vram_bytes_loaded += param_size
return t.to(self._compute_device)
self._model._apply(to_vram)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
return vram_bytes_loaded
@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
@@ -67,18 +120,20 @@ class CachedModelWithPartialLoad:
"""
vram_bytes_freed = 0
def from_vram(t: torch.Tensor):
nonlocal vram_bytes_freed
# TODO(ryand): Iterate over buffers too?
for key, param in self._model.named_parameters():
if vram_bytes_freed >= vram_bytes_to_free:
return t
break
if t.device.type != self._compute_device.type:
return t
if param.device.type != self._compute_device.type:
continue
vram_bytes_freed += calc_tensor_size(t)
return t.to("cpu")
# Create a new parameter, but inject the existing CPU tensor into it.
out_param = torch.nn.Parameter(self._cpu_state_dict[key], requires_grad=param.requires_grad)
set_nested_attr(self._model, key, out_param)
vram_bytes_freed += calc_tensor_size(param)
self._model._apply(from_vram)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
return vram_bytes_freed

View File

@@ -1,21 +1,23 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
import gc
import math
import time
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
# Size of a GB in bytes.
GB = 2**30
@@ -24,7 +26,9 @@ GB = 2**30
MB = 2**20
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
"""Get the cache key for a model based on the optional submodel type."""
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
@@ -89,12 +93,15 @@ class ModelCache:
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._logger = PrefixedLoggerAdapter(
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@@ -128,20 +135,34 @@ class ModelCache:
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
"""Set the CacheStats object for collecting cache statistics."""
self._stats = stats
def put(self, key: str, model: AnyModel) -> None:
"""Add a model to the cache."""
if key in self._cached_models:
self._logger.debug(
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
)
return
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
running_on_cpu = self._execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
# Wrap model.
if isinstance(model, torch.nn.Module):
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
# running_on_cpu = self._execution_device == torch.device("cpu")
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
self._logger.debug(
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
)
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
@@ -157,6 +178,7 @@ class ModelCache:
else:
if self.stats:
self.stats.misses += 1
self._logger.debug(f"Cache miss: {key}")
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
@@ -165,16 +187,18 @@ class ModelCache:
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
)
# this moves the entry to the top (right end) of the stack
self._cache_stack = [k for k in self._cache_stack if k != key]
self._cache_stack.append(key)
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
return cache_entry
def lock(self, key: str) -> None:
@@ -182,13 +206,13 @@ class ModelCache:
cache_entry = self._cached_models[key]
cache_entry.lock()
self._logger.debug(f"Locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
try:
if self._lazy_offloading:
self._offload_unlocked_models(cache_entry.size)
self._move_model_to_device(cache_entry, self._execution_device)
cache_entry.loaded = True
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
self._print_cuda_stats()
self._load_locked_model(cache_entry)
self._logger.debug(
f"Finished locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
except torch.cuda.OutOfMemoryError:
self._logger.warning("Insufficient GPU memory to load model. Aborting")
cache_entry.unlock()
@@ -197,195 +221,291 @@ class ModelCache:
cache_entry.unlock()
raise
self._log_cache_state()
def unlock(self, key: str) -> None:
"""Unlock a model."""
cache_entry = self._cached_models[key]
cache_entry.unlock()
if not self._lazy_offloading:
self._offload_unlocked_models(0)
self._print_cuda_stats()
self._logger.debug(f"Unlocked model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
def _get_cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
for cache_record in self._cached_models.values():
total += cache_record.size
return total
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
"""Helper function for self.lock(). Loads a locked model into VRAM."""
vram_available = self._get_vram_available()
# The amount of additional VRAM that will be used if we fully load the model into VRAM.
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
model_total_bytes = cache_entry.cached_model.total_bytes()
model_vram_needed = model_total_bytes - model_cur_vram_bytes
self._logger.debug(
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
# Make room for the model in VRAM.
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
# possible.
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed)
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
# Check the updated vram_available after offloading.
vram_available = self._get_vram_available()
self._logger.debug(
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
# Move as much of the model as possible into VRAM.
model_bytes_loaded = 0
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
model_bytes_loaded = cache_entry.cached_model.partial_load_to_vram(vram_available)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
model_bytes_loaded = cache_entry.cached_model.full_load_to_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
vram_available = self._get_vram_available()
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
self._logger.debug(
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
def _get_vram_available(self) -> int:
"""Get the amount of VRAM available in the cache."""
return int(self._max_vram_cache_size * GB) - self._get_vram_in_use()
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use."""
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available in the cache."""
return int(self._max_cache_size * GB) - self._get_ram_in_use()
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
return model_key
def _offload_unlocked_models(self, size_required: int) -> None:
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.is_locked:
self._move_model_to_device(cache_entry, self._storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self._logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self._storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self._logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
"""Helper function for preparing a VRAM state log string."""
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
return (
f"model_total={model_total_bytes/MB:.0f} MB, "
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
+ f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
+ f"vram_available={(vram_available/MB):.0f} MB, "
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
offloaded. Of course, locked models are not offloaded.
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
Returns:
int: The number of bytes freed.
"""
self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.")
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
for cache_entry in cache_entries_increasing_size:
if vram_bytes_freed >= vram_bytes_to_free:
break
if cache_entry.is_locked:
continue
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
cache_entry_bytes_freed = cache_entry.cached_model.partial_unload_from_vram(
vram_bytes_to_free - vram_bytes_freed
)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
cache_entry_bytes_freed = cache_entry.cached_model.full_unload_from_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
if cache_entry_bytes_freed > 0:
self._logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
)
vram_bytes_freed += cache_entry_bytes_freed
TorchDevice.empty_cache()
return vram_bytes_freed
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
# """Move model into the indicated device.
# :param cache_entry: The CacheRecord for the model
# :param target_device: The torch.device to move the model into
# May raise a torch.cuda.OutOfMemoryError
# """
# self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
# source_device = cache_entry.device
# # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# # This would need to be revised to support multi-GPU.
# if torch.device(source_device).type == torch.device(target_device).type:
# return
# # Some models don't have a `to` method, in which case they run in RAM/CPU.
# if not hasattr(cache_entry.model, "to"):
# return
# # This roundabout method for moving the model around is done to avoid
# # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# # When moving to VRAM, we copy (not move) each element of the state dict from
# # RAM to a new state dict in VRAM, and then inject it into the model.
# # This operation is slightly faster than running `to()` on the whole model.
# #
# # When the model needs to be removed from VRAM we simply delete the copy
# # of the state dict in VRAM, and reinject the state dict that is cached
# # in RAM into the model. So this operation is very fast.
# start_model_to_time = time.time()
# snapshot_before = self._capture_memory_snapshot()
# try:
# if cache_entry.state_dict is not None:
# assert hasattr(cache_entry.model, "load_state_dict")
# if target_device == self._storage_device:
# cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
# else:
# new_dict: Dict[str, torch.Tensor] = {}
# for k, v in cache_entry.state_dict.items():
# new_dict[k] = v.to(target_device, copy=True)
# cache_entry.model.load_state_dict(new_dict, assign=True)
# cache_entry.model.to(target_device)
# cache_entry.device = target_device
# except Exception as e: # blow away cache entry
# self._delete_cache_entry(cache_entry)
# raise e
# snapshot_after = self._capture_memory_snapshot()
# end_model_to_time = time.time()
# self._logger.debug(
# f"Moved model '{cache_entry.key}' from {source_device} to"
# f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
# f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
# )
# if (
# snapshot_before is not None
# and snapshot_after is not None
# and snapshot_before.vram is not None
# and snapshot_after.vram is not None
# ):
# vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# # If the estimated model size does not match the change in VRAM, log a warning.
# if not math.isclose(
# vram_change,
# cache_entry.size,
# rel_tol=0.1,
# abs_tol=10 * MB,
# ):
# self._logger.debug(
# f"Moving model '{cache_entry.key}' from {source_device} to"
# f" {target_device} caused an unexpected change in VRAM usage. The model's"
# " estimated size may be incorrect. Estimated model size:"
# f" {(cache_entry.size/GB):.3f} GB.\n"
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
# )
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
ram_size_bytes = self._max_cache_size * GB
ram_in_use_bytes = self._get_ram_in_use()
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
ram_available_bytes = self._get_ram_available()
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
vram_size_bytes = self._max_vram_cache_size * GB
vram_in_use_bytes = self._get_vram_in_use()
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes = self._get_vram_available()
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
log = f"{title}\n"
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
log += log_format.format(
f"Storage Device ({self._storage_device.type})",
ram_size_bytes / MB,
ram_in_use_bytes / MB,
ram_in_use_bytes_percent,
ram_available_bytes / MB,
ram_available_bytes_percent,
)
log += log_format.format(
f"Compute Device ({self._execution_device.type})",
vram_size_bytes / MB,
vram_in_use_bytes / MB,
vram_in_use_bytes_percent,
vram_available_bytes / MB,
vram_available_bytes_percent,
)
if torch.cuda.is_available():
log += " {:<30} {} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
if include_entry_details and len(self._cached_models) > 0:
log += " Models:\n"
log_format = (
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
)
for cache_record in self._cached_models.values():
total_bytes = cache_record.cached_model.total_bytes()
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
cur_ram_bytes = total_bytes - cur_vram_bytes
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
log += log_format.format(
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
total_bytes / MB,
cur_vram_bytes / MB,
cur_vram_bytes_percent,
cur_ram_bytes / MB,
cur_ram_bytes_percent,
cache_record.is_locked,
)
def _print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self._get_cache_size() / GB)
self._logger.debug(log)
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self._storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.is_locked:
locked_in_vram_models += 1
self._logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
def make_room(self, bytes_needed: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
current_size = self._get_cache_size()
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
self._log_cache_state(title="Before dropping models:")
if current_size + bytes_needed > maximum_size:
self._logger.debug(
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
ram_bytes_available = self._get_ram_available()
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
ram_bytes_freed = 0
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self._logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
if not cache_entry.is_locked:
ram_bytes_freed += cache_entry.cached_model.total_bytes()
self._logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
)
current_size -= cache_entry.size
models_cleared += 1
self._delete_cache_entry(cache_entry)
del cache_entry
models_cleared += 1
else:
pos += 1
@@ -406,7 +526,8 @@ class ModelCache:
gc.collect()
TorchDevice.empty_cache()
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
self._log_cache_state(title="After dropping models:")
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
self._cache_stack.remove(cache_entry.key)

View File

@@ -0,0 +1,33 @@
from typing import Any, Callable
import torch
from torch.overrides import TorchFunctionMode
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
old_forward = m.forward
def new_forward(*args: Any, **kwargs: Any):
with TorchFunctionAutocastDeviceContext(to_device):
return old_forward(*args, **kwargs)
m.forward = new_forward
def _cast_to_device_and_run(
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_function__(
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
):
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)

View File

@@ -0,0 +1,12 @@
import logging
from typing import Any, MutableMapping
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
def __init__(self, logger: logging.Logger, prefix: str):
super().__init__(logger, {})
self.prefix = prefix
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
return f"[{self.prefix}] {msg}", kwargs

View File

@@ -1,24 +0,0 @@
import torch
from invokeai.backend.model_cache_v2.torch_autocast_context import TorchAutocastContext
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x
def test_torch_autocast_context():
model = DummyModule()
with TorchAutocastContext(to_device=torch.device("cuda")):
x = torch.randn(10, 10, device="cuda")
y = model(x)
print(y.shape)

View File

@@ -4,7 +4,7 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),

View File

@@ -4,7 +4,8 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.dummy_module import DummyModule
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
parameterize_mps_and_cuda = pytest.mark.parametrize(
("device"),
@@ -29,44 +30,145 @@ def test_cached_model_total_bytes(device: str):
linear_numel = 10 * 10 + 10
assert cached_model.total_bytes() == linear_numel * 4 * 2
cached_model.model.to(dtype=torch.float16)
assert cached_model.total_bytes() == linear_numel * 2 * 2
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
assert cached_model.cur_vram_bytes() == 0
cached_model.model.to(device=torch.device(device))
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() > 0
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
assert all(p.device.type == device for p in model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert loaded_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == device)
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str):
model = DummyModule()
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() == model_total_bytes
# Partially unload the model from VRAM.
bytes_to_free = int(model_total_bytes * 0.4)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
assert freed_bytes == sum(calc_tensor_size(p) for p in model.parameters() if p.device.type == "cpu")
@parameterize_mps_and_cuda
def test_cached_model_full_load(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
loaded_bytes = cached_model.full_load_to_vram()
assert loaded_bytes > 0
assert loaded_bytes == model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
assert all(p.device.type == device for p in model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
# Full load the rest of the model into VRAM.
loaded_bytes_2 = cached_model.full_load_to_vram()
assert loaded_bytes_2 > 0
assert loaded_bytes_2 < model_total_bytes
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
assert all(p.device.type == device for p in model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str):
def test_cached_model_full_unload_from_partial(device: str):
model = DummyModule()
model.to(device=torch.device(device))
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == model_total_bytes
bytes_to_free = int(model_total_bytes * 0.4)
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
assert freed_bytes >= bytes_to_free
assert freed_bytes < model_total_bytes
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
# Partially load the model into VRAM.
target_vram_bytes = int(model_total_bytes * 0.6)
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
assert loaded_bytes > 0
assert loaded_bytes < model_total_bytes
assert loaded_bytes == cached_model.cur_vram_bytes()
# Full unload the model from VRAM.
unloaded_bytes = cached_model.full_unload_from_vram()
assert unloaded_bytes > 0
assert unloaded_bytes == loaded_bytes
assert cached_model.cur_vram_bytes() == 0
assert all(p.device.type == "cpu" for p in model.parameters())
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
# The CPU state dict can be accessed and has the expected properties.
cpu_state_dict = cached_model.get_cpu_state_dict()
assert cpu_state_dict is not None
assert len(cpu_state_dict) == len(model.state_dict())
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
# Full load the model into VRAM.
cached_model.full_load_to_vram()
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
# The CPU state dict is still available, and still on the CPU.
cpu_state_dict = cached_model.get_cpu_state_dict()
assert cpu_state_dict is not None
assert len(cpu_state_dict) == len(model.state_dict())
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())

View File

@@ -0,0 +1,50 @@
import pytest
import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
TorchFunctionAutocastDeviceContext,
add_autocast_to_module_forward,
)
from tests.backend.model_manager.load.model_cache.dummy_module import DummyModule
def test_torch_function_autocast_device_context():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available.")
model = DummyModule()
# Model parameters should start off on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
with TorchFunctionAutocastDeviceContext(to_device=torch.device("cuda")):
x = torch.randn(10, 10, device="cuda")
y = model(x)
# The model output should be on the GPU.
assert y.device.type == "cuda"
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
def test_add_autocast_to_module_forward():
model = DummyModule()
assert all(p.device.type == "cpu" for p in model.parameters())
add_autocast_to_module_forward(model, torch.device("cuda"))
# After adding autocast, the model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
x = torch.randn(10, 10, device="cuda")
y = model(x)
# The model output should be on the GPU.
assert y.device.type == "cuda"
# The model parameters should still be on the CPU.
assert all(p.device.type == "cpu" for p in model.parameters())
# The autocast context should automatically be disabled after the model forward call completes.
# So, attempting to perform an operation with comflicting devices should raise an error.
with pytest.raises(RuntimeError):
_ = torch.randn(10, device="cuda") * torch.randn(10, device="cpu")