feat(nodes): simplify MaskFromIDInvocation

This commit is contained in:
psychedelicious
2025-01-21 12:10:10 +11:00
parent 8d2b4e2bf5
commit 14b5c871dc

View File

@@ -1000,7 +1000,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Mask from ID",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
version="1.0.1",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate a mask for a particular color in an ID Map"""
@@ -1010,40 +1010,24 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
threshold: int = InputField(default=100, description="Threshold for color detection")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
r, g, b, a = rgba_color
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
return hex_code
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
if id_mask.mode != "RGB":
id_mask = id_mask.convert("RGB")
# Can directly just use the tuple but I'll leave this rgba_to_hex here
# incase anyone prefers using hex codes directly instead of the color picker
hex_color_str = self.rgba_to_hex(color)
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
np_color = numpy.array(self.color.tuple())
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
color_distance = numpy.linalg.norm(image - np_color, axis=-1)
# Create a mask based on the threshold and the distance calculated above
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255
# Convert the mask back to PIL
binary_mask_pil = Image.fromarray(binary_mask)
return binary_mask_pil
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
if self.invert:
mask = ImageOps.invert(mask)
binary_mask_pil = ImageOps.invert(binary_mask_pil)
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)