Files
InvokeAI/invokeai/backend/util/vae_working_memory.py
psychedelicious fd4c3bd27a refactor: estimate working vae memory during encode/decode
- Move the estimation logic to utility functions
- Estimate memory _within_ the encode and decode methods, ensuring we
_always_ estimate working memory when running a VAE
2025-08-18 21:43:14 +10:00

118 lines
5.0 KiB
Python

from typing import Literal
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def estimate_vae_working_memory_sd15_sdxl(
operation: Literal["encode", "decode"],
image_tensor: torch.Tensor,
vae: AutoencoderKL | AutoencoderTiny,
tile_size: int | None,
fp32: bool,
) -> int:
"""Estimate the working memory required to encode or decode the given tensor."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if fp32 else 2
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
if tile_size is not None:
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
h = tile_size
w = tile_size
working_memory = h * w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
working_memory = h * w * element_size * scaling_constant
if fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_cogview4(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = h * w * element_size * scaling_constant
print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_flux(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
out_h = latent_scale_factor_for_operation * image_tensor.shape[-2]
out_w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = out_h * out_w * element_size * scaling_constant
print(f"estimate_vae_working_memory_flux: {int(working_memory)}")
return int(working_memory)
def estimate_vae_working_memory_sd3(
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
# Encoding uses ~45% the working memory as decoding.
scaling_constant = 2200 if operation == "decode" else 1100
working_memory = h * w * element_size * scaling_constant
print(f"estimate_vae_working_memory_sd3: {int(working_memory)}")
return int(working_memory)