From 9b763b9e4c65a4286ad867fb431cb1053795f0ec Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 5 Jan 2024 10:31:58 -0500 Subject: [PATCH] Fix issue with seamless context managers when seamless is not configured. --- invokeai/app/invocations/latent.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3428fd1c2e..f7a1b4bee3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import contextlib from contextlib import ExitStack from functools import singledispatchmethod from typing import List, Literal, Optional, Union @@ -716,10 +717,23 @@ class DenoiseLatentsInvocation(BaseInvocation): **self.unet.unet.model_dump(), context=context, ) + + # Prepare seamless context, if configured. + seamless_context = contextlib.nullcontext() + seamless_config = self.unet.seamless + if seamless_config is not None: + seamless_context = set_seamless( + model=unet_info.context.model, + axes=seamless_config.axes, + skipped_layers=seamless_config.skipped_layers, + skip_second_resnet=seamless_config.skip_second_resnet, + skip_conv2=seamless_config.skip_conv2, + ) + with ( ExitStack() as exit_stack, ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, **self.unet.seamless.dict()), + seamless_context, unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), @@ -826,7 +840,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): context=context, ) - with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae: + # Prepare seamless context, if configured. + seamless_context = contextlib.nullcontext() + seamless_config = self.vae.seamless + if seamless_config is not None: + seamless_context = set_seamless( + model=vae_info.context.model, + axes=seamless_config.axes, + skipped_layers=seamless_config.skipped_layers, + skip_second_resnet=seamless_config.skip_second_resnet, + skip_conv2=seamless_config.skip_conv2, + ) + + with seamless_context, vae_info as vae: latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32)