mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 08:38:00 -05:00
- Move the estimation logic to utility functions - Estimate memory _within_ the encode and decode methods, ensuring we _always_ estimate working memory when running a VAE
118 lines
5.0 KiB
Python
118 lines
5.0 KiB
Python
from typing import Literal
|
|
|
|
import torch
|
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
|
|
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
|
|
|
|
|
def estimate_vae_working_memory_sd15_sdxl(
|
|
operation: Literal["encode", "decode"],
|
|
image_tensor: torch.Tensor,
|
|
vae: AutoencoderKL | AutoencoderTiny,
|
|
tile_size: int | None,
|
|
fp32: bool,
|
|
) -> int:
|
|
"""Estimate the working memory required to encode or decode the given tensor."""
|
|
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
|
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
|
element_size = 4 if fp32 else 2
|
|
|
|
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
|
# Encoding uses ~45% the working memory as decoding.
|
|
scaling_constant = 2200 if operation == "decode" else 1100
|
|
|
|
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
|
|
|
if tile_size is not None:
|
|
if tile_size == 0:
|
|
tile_size = vae.tile_sample_min_size
|
|
assert isinstance(tile_size, int)
|
|
h = tile_size
|
|
w = tile_size
|
|
working_memory = h * w * element_size * scaling_constant
|
|
|
|
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
|
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
|
# most use cases.
|
|
working_memory = working_memory * 1.25
|
|
else:
|
|
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
|
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
|
working_memory = h * w * element_size * scaling_constant
|
|
|
|
if fp32:
|
|
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
|
working_memory += 250 * 2**20
|
|
|
|
print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")
|
|
|
|
return int(working_memory)
|
|
|
|
|
|
def estimate_vae_working_memory_cogview4(
|
|
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
|
|
) -> int:
|
|
"""Estimate the working memory required by the invocation in bytes."""
|
|
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
|
|
|
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
|
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
|
element_size = next(vae.parameters()).element_size()
|
|
|
|
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
|
# Encoding uses ~45% the working memory as decoding.
|
|
scaling_constant = 2200 if operation == "decode" else 1100
|
|
working_memory = h * w * element_size * scaling_constant
|
|
|
|
print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}")
|
|
|
|
return int(working_memory)
|
|
|
|
|
|
def estimate_vae_working_memory_flux(
|
|
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder
|
|
) -> int:
|
|
"""Estimate the working memory required by the invocation in bytes."""
|
|
|
|
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
|
|
|
out_h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
|
out_w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
|
element_size = next(vae.parameters()).element_size()
|
|
|
|
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
|
# Encoding uses ~45% the working memory as decoding.
|
|
scaling_constant = 2200 if operation == "decode" else 1100
|
|
|
|
working_memory = out_h * out_w * element_size * scaling_constant
|
|
|
|
print(f"estimate_vae_working_memory_flux: {int(working_memory)}")
|
|
|
|
return int(working_memory)
|
|
|
|
|
|
def estimate_vae_working_memory_sd3(
|
|
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
|
|
) -> int:
|
|
"""Estimate the working memory required by the invocation in bytes."""
|
|
# Encode operations use approximately 50% of the memory required for decode operations
|
|
|
|
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
|
|
|
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
|
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
|
element_size = next(vae.parameters()).element_size()
|
|
|
|
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
|
# Encoding uses ~45% the working memory as decoding.
|
|
scaling_constant = 2200 if operation == "decode" else 1100
|
|
|
|
working_memory = h * w * element_size * scaling_constant
|
|
|
|
print(f"estimate_vae_working_memory_sd3: {int(working_memory)}")
|
|
|
|
return int(working_memory)
|