mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 13:04:58 -05:00
feat(mm): implement working memory estimation for VAE encode for all models
Tell the model manager that we need some extra working memory for VAE encoding operations to prevent OOMs. See previous commit for investigation and determination of the magic numbers used. This safety measure is especially relevant now that we have FLUX Kontext and may be encoding rather large ref images. Without the working memory estimation we can OOM as we prepare for denoising. See #8405 for an example of this issue on a very low VRAM system. It's possible we can have the same issue on any GPU, though - just a matter of hitting the right combination of models loaded.
This commit is contained in:
@@ -36,9 +36,19 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
@@ -62,7 +72,10 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -35,14 +35,24 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
@@ -60,7 +70,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -52,11 +52,43 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
def _estimate_working_memory(
|
||||
self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
h = tile_size
|
||||
w = tile_size
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
if self.fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
|
||||
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0, estimated_working_memory: int = 0
|
||||
) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
@@ -113,14 +145,18 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
use_tiling = self.tiled or context.config.get().force_tiled_decode
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model)
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size, estimated_working_memory=estimated_working_memory
|
||||
)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -32,9 +32,19 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to encode")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
@@ -58,7 +68,10 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -131,7 +131,14 @@ class KontextExtension:
|
||||
|
||||
# Continue with VAE encoding
|
||||
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
|
||||
with vae_info as vae:
|
||||
# Estimate working memory for encode operation (50% of decode memory requirements)
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae_info.model.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
estimated_working_memory = int(h * w * element_size * scaling_constant)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
Reference in New Issue
Block a user