feat(mm): implement working memory estimation for VAE encode for all models

Tell the model manager that we need some extra working memory for VAE
encoding operations to prevent OOMs.

See previous commit for investigation and determination of the magic
numbers used.

This safety measure is especially relevant now that we have FLUX Kontext
and may be encoding rather large ref images. Without the working memory
estimation we can OOM as we prepare for denoising.

See #8405 for an example of this issue on a very low VRAM system. It's
possible we can have the same issue on any GPU, though - just a matter
of hitting the right combination of models loaded.
This commit is contained in:
psychedelicious
2025-08-11 18:42:56 +10:00
parent 7785061e7d
commit 7d86f00d82
5 changed files with 93 additions and 13 deletions

View File

@@ -36,9 +36,19 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
working_memory = h * w * element_size * scaling_constant
return int(working_memory)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
@@ -62,7 +72,10 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@@ -35,14 +35,24 @@ class FluxVaeEncodeInvocation(BaseInvocation):
input=Input.Connection,
)
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
working_memory = h * w * element_size * scaling_constant
return int(working_memory)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
# TODO(ryand): Expose seed parameter at the invocation level.
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
@@ -60,7 +70,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@@ -52,11 +52,43 @@ class ImageToLatentsInvocation(BaseInvocation):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
def _estimate_working_memory(
self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
element_size = 4 if self.fp32 else 2
scaling_constant = 1100 # 50% of decode scaling constant (2200)
if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
h = tile_size
w = tile_size
working_memory = h * w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
working_memory = h * w * element_size * scaling_constant
if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
return int(working_memory)
@staticmethod
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0, estimated_working_memory: int = 0
) -> torch.Tensor:
with vae_info as vae:
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
@@ -113,14 +145,18 @@ class ImageToLatentsInvocation(BaseInvocation):
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
use_tiling = self.tiled or context.config.get().force_tiled_decode
estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model)
context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size, estimated_working_memory=estimated_working_memory
)
latents = latents.to("cpu")

View File

@@ -32,9 +32,19 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to encode")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# Encode operations use approximately 50% of the memory required for decode operations
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
working_memory = h * w * element_size * scaling_constant
return int(working_memory)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
@@ -58,7 +68,10 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
assert isinstance(vae_info.model, AutoencoderKL)
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@@ -131,7 +131,14 @@ class KontextExtension:
# Continue with VAE encoding
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
with vae_info as vae:
# Estimate working memory for encode operation (50% of decode memory requirements)
h = image_tensor.shape[-2]
w = image_tensor.shape[-1]
element_size = next(vae_info.model.parameters()).element_size()
scaling_constant = 1100 # 50% of decode scaling constant (2200)
estimated_working_memory = int(h * w * element_size * scaling_constant)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)