fix(nodes): apply_tensor_mask_to_image transparent image handling

Fix an issue where if the input image is transparent in a region to be masked, that transparent region ends up opaque black. Need to respect the input image transparency by applying the mask to the alpha channel only.
This commit is contained in:
psychedelicious
2024-10-23 13:49:11 +10:00
parent ff72315db2
commit 48a471bfb8

View File

@@ -159,13 +159,15 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
version="1.0.0",
)
class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
"""Applies a tensor mask to an image.
The image is converted to RGBA and the mask is applied to the alpha channel."""
mask: TensorField = InputField(description="The mask tensor to apply.")
image: ImageField = InputField(description="The image to apply the mask to.")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGB")
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = context.tensors.load(self.mask.tensor_name)
# Squeeze the channel dimension if it exists.
@@ -176,10 +178,22 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy().astype(np.uint8)
# Apply the mask only to the alpha channel where the original alpha is non-zero. This preserves the original
# image's transparency - else the transparent regions would end up as opaque black.
# Separate the image into R, G, B, and A channels
image_np = pil_to_np(image)
masked_image_np = np.dstack([image_np, mask_np]) # Add the alpha channel
r, g, b, a = np.split(image_np, 4, axis=-1)
# Apply the mask to the alpha channel
new_alpha = np.where(a.squeeze() > 0, mask_np, a.squeeze())
# Stack the RGB channels with the modified alpha
masked_image_np = np.dstack([r.squeeze(), g.squeeze(), b.squeeze(), new_alpha])
# Convert back to an image (RGBA)
masked_image = Image.fromarray(masked_image_np, "RGBA")
masked_image = Image.fromarray(masked_image_np.astype(np.uint8), "RGBA")
image_dto = context.images.save(image=masked_image)
return ImageOutput.build(image_dto)