diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 06de530154..73e6106ff9 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from functools import singledispatchmethod import einops @@ -24,6 +25,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params @invocation( @@ -49,7 +51,7 @@ class ImageToLatentsInvocation(BaseInvocation): @staticmethod def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: - assert isinstance(vae, torch.nn.Module) + assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -76,14 +78,21 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() + tiling_context = nullcontext() if tiled: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=512, + tile_latent_min_size=512 // 8, + tile_overlap_factor=0.25, + ) vae.enable_tiling() else: vae.disable_tiling() # non_noised_latents_from_image image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) - with torch.inference_mode(): + with torch.inference_mode(), tiling_context: latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) latents = vae.config.scaling_factor * latents diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 202e8bfa1b..3d714730dc 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import torch from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( @@ -8,7 +10,6 @@ from diffusers.models.attention_processor import ( ) from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION @@ -24,6 +25,7 @@ from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -53,9 +55,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(self.vae.vae) - assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) + assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.nn.Module) + assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -82,7 +84,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.to(dtype=torch.float16) latents = latents.half() + tiling_context = nullcontext() if self.tiled or context.config.get().force_tiled_decode: + tiling_context = patch_vae_tiling_params( + vae, + tile_sample_min_size=512, + tile_latent_min_size=512 // 8, + tile_overlap_factor=0.25, + ) vae.enable_tiling() else: vae.disable_tiling() @@ -90,7 +99,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): # clear memory as vae decode can request a lot TorchDevice.empty_cache() - with torch.inference_mode(): + with torch.inference_mode(), tiling_context: # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] diff --git a/invokeai/backend/stable_diffusion/vae_tiling.py b/invokeai/backend/stable_diffusion/vae_tiling.py new file mode 100644 index 0000000000..1fa7a18708 --- /dev/null +++ b/invokeai/backend/stable_diffusion/vae_tiling.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny + + +@contextmanager +def patch_vae_tiling_params( + vae: AutoencoderKL | AutoencoderTiny, + tile_sample_min_size: int, + tile_latent_min_size: int, + tile_overlap_factor: float, +): + # Record initial config. + orig_tile_sample_min_size = vae.tile_sample_min_size + orig_tile_latent_min_size = vae.tile_latent_min_size + orig_tile_overlap_factor = vae.tile_overlap_factor + + try: + # Apply target config. + vae.tile_sample_min_size = tile_sample_min_size + vae.tile_latent_min_size = tile_latent_min_size + vae.tile_overlap_factor = tile_overlap_factor + yield + finally: + # Restore initial config. + vae.tile_sample_min_size = orig_tile_sample_min_size + vae.tile_latent_min_size = orig_tile_latent_min_size + vae.tile_overlap_factor = orig_tile_overlap_factor