From 843f2d71d663ac95f970645489fbc2f7f74be9fb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 27 Nov 2023 11:02:10 -0500 Subject: [PATCH] Copy CropLatentsInvocation from https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80. --- invokeai/app/invocations/latent.py | 53 ++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ab59b41865..26294ed7f7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1166,3 +1166,56 @@ class BlendLatentsInvocation(BaseInvocation): # context.services.latents.set(name, resized_latents) context.services.latents.save(name, blended_latents) return build_latents_output(latents_name=name, latents=blended_latents) + + +@invocation( + "lcrop", + title="Crop Latents", + tags=["latents", "crop"], + category="latents", + version="1.0.0", +) +class CropLatentsInvocation(BaseInvocation): + """Crops latents""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + width: int = InputField( + ge=64, + multiple_of=_downsampling_factor, + description=FieldDescriptions.width, + ) + height: int = InputField( + ge=64, + multiple_of=_downsampling_factor, + description=FieldDescriptions.width, + ) + x_offset: int = InputField( + ge=0, + multiple_of=_downsampling_factor, + description="x-coordinate", + ) + y_offset: int = InputField( + ge=0, + multiple_of=_downsampling_factor, + description="y-coordinate", + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + + x1 = self.x_offset // _downsampling_factor + y1 = self.y_offset // _downsampling_factor + x2 = x1 + (self.width // _downsampling_factor) + y2 = y1 + (self.height // _downsampling_factor) + + cropped_latents = latents[:, :, y1:y2, x1:x2] + + # resized_latents = resized_latents.to("cpu") + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.save(name, cropped_latents) + + return build_latents_output(latents_name=name, latents=cropped_latents)