From 7d86f00d821add9da488ebc4042527db18e6a1a0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 11 Aug 2025 18:42:56 +1000 Subject: [PATCH] feat(mm): implement working memory estimation for VAE encode for all models Tell the model manager that we need some extra working memory for VAE encoding operations to prevent OOMs. See previous commit for investigation and determination of the magic numbers used. This safety measure is especially relevant now that we have FLUX Kontext and may be encoding rather large ref images. Without the working memory estimation we can OOM as we prepare for denoising. See #8405 for an example of this issue on a very low VRAM system. It's possible we can have the same issue on any GPU, though - just a matter of hitting the right combination of models loaded. --- .../invocations/cogview4_image_to_latents.py | 19 +++++++-- invokeai/app/invocations/flux_vae_encode.py | 17 ++++++-- invokeai/app/invocations/image_to_latents.py | 42 +++++++++++++++++-- .../app/invocations/sd3_image_to_latents.py | 19 +++++++-- .../flux/extensions/kontext_extension.py | 9 +++- 5 files changed, 93 insertions(+), 13 deletions(-) diff --git a/invokeai/app/invocations/cogview4_image_to_latents.py b/invokeai/app/invocations/cogview4_image_to_latents.py index 23f1c13e26..706fc7a0cb 100644 --- a/invokeai/app/invocations/cogview4_image_to_latents.py +++ b/invokeai/app/invocations/cogview4_image_to_latents.py @@ -36,9 +36,19 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to encode.") vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + working_memory = h * w * element_size * scaling_constant + return int(working_memory) + @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info as vae: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKL) vae.disable_tiling() @@ -62,7 +72,10 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") vae_info = context.models.load(self.vae.vae) - latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + assert isinstance(vae_info.model, AutoencoderKL) + + estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py index daf039b80d..7bb9f18e76 100644 --- a/invokeai/app/invocations/flux_vae_encode.py +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -35,14 +35,24 @@ class FluxVaeEncodeInvocation(BaseInvocation): input=Input.Connection, ) + def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + working_memory = h * w * element_size * scaling_constant + return int(working_memory) + @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: # TODO(ryand): Expose seed parameter at the invocation level. # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function # should be used for VAE encode sampling. generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) - with vae_info as vae: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoEncoder) vae_dtype = next(iter(vae.parameters())).dtype image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) @@ -60,7 +70,8 @@ class FluxVaeEncodeInvocation(BaseInvocation): image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") context.util.signal_progress("Running VAE") - latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 7508c0716d..6c1360ea65 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -52,11 +52,43 @@ class ImageToLatentsInvocation(BaseInvocation): tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=False, description=FieldDescriptions.fp32) + def _estimate_working_memory( + self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny + ) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + element_size = 4 if self.fp32 else 2 + scaling_constant = 1100 # 50% of decode scaling constant (2200) + + if use_tiling: + tile_size = self.tile_size + if tile_size == 0: + tile_size = vae.tile_sample_min_size + assert isinstance(tile_size, int) + h = tile_size + w = tile_size + working_memory = h * w * element_size * scaling_constant + + # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap + # and number of tiles. We could make this more precise in the future, but this should be good enough for + # most use cases. + working_memory = working_memory * 1.25 + else: + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + working_memory = h * w * element_size * scaling_constant + + if self.fp32: + # If we are running in FP32, then we should account for the likely increase in model size (~250MB). + working_memory += 250 * 2**20 + + return int(working_memory) + @staticmethod def vae_encode( - vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 + vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0, estimated_working_memory: int = 0 ) -> torch.Tensor: - with vae_info as vae: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype if upcast: @@ -113,14 +145,18 @@ class ImageToLatentsInvocation(BaseInvocation): image = context.images.get_pil(self.image.image_name) vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + use_tiling = self.tiled or context.config.get().force_tiled_decode + estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model) + context.util.signal_progress("Running VAE encoder") latents = self.vae_encode( - vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size + vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size, estimated_working_memory=estimated_working_memory ) latents = latents.to("cpu") diff --git a/invokeai/app/invocations/sd3_image_to_latents.py b/invokeai/app/invocations/sd3_image_to_latents.py index fc88e85aa5..12048bfce2 100644 --- a/invokeai/app/invocations/sd3_image_to_latents.py +++ b/invokeai/app/invocations/sd3_image_to_latents.py @@ -32,9 +32,19 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to encode") vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + working_memory = h * w * element_size * scaling_constant + return int(working_memory) + @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info as vae: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKL) vae.disable_tiling() @@ -58,7 +68,10 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") vae_info = context.models.load(self.vae.vae) - latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + assert isinstance(vae_info.model, AutoencoderKL) + + estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/backend/flux/extensions/kontext_extension.py b/invokeai/backend/flux/extensions/kontext_extension.py index 6aabcb6cda..d62b393731 100644 --- a/invokeai/backend/flux/extensions/kontext_extension.py +++ b/invokeai/backend/flux/extensions/kontext_extension.py @@ -131,7 +131,14 @@ class KontextExtension: # Continue with VAE encoding # Don't sample from the distribution for reference images - use the mean (matching ComfyUI) - with vae_info as vae: + # Estimate working memory for encode operation (50% of decode memory requirements) + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae_info.model.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + estimated_working_memory = int(h * w * element_size * scaling_constant) + + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoEncoder) vae_dtype = next(iter(vae.parameters())).dtype image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)