diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3d1c925570..0e8e2f7f3f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -146,7 +146,6 @@ class TextToLatentsInvocation(BaseInvocation): # TODO: consider making prompt optional to enable providing prompt through a link # fmt: off prompt: Optional[str] = Field(description="The prompt to generate an image from") - seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", ) noise: Optional[LatentsField] = Field(description="The noise to use") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", ) @@ -363,9 +362,87 @@ class LatentsToImageInvocation(BaseInvocation): session_id=context.graph_execution_state_id, node=self ) + torch.cuda.empty_cache() + context.services.images.save(image_type, image_name, image, metadata) return build_image_output( - image_type=image_type, - image_name=image_name, - image=image + image_type=image_type, image_name=image_name, image=image ) + + +LATENTS_INTERPOLATION_MODE = Literal[ + "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" +] + + +class ResizeLatentsInvocation(BaseInvocation): + """Resizes latents to explicit width/height.""" + + type: Literal["lresize"] = "lresize" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to") + downsample: int = Field( + default=8, ge=1, description="The downsampling factor (leave at 8 for SD)" + ) + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode" + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + # resizing + resized_latents = torch.nn.functional.interpolate( + latents, + size=( + self.height // self.downsample, + self.width // self.downsample, + ), + mode=self.mode, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.set(name, resized_latents) + return LatentsOutput(latents=LatentsField(latents_name=name)) + + +class ScaleLatentsInvocation(BaseInvocation): + """Scales latents by a given factor.""" + + type: Literal["lscale"] = "lscale" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to resize") + scale: int = Field( + default=2, ge=1, description="The factor by which to scale the latents" + ) + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode" + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + + (_, _, h, w) = latents.size() + + # resizing + resized_latents = torch.nn.functional.interpolate( + latents, + size=( + h * self.scale, + w * self.scale, + ), + mode=self.mode, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.set(name, resized_latents) + return LatentsOutput(latents=LatentsField(latents_name=name))