mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 14:37:57 -05:00
53 lines
2.4 KiB
Python
53 lines
2.4 KiB
Python
import logging
|
|
import os
|
|
import sys
|
|
|
|
|
|
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger):
|
|
"""Configure the PyTorch CUDA memory allocator. See
|
|
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
|
|
configurations.
|
|
"""
|
|
|
|
if "torch" in sys.modules:
|
|
raise RuntimeError("configure_torch_cuda_allocator() must be called before importing torch.")
|
|
|
|
# Log a warning if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
|
|
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
|
|
if prev_cuda_alloc_conf is not None:
|
|
if prev_cuda_alloc_conf == pytorch_cuda_alloc_conf:
|
|
logger.info(
|
|
f"PYTORCH_CUDA_ALLOC_CONF is already set to '{pytorch_cuda_alloc_conf}'. Skipping configuration."
|
|
)
|
|
return
|
|
else:
|
|
logger.warning(
|
|
f"Attempted to configure the PyTorch CUDA memory allocator with '{pytorch_cuda_alloc_conf}', but PYTORCH_CUDA_ALLOC_CONF is already set to "
|
|
f"'{prev_cuda_alloc_conf}'. Skipping configuration."
|
|
)
|
|
return
|
|
|
|
# Configure the PyTorch CUDA memory allocator.
|
|
# NOTE: It is important that this happens before torch is imported.
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = pytorch_cuda_alloc_conf
|
|
|
|
import torch
|
|
|
|
# Relevant docs: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError(
|
|
"Attempted to configure the PyTorch CUDA memory allocator, but no CUDA devices are available."
|
|
)
|
|
|
|
# Verify that the torch allocator was properly configured.
|
|
allocator_backend = torch.cuda.get_allocator_backend()
|
|
expected_backend = "cudaMallocAsync" if "cudaMallocAsync" in pytorch_cuda_alloc_conf else "native"
|
|
if allocator_backend != expected_backend:
|
|
raise RuntimeError(
|
|
f"Failed to configure the PyTorch CUDA memory allocator. Expected backend: '{expected_backend}', but got "
|
|
f"'{allocator_backend}'. Verify that 1) the pytorch_cuda_alloc_conf is set correctly, and 2) that torch is "
|
|
"not imported before calling configure_torch_cuda_allocator()."
|
|
)
|
|
|
|
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")
|