diff --git a/invokeai/app/invocations/cogview4_image_to_latents.py b/invokeai/app/invocations/cogview4_image_to_latents.py index 23f1c13e26..706fc7a0cb 100644 --- a/invokeai/app/invocations/cogview4_image_to_latents.py +++ b/invokeai/app/invocations/cogview4_image_to_latents.py @@ -36,9 +36,19 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to encode.") vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + working_memory = h * w * element_size * scaling_constant + return int(working_memory) + @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info as vae: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKL) vae.disable_tiling() @@ -62,7 +72,10 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") vae_info = context.models.load(self.vae.vae) - latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + assert isinstance(vae_info.model, AutoencoderKL) + + estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py index daf039b80d..7bb9f18e76 100644 --- a/invokeai/app/invocations/flux_vae_encode.py +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -35,14 +35,24 @@ class FluxVaeEncodeInvocation(BaseInvocation): input=Input.Connection, ) + def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + working_memory = h * w * element_size * scaling_constant + return int(working_memory) + @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: # TODO(ryand): Expose seed parameter at the invocation level. # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function # should be used for VAE encode sampling. generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) - with vae_info as vae: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoEncoder) vae_dtype = next(iter(vae.parameters())).dtype image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) @@ -60,7 +70,8 @@ class FluxVaeEncodeInvocation(BaseInvocation): image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") context.util.signal_progress("Running VAE") - latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 7508c0716d..6c1360ea65 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -52,11 +52,43 @@ class ImageToLatentsInvocation(BaseInvocation): tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size) fp32: bool = InputField(default=False, description=FieldDescriptions.fp32) + def _estimate_working_memory( + self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny + ) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + element_size = 4 if self.fp32 else 2 + scaling_constant = 1100 # 50% of decode scaling constant (2200) + + if use_tiling: + tile_size = self.tile_size + if tile_size == 0: + tile_size = vae.tile_sample_min_size + assert isinstance(tile_size, int) + h = tile_size + w = tile_size + working_memory = h * w * element_size * scaling_constant + + # We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap + # and number of tiles. We could make this more precise in the future, but this should be good enough for + # most use cases. + working_memory = working_memory * 1.25 + else: + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + working_memory = h * w * element_size * scaling_constant + + if self.fp32: + # If we are running in FP32, then we should account for the likely increase in model size (~250MB). + working_memory += 250 * 2**20 + + return int(working_memory) + @staticmethod def vae_encode( - vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0 + vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0, estimated_working_memory: int = 0 ) -> torch.Tensor: - with vae_info as vae: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype if upcast: @@ -113,14 +145,18 @@ class ImageToLatentsInvocation(BaseInvocation): image = context.images.get_pil(self.image.image_name) vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + use_tiling = self.tiled or context.config.get().force_tiled_decode + estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model) + context.util.signal_progress("Running VAE encoder") latents = self.vae_encode( - vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size + vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size, estimated_working_memory=estimated_working_memory ) latents = latents.to("cpu") diff --git a/invokeai/app/invocations/sd3_image_to_latents.py b/invokeai/app/invocations/sd3_image_to_latents.py index fc88e85aa5..12048bfce2 100644 --- a/invokeai/app/invocations/sd3_image_to_latents.py +++ b/invokeai/app/invocations/sd3_image_to_latents.py @@ -32,9 +32,19 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to encode") vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int: + """Estimate the working memory required by the invocation in bytes.""" + # Encode operations use approximately 50% of the memory required for decode operations + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + working_memory = h * w * element_size * scaling_constant + return int(working_memory) + @staticmethod - def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: - with vae_info as vae: + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor: + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoencoderKL) vae.disable_tiling() @@ -58,7 +68,10 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") vae_info = context.models.load(self.vae.vae) - latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + assert isinstance(vae_info.model, AutoencoderKL) + + estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model) + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory) latents = latents.to("cpu") name = context.tensors.save(tensor=latents) diff --git a/invokeai/backend/flux/extensions/kontext_extension.py b/invokeai/backend/flux/extensions/kontext_extension.py index 6aabcb6cda..d62b393731 100644 --- a/invokeai/backend/flux/extensions/kontext_extension.py +++ b/invokeai/backend/flux/extensions/kontext_extension.py @@ -131,7 +131,14 @@ class KontextExtension: # Continue with VAE encoding # Don't sample from the distribution for reference images - use the mean (matching ComfyUI) - with vae_info as vae: + # Estimate working memory for encode operation (50% of decode memory requirements) + h = image_tensor.shape[-2] + w = image_tensor.shape[-1] + element_size = next(vae_info.model.parameters()).element_size() + scaling_constant = 1100 # 50% of decode scaling constant (2200) + estimated_working_memory = int(h * w * element_size * scaling_constant) + + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): assert isinstance(vae, AutoEncoder) vae_dtype = next(iter(vae.parameters())).dtype image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)