mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor: estimate working vae memory during encode/decode
- Move the estimation logic to utility functions - Estimate memory _within_ the encode and decode methods, ensuring we _always_ estimate working memory when running a VAE
This commit is contained in:
@@ -17,6 +17,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
|
||||
|
||||
# TODO(ryand): This is effectively a copy of SD3ImageToLatentsInvocation and a subset of ImageToLatentsInvocation. We
|
||||
# should refactor to avoid this duplication.
|
||||
@@ -36,18 +37,12 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
estimated_working_memory = estimate_vae_working_memory_cogview4(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
@@ -74,10 +69,7 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
|
||||
)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -6,7 +6,6 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -20,6 +19,7 @@ from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_cogview4
|
||||
|
||||
# TODO(ryand): This is effectively a copy of SD3LatentsToImageInvocation and a subset of LatentsToImageInvocation. We
|
||||
# should refactor to avoid this duplication.
|
||||
@@ -39,22 +39,15 @@ class CogView4LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
estimated_working_memory = estimate_vae_working_memory_cogview4(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
|
||||
@@ -3,7 +3,6 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -18,6 +17,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -39,17 +39,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
assert isinstance(vae_info.model, AutoEncoder)
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -35,22 +36,16 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
assert isinstance(vae_info.model, AutoEncoder)
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
@@ -70,10 +65,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE")
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
|
||||
)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -27,6 +27,7 @@ from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -52,47 +53,23 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
def _estimate_working_memory(
|
||||
self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
h = tile_size
|
||||
w = tile_size
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
if self.fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def vae_encode(
|
||||
cls,
|
||||
vae_info: LoadedModel,
|
||||
upcast: bool,
|
||||
tiled: bool,
|
||||
image_tensor: torch.Tensor,
|
||||
tile_size: int = 0,
|
||||
estimated_working_memory: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
vae=vae_info.model,
|
||||
tile_size=tile_size if tiled else None,
|
||||
fp32=upcast,
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
orig_dtype = vae.dtype
|
||||
@@ -156,17 +133,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
use_tiling = self.tiled or context.config.get().force_tiled_decode
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model)
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info,
|
||||
upcast=self.fp32,
|
||||
tiled=self.tiled,
|
||||
tiled=self.tiled or context.config.get().force_tiled_decode,
|
||||
image_tensor=image_tensor,
|
||||
tile_size=self.tile_size,
|
||||
estimated_working_memory=estimated_working_memory,
|
||||
)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -27,6 +27,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -53,39 +54,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
def _estimate_working_memory(
|
||||
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
out_h = tile_size
|
||||
out_w = tile_size
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
if self.fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
@@ -94,8 +62,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
|
||||
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
|
||||
operation="decode",
|
||||
image_tensor=latents,
|
||||
vae=vae_info.model,
|
||||
tile_size=self.tile_size if use_tiling else None,
|
||||
fp32=self.fp32,
|
||||
)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
|
||||
@@ -17,6 +17,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -32,18 +33,12 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to encode")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
h = image_tensor.shape[-2]
|
||||
w = image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1100 # 50% of decode scaling constant (2200)
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
estimated_working_memory = estimate_vae_working_memory_sd3(
|
||||
operation="encode", image_tensor=image_tensor, vae=vae_info.model
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
@@ -70,10 +65,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, AutoencoderKL)
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory
|
||||
)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
|
||||
@@ -6,7 +6,6 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -20,6 +19,7 @@ from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd3
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -41,22 +41,15 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
estimated_working_memory = estimate_vae_working_memory_sd3(
|
||||
operation="decode", image_tensor=latents, vae=vae_info.model
|
||||
)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
|
||||
117
invokeai/backend/util/vae_working_memory.py
Normal file
117
invokeai/backend/util/vae_working_memory.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
|
||||
def estimate_vae_working_memory_sd15_sdxl(
|
||||
operation: Literal["encode", "decode"],
|
||||
image_tensor: torch.Tensor,
|
||||
vae: AutoencoderKL | AutoencoderTiny,
|
||||
tile_size: int | None,
|
||||
fp32: bool,
|
||||
) -> int:
|
||||
"""Estimate the working memory required to encode or decode the given tensor."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if fp32 else 2
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
if tile_size is not None:
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
h = tile_size
|
||||
w = tile_size
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
if fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
|
||||
def estimate_vae_working_memory_cogview4(
|
||||
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
print(f"estimate_vae_working_memory_cogview4: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
|
||||
def estimate_vae_working_memory_flux(
|
||||
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoEncoder
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
out_h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
out_w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
print(f"estimate_vae_working_memory_flux: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
|
||||
|
||||
def estimate_vae_working_memory_sd3(
|
||||
operation: Literal["encode", "decode"], image_tensor: torch.Tensor, vae: AutoencoderKL
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# Encode operations use approximately 50% of the memory required for decode operations
|
||||
|
||||
latent_scale_factor_for_operation = LATENT_SCALE_FACTOR if operation == "decode" else 1
|
||||
|
||||
h = latent_scale_factor_for_operation * image_tensor.shape[-2]
|
||||
w = latent_scale_factor_for_operation * image_tensor.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
|
||||
# This constant is determined experimentally and takes into consideration both allocated and reserved memory. See #8414
|
||||
# Encoding uses ~45% the working memory as decoding.
|
||||
scaling_constant = 2200 if operation == "decode" else 1100
|
||||
|
||||
working_memory = h * w * element_size * scaling_constant
|
||||
|
||||
print(f"estimate_vae_working_memory_sd3: {int(working_memory)}")
|
||||
|
||||
return int(working_memory)
|
||||
Reference in New Issue
Block a user