mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 00:58:02 -05:00
1564 lines
68 KiB
Python
1564 lines
68 KiB
Python
# All nodes in this file are originally pulled from https://github.com/dwringer/composition-nodes
|
|
|
|
import os
|
|
from ast import literal_eval as tuple_from_string
|
|
from functools import reduce
|
|
from io import BytesIO
|
|
from math import pi as PI
|
|
from typing import Literal, Optional
|
|
|
|
import cv2
|
|
import numpy
|
|
import torch
|
|
from PIL import Image, ImageChops, ImageCms, ImageColor, ImageDraw, ImageEnhance, ImageOps
|
|
from torchvision.transforms.functional import to_pil_image as pil_image_from_tensor
|
|
|
|
from invokeai.app.invocations.primitives import ImageOutput
|
|
from invokeai.backend.image_util.composition import (
|
|
CIELAB_TO_UPLAB_ICC_PATH,
|
|
MAX_FLOAT,
|
|
equivalent_achromatic_lightness,
|
|
gamut_clip_tensor,
|
|
hsl_from_srgb,
|
|
linear_srgb_from_oklab,
|
|
linear_srgb_from_srgb,
|
|
okhsl_from_srgb,
|
|
okhsv_from_srgb,
|
|
oklab_from_linear_srgb,
|
|
remove_nans,
|
|
srgb_from_hsl,
|
|
srgb_from_linear_srgb,
|
|
srgb_from_okhsl,
|
|
srgb_from_okhsv,
|
|
tensor_from_pil_image,
|
|
)
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
|
from invokeai.invocation_api import (
|
|
BaseInvocation,
|
|
ImageField,
|
|
InputField,
|
|
InvocationContext,
|
|
WithBoard,
|
|
WithMetadata,
|
|
invocation,
|
|
)
|
|
|
|
HUE_COLOR_SPACES = Literal[
|
|
"HSV / HSL / RGB",
|
|
"Okhsl",
|
|
"Okhsv",
|
|
"*Oklch / Oklab",
|
|
"*LCh / CIELab",
|
|
"*UPLab (w/CIELab_to_UPLab.icc)",
|
|
]
|
|
|
|
|
|
@invocation(
|
|
"invokeai_img_hue_adjust_plus",
|
|
title="Adjust Image Hue Plus",
|
|
tags=["image", "hue", "oklab", "cielab", "uplab", "lch", "hsv", "hsl", "lab"],
|
|
category="image",
|
|
version="1.2.0",
|
|
)
|
|
class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Adjusts the Hue of an image by rotating it in the selected color space. Originally created by @dwringer"""
|
|
|
|
image: ImageField = InputField(description="The image to adjust")
|
|
space: HUE_COLOR_SPACES = InputField(
|
|
default="HSV / HSL / RGB",
|
|
description="Color space in which to rotate hue by polar coords (*: non-invertible)",
|
|
)
|
|
degrees: float = InputField(default=0.0, description="Degrees by which to rotate image hue")
|
|
preserve_lightness: bool = InputField(default=False, description="Whether to preserve CIELAB lightness values")
|
|
ok_adaptive_gamut: float = InputField(
|
|
ge=0, default=0.05, description="Higher preserves chroma at the expense of lightness (Oklab)"
|
|
)
|
|
ok_high_precision: bool = InputField(
|
|
default=True, description="Use more steps in computing gamut (Oklab/Okhsv/Okhsl)"
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image_in = context.images.get_pil(self.image.image_name)
|
|
image_out = None
|
|
space = self.space.split()[0].lower().strip("*")
|
|
|
|
# Keep the mode and alpha channel for restoration after shifting the hue:
|
|
image_mode = image_in.mode
|
|
original_mode = image_mode
|
|
alpha_channel = None
|
|
if (image_mode == "RGBA") or (image_mode == "LA") or (image_mode == "PA"):
|
|
alpha_channel = image_in.getchannel("A")
|
|
elif (image_mode == "RGBa") or (image_mode == "La") or (image_mode == "Pa"):
|
|
alpha_channel = image_in.getchannel("a")
|
|
if (image_mode == "RGBA") or (image_mode == "RGBa"):
|
|
image_mode = "RGB"
|
|
elif (image_mode == "LA") or (image_mode == "La"):
|
|
image_mode = "L"
|
|
elif image_mode == "PA":
|
|
image_mode = "P"
|
|
|
|
image_in = image_in.convert("RGB")
|
|
|
|
# Keep the CIELAB L* lightness channel for restoration if Preserve Lightness is selected:
|
|
(channel_l, channel_a, channel_b, profile_srgb, profile_lab, profile_uplab, lab_transform, uplab_transform) = (
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
if self.preserve_lightness or (space == "lch") or (space == "uplab"):
|
|
profile_srgb = ImageCms.createProfile("sRGB")
|
|
if space == "uplab":
|
|
with open(CIELAB_TO_UPLAB_ICC_PATH, "rb") as f:
|
|
profile_uplab = ImageCms.getOpenProfile(f)
|
|
if profile_uplab is None:
|
|
profile_lab = ImageCms.createProfile("LAB", colorTemp=6500)
|
|
else:
|
|
profile_lab = ImageCms.createProfile("LAB", colorTemp=5000)
|
|
|
|
lab_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_srgb, profile_lab, "RGB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_in, lab_transform)
|
|
if profile_uplab is not None:
|
|
uplab_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_lab, profile_uplab, "LAB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, uplab_transform)
|
|
|
|
channel_l = image_out.getchannel("L")
|
|
channel_a = image_out.getchannel("A")
|
|
channel_b = image_out.getchannel("B")
|
|
|
|
if space == "hsv":
|
|
hsv_tensor = image_resized_to_grid_as_tensor(image_in.convert("HSV"), normalize=False, multiple_of=1)
|
|
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
|
|
image_out = pil_image_from_tensor(hsv_tensor, mode="HSV").convert("RGB")
|
|
|
|
elif space == "okhsl":
|
|
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
|
|
hsl_tensor = okhsl_from_srgb(rgb_tensor, steps=(3 if self.ok_high_precision else 1))
|
|
hsl_tensor[0, :, :] = torch.remainder(torch.add(hsl_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
|
|
rgb_tensor = srgb_from_okhsl(hsl_tensor, alpha=0.0)
|
|
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")
|
|
|
|
elif space == "okhsv":
|
|
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
|
|
hsv_tensor = okhsv_from_srgb(rgb_tensor, steps=(3 if self.ok_high_precision else 1))
|
|
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
|
|
rgb_tensor = srgb_from_okhsv(hsv_tensor, alpha=0.0)
|
|
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")
|
|
|
|
elif (space == "lch") or (space == "uplab"):
|
|
# <Channels a and b were already extracted, above.>
|
|
|
|
a_tensor = image_resized_to_grid_as_tensor(channel_a, normalize=True, multiple_of=1)
|
|
b_tensor = image_resized_to_grid_as_tensor(channel_b, normalize=True, multiple_of=1)
|
|
|
|
# L*a*b* to L*C*h
|
|
c_tensor = torch.sqrt(torch.add(torch.pow(a_tensor, 2.0), torch.pow(b_tensor, 2.0)))
|
|
h_tensor = torch.atan2(b_tensor, a_tensor)
|
|
|
|
# Rotate h
|
|
rot_rads = (self.degrees / 180.0) * PI
|
|
|
|
h_rot = torch.add(h_tensor, rot_rads)
|
|
h_rot = torch.sub(torch.remainder(torch.add(h_rot, PI), 2 * PI), PI)
|
|
|
|
# L*C*h to L*a*b*
|
|
a_tensor = torch.mul(c_tensor, torch.cos(h_rot))
|
|
b_tensor = torch.mul(c_tensor, torch.sin(h_rot))
|
|
|
|
# -1..1 -> 0..1 for all elts of a, b
|
|
a_tensor = torch.div(torch.add(a_tensor, 1.0), 2.0)
|
|
b_tensor = torch.div(torch.add(b_tensor, 1.0), 2.0)
|
|
|
|
a_img = pil_image_from_tensor(a_tensor)
|
|
b_img = pil_image_from_tensor(b_tensor)
|
|
|
|
image_out = Image.merge("LAB", (channel_l, a_img, b_img))
|
|
|
|
if profile_uplab is not None:
|
|
deuplab_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_uplab, profile_lab, "LAB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, deuplab_transform)
|
|
|
|
rgb_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_lab, profile_srgb, "LAB", "RGB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, rgb_transform)
|
|
|
|
elif space == "oklch":
|
|
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
|
|
|
|
linear_srgb_tensor = linear_srgb_from_srgb(rgb_tensor)
|
|
|
|
lab_tensor = oklab_from_linear_srgb(linear_srgb_tensor)
|
|
|
|
# L*a*b* to L*C*h
|
|
c_tensor = torch.sqrt(torch.add(torch.pow(lab_tensor[1, :, :], 2.0), torch.pow(lab_tensor[2, :, :], 2.0)))
|
|
h_tensor = torch.atan2(lab_tensor[2, :, :], lab_tensor[1, :, :])
|
|
|
|
# Rotate h
|
|
rot_rads = (self.degrees / 180.0) * PI
|
|
|
|
h_rot = torch.add(h_tensor, rot_rads)
|
|
h_rot = torch.remainder(torch.add(h_rot, 2 * PI), 2 * PI)
|
|
|
|
# L*C*h to L*a*b*
|
|
lab_tensor[1, :, :] = torch.mul(c_tensor, torch.cos(h_rot))
|
|
lab_tensor[2, :, :] = torch.mul(c_tensor, torch.sin(h_rot))
|
|
|
|
linear_srgb_tensor = linear_srgb_from_oklab(lab_tensor)
|
|
|
|
rgb_tensor = srgb_from_linear_srgb(
|
|
linear_srgb_tensor, alpha=self.ok_adaptive_gamut, steps=(3 if self.ok_high_precision else 1)
|
|
)
|
|
|
|
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")
|
|
|
|
# Not all modes can convert directly to LAB using pillow:
|
|
# image_out = image_out.convert("RGB")
|
|
|
|
# Restore the L* channel if required:
|
|
if self.preserve_lightness and (not ((space == "lch") or (space == "uplab"))):
|
|
if profile_uplab is None:
|
|
profile_lab = ImageCms.createProfile("LAB", colorTemp=6500)
|
|
else:
|
|
profile_lab = ImageCms.createProfile("LAB", colorTemp=5000)
|
|
|
|
lab_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_srgb, profile_lab, "RGB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
|
|
image_out = ImageCms.applyTransform(image_out, lab_transform)
|
|
|
|
if profile_uplab is not None:
|
|
uplab_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_lab, profile_uplab, "LAB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, uplab_transform)
|
|
|
|
image_out = Image.merge("LAB", tuple([channel_l] + [image_out.getchannel(c) for c in "AB"]))
|
|
|
|
if profile_uplab is not None:
|
|
deuplab_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_uplab, profile_lab, "LAB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, deuplab_transform)
|
|
|
|
rgb_transform = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_lab, profile_srgb, "LAB", "RGB", renderingIntent=2, flags=0x2400
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, rgb_transform)
|
|
|
|
# Restore the original image mode, with alpha channel if required:
|
|
image_out = image_out.convert(image_mode)
|
|
if "a" in original_mode.lower():
|
|
image_out = Image.merge(
|
|
original_mode, tuple([image_out.getchannel(c) for c in image_mode] + [alpha_channel])
|
|
)
|
|
|
|
image_dto = context.images.save(image_out)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
@invocation(
|
|
"invokeai_img_enhance",
|
|
title="Enhance Image",
|
|
tags=["enhance", "image"],
|
|
category="image",
|
|
version="1.2.1",
|
|
)
|
|
class InvokeImageEnhanceInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Applies processing from PIL's ImageEnhance module. Originally created by @dwringer"""
|
|
|
|
image: ImageField = InputField(description="The image for which to apply processing")
|
|
invert: bool = InputField(default=False, description="Whether to invert the image colors")
|
|
color: float = InputField(ge=0, default=1.0, description="Color enhancement factor")
|
|
contrast: float = InputField(ge=0, default=1.0, description="Contrast enhancement factor")
|
|
brightness: float = InputField(ge=0, default=1.0, description="Brightness enhancement factor")
|
|
sharpness: float = InputField(ge=0, default=1.0, description="Sharpness enhancement factor")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image_out = context.images.get_pil(self.image.image_name)
|
|
if self.invert:
|
|
if image_out.mode not in ("L", "RGB"):
|
|
image_out = image_out.convert("RGB")
|
|
image_out = ImageOps.invert(image_out)
|
|
if self.color != 1.0:
|
|
color_enhancer = ImageEnhance.Color(image_out)
|
|
image_out = color_enhancer.enhance(self.color)
|
|
if self.contrast != 1.0:
|
|
contrast_enhancer = ImageEnhance.Contrast(image_out)
|
|
image_out = contrast_enhancer.enhance(self.contrast)
|
|
if self.brightness != 1.0:
|
|
brightness_enhancer = ImageEnhance.Brightness(image_out)
|
|
image_out = brightness_enhancer.enhance(self.brightness)
|
|
if self.sharpness != 1.0:
|
|
sharpness_enhancer = ImageEnhance.Sharpness(image_out)
|
|
image_out = sharpness_enhancer.enhance(self.sharpness)
|
|
image_dto = context.images.save(image_out)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
@invocation(
|
|
"invokeai_ealightness",
|
|
title="Equivalent Achromatic Lightness",
|
|
tags=["image", "channel", "mask", "cielab", "lab"],
|
|
category="image",
|
|
version="1.2.0",
|
|
)
|
|
class InvokeEquivalentAchromaticLightnessInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Calculate Equivalent Achromatic Lightness from image. Originally created by @dwringer"""
|
|
|
|
image: ImageField = InputField(description="Image from which to get channel")
|
|
|
|
# The chroma, C*
|
|
# , and the hue, h, in the CIELAB color space are obtained by C*=sqrt((a*)^2+(b*)^2)
|
|
# and h=arctan(b*/a*)
|
|
# k 0.1644 0.0603 0.1307 0.0060
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image_in = context.images.get_pil(self.image.image_name)
|
|
|
|
if image_in.mode == "L":
|
|
image_in = image_in.convert("RGB")
|
|
|
|
image_out = image_in.convert("LAB")
|
|
channel_l = image_out.getchannel("L")
|
|
channel_a = image_out.getchannel("A")
|
|
channel_b = image_out.getchannel("B")
|
|
|
|
l_tensor = image_resized_to_grid_as_tensor(channel_l, normalize=False, multiple_of=1)
|
|
l_max = torch.ones(l_tensor.shape)
|
|
l_min = torch.zeros(l_tensor.shape)
|
|
a_tensor = image_resized_to_grid_as_tensor(channel_a, normalize=True, multiple_of=1)
|
|
b_tensor = image_resized_to_grid_as_tensor(channel_b, normalize=True, multiple_of=1)
|
|
|
|
c_tensor = torch.sqrt(torch.add(torch.pow(a_tensor, 2.0), torch.pow(b_tensor, 2.0)))
|
|
h_tensor = torch.atan2(b_tensor, a_tensor)
|
|
|
|
k = [0.1644, 0.0603, 0.1307, 0.0060]
|
|
|
|
h_minus_90 = torch.sub(h_tensor, PI / 2.0)
|
|
h_minus_90 = torch.sub(torch.remainder(torch.add(h_minus_90, 3 * PI), 2 * PI), PI)
|
|
|
|
f_by = torch.add(k[0] * torch.abs(torch.sin(torch.div(h_minus_90, 2.0))), k[1])
|
|
f_r_0 = torch.add(k[2] * torch.abs(torch.cos(h_tensor)), k[3])
|
|
|
|
f_r = torch.zeros(l_tensor.shape)
|
|
mask_hi = torch.ge(h_tensor, -1 * (PI / 2.0))
|
|
mask_lo = torch.le(h_tensor, PI / 2.0)
|
|
mask = torch.logical_and(mask_hi, mask_lo)
|
|
f_r[mask] = f_r_0[mask]
|
|
|
|
l_adjustment = torch.tensordot(torch.add(f_by, f_r), c_tensor, dims=([1, 2], [1, 2]))
|
|
l_max = torch.add(l_max, l_adjustment)
|
|
l_min = torch.add(l_min, l_adjustment)
|
|
image_tensor = torch.add(l_tensor, l_adjustment)
|
|
|
|
image_tensor = torch.div(torch.sub(image_tensor, l_min.min()), l_max.max() - l_min.min())
|
|
|
|
image_out = pil_image_from_tensor(image_tensor)
|
|
|
|
image_dto = context.images.save(image_out)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
BLEND_MODES = Literal[
|
|
"Normal",
|
|
"Lighten Only",
|
|
"Darken Only",
|
|
"Lighten Only (EAL)",
|
|
"Darken Only (EAL)",
|
|
"Hue",
|
|
"Saturation",
|
|
"Color",
|
|
"Luminosity",
|
|
"Linear Dodge (Add)",
|
|
"Subtract",
|
|
"Multiply",
|
|
"Divide",
|
|
"Screen",
|
|
"Overlay",
|
|
"Linear Burn",
|
|
"Difference",
|
|
"Hard Light",
|
|
"Soft Light",
|
|
"Vivid Light",
|
|
"Linear Light",
|
|
"Color Burn",
|
|
"Color Dodge",
|
|
]
|
|
|
|
BLEND_COLOR_SPACES = Literal[
|
|
"RGB", "Linear RGB", "HSL (RGB)", "HSV (RGB)", "Okhsl", "Okhsv", "Oklch (Oklab)", "LCh (CIELab)"
|
|
]
|
|
|
|
|
|
@invocation(
|
|
"invokeai_img_blend",
|
|
title="Image Layer Blend",
|
|
tags=["image", "blend", "layer", "alpha", "composite", "dodge", "burn"],
|
|
category="image",
|
|
version="1.2.0",
|
|
)
|
|
class InvokeImageBlendInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Blend two images together, with optional opacity, mask, and blend modes. Originally created by @dwringer"""
|
|
|
|
layer_upper: ImageField = InputField(description="The top image to blend", ui_order=1)
|
|
blend_mode: BLEND_MODES = InputField(default="Normal", description="Available blend modes", ui_order=2)
|
|
opacity: float = InputField(ge=0, default=1.0, description="Desired opacity of the upper layer", ui_order=3)
|
|
mask: Optional[ImageField] = InputField(
|
|
default=None, description="Optional mask, used to restrict areas from blending", ui_order=4
|
|
)
|
|
fit_to_width: bool = InputField(default=False, description="Scale upper layer to fit base width", ui_order=5)
|
|
fit_to_height: bool = InputField(default=True, description="Scale upper layer to fit base height", ui_order=6)
|
|
layer_base: ImageField = InputField(description="The bottom image to blend", ui_order=7)
|
|
color_space: BLEND_COLOR_SPACES = InputField(
|
|
default="RGB", description="Available color spaces for blend computations", ui_order=8
|
|
)
|
|
adaptive_gamut: float = InputField(
|
|
ge=0,
|
|
default=0.0,
|
|
description="Adaptive gamut clipping (0=off). Higher prioritizes chroma over lightness",
|
|
ui_order=9,
|
|
)
|
|
high_precision: bool = InputField(
|
|
default=True, description="Use more steps in computing gamut when possible", ui_order=10
|
|
)
|
|
|
|
def scale_and_pad_or_crop_to_base(self, image_upper: Image.Image, image_base: Image.Image):
|
|
"""Rescale upper image based on self.fill_x and self.fill_y params"""
|
|
|
|
aspect_base = image_base.width / image_base.height
|
|
aspect_upper = image_upper.width / image_upper.height
|
|
if self.fit_to_width and self.fit_to_height:
|
|
image_upper = image_upper.resize((image_base.width, image_base.height))
|
|
elif (self.fit_to_width and (aspect_base < aspect_upper)) or (
|
|
self.fit_to_height and (aspect_upper <= aspect_base)
|
|
):
|
|
image_upper = ImageOps.pad(
|
|
image_upper, (image_base.width, image_base.height), color=tuple([0 for band in image_upper.getbands()])
|
|
)
|
|
elif (self.fit_to_width and (aspect_upper <= aspect_base)) or (
|
|
self.fit_to_height and (aspect_base < aspect_upper)
|
|
):
|
|
image_upper = ImageOps.fit(image_upper, (image_base.width, image_base.height))
|
|
return image_upper
|
|
|
|
def image_convert_with_xform(self, image_in: Image.Image, from_mode: str, to_mode: str):
|
|
"""Use PIL ImageCms color management to convert 3-channel image from one mode to another"""
|
|
|
|
def fixed_mode(mode: str):
|
|
if mode.lower() == "srgb":
|
|
return "rgb"
|
|
elif mode.lower() == "cielab":
|
|
return "lab"
|
|
else:
|
|
return mode.lower()
|
|
|
|
from_mode, to_mode = fixed_mode(from_mode), fixed_mode(to_mode)
|
|
|
|
profile_srgb = None
|
|
profile_uplab = None
|
|
profile_lab = None
|
|
if (from_mode.lower() == "rgb") or (to_mode.lower() == "rgb"):
|
|
profile_srgb = ImageCms.createProfile("sRGB")
|
|
if (from_mode.lower() == "uplab") or (to_mode.lower() == "uplab"):
|
|
if os.path.isfile("CIELab_to_UPLab.icc"):
|
|
profile_uplab = ImageCms.getOpenProfile("CIELab_to_UPLab.icc")
|
|
if (from_mode.lower() in ["lab", "cielab", "uplab"]) or (to_mode.lower() in ["lab", "cielab", "uplab"]):
|
|
if profile_uplab is None:
|
|
profile_lab = ImageCms.createProfile("LAB", colorTemp=6500)
|
|
else:
|
|
profile_lab = ImageCms.createProfile("LAB", colorTemp=5000)
|
|
|
|
xform_rgb_to_lab = None
|
|
xform_uplab_to_lab = None
|
|
xform_lab_to_uplab = None
|
|
xform_lab_to_rgb = None
|
|
if from_mode == "rgb":
|
|
xform_rgb_to_lab = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_srgb, profile_lab, "RGB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
elif from_mode == "uplab":
|
|
xform_uplab_to_lab = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_uplab, profile_lab, "LAB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
if to_mode == "uplab":
|
|
xform_lab_to_uplab = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_lab, profile_uplab, "LAB", "LAB", renderingIntent=2, flags=0x2400
|
|
)
|
|
elif to_mode == "rgb":
|
|
xform_lab_to_rgb = ImageCms.buildTransformFromOpenProfiles(
|
|
profile_lab, profile_srgb, "LAB", "RGB", renderingIntent=2, flags=0x2400
|
|
)
|
|
|
|
image_out = None
|
|
if (from_mode == "rgb") and (to_mode == "lab"):
|
|
image_out = ImageCms.applyTransform(image_in, xform_rgb_to_lab)
|
|
elif (from_mode == "rgb") and (to_mode == "uplab"):
|
|
image_out = ImageCms.applyTransform(image_in, xform_rgb_to_lab)
|
|
image_out = ImageCms.applyTransform(image_out, xform_lab_to_uplab)
|
|
elif (from_mode == "lab") and (to_mode == "uplab"):
|
|
image_out = ImageCms.applyTransform(image_in, xform_lab_to_uplab)
|
|
elif (from_mode == "lab") and (to_mode == "rgb"):
|
|
image_out = ImageCms.applyTransform(image_in, xform_lab_to_rgb)
|
|
elif (from_mode == "uplab") and (to_mode == "lab"):
|
|
image_out = ImageCms.applyTransform(image_in, xform_uplab_to_lab)
|
|
elif (from_mode == "uplab") and (to_mode == "rgb"):
|
|
image_out = ImageCms.applyTransform(image_in, xform_uplab_to_lab)
|
|
image_out = ImageCms.applyTransform(image_out, xform_lab_to_rgb)
|
|
|
|
return image_out
|
|
|
|
def prepare_tensors_from_images(
|
|
self,
|
|
image_upper: Image.Image,
|
|
image_lower: Image.Image,
|
|
mask_image: Optional[Image.Image] = None,
|
|
required: Optional[list[str]] = None,
|
|
):
|
|
"""Convert image to the necessary image space representations for blend calculations"""
|
|
required = required or ["hsv", "hsl", "lch", "oklch", "okhsl", "okhsv", "l_eal"]
|
|
alpha_upper, alpha_lower = None, None
|
|
if image_upper.mode == "RGBA":
|
|
# Prepare tensors to compute blend
|
|
image_rgba_upper = image_upper.convert("RGBA")
|
|
alpha_upper = image_rgba_upper.getchannel("A")
|
|
image_upper = image_upper.convert("RGB")
|
|
else:
|
|
if not (image_upper.mode == "RGB"):
|
|
image_upper = image_upper.convert("RGB")
|
|
if image_lower.mode == "RGBA":
|
|
# Prepare tensors to compute blend
|
|
image_rgba_lower = image_lower.convert("RGBA")
|
|
alpha_lower = image_rgba_lower.getchannel("A")
|
|
image_lower = image_lower.convert("RGB")
|
|
else:
|
|
if not (image_lower.mode == "RGB"):
|
|
image_lower = image_lower.convert("RGB")
|
|
|
|
image_lab_upper, image_lab_lower = None, None
|
|
upper_lab_tensor, lower_lab_tensor = None, None
|
|
upper_lch_tensor, lower_lch_tensor = None, None
|
|
if "lch" in required:
|
|
image_lab_upper, image_lab_lower = (
|
|
self.image_convert_with_xform(image_upper, "rgb", "lab"),
|
|
self.image_convert_with_xform(image_lower, "rgb", "lab"),
|
|
)
|
|
|
|
upper_lab_tensor = torch.stack(
|
|
[
|
|
tensor_from_pil_image(image_lab_upper.getchannel("L"), normalize=False)[0, :, :],
|
|
tensor_from_pil_image(image_lab_upper.getchannel("A"), normalize=True)[0, :, :],
|
|
tensor_from_pil_image(image_lab_upper.getchannel("B"), normalize=True)[0, :, :],
|
|
]
|
|
)
|
|
lower_lab_tensor = torch.stack(
|
|
[
|
|
tensor_from_pil_image(image_lab_lower.getchannel("L"), normalize=False)[0, :, :],
|
|
tensor_from_pil_image(image_lab_lower.getchannel("A"), normalize=True)[0, :, :],
|
|
tensor_from_pil_image(image_lab_lower.getchannel("B"), normalize=True)[0, :, :],
|
|
]
|
|
)
|
|
upper_lch_tensor = torch.stack(
|
|
[
|
|
upper_lab_tensor[0, :, :],
|
|
torch.sqrt(
|
|
torch.add(torch.pow(upper_lab_tensor[1, :, :], 2.0), torch.pow(upper_lab_tensor[2, :, :], 2.0))
|
|
),
|
|
torch.atan2(upper_lab_tensor[2, :, :], upper_lab_tensor[1, :, :]),
|
|
]
|
|
)
|
|
lower_lch_tensor = torch.stack(
|
|
[
|
|
lower_lab_tensor[0, :, :],
|
|
torch.sqrt(
|
|
torch.add(torch.pow(lower_lab_tensor[1, :, :], 2.0), torch.pow(lower_lab_tensor[2, :, :], 2.0))
|
|
),
|
|
torch.atan2(lower_lab_tensor[2, :, :], lower_lab_tensor[1, :, :]),
|
|
]
|
|
)
|
|
|
|
upper_l_eal_tensor, lower_l_eal_tensor = None, None
|
|
if "l_eal" in required:
|
|
upper_l_eal_tensor = equivalent_achromatic_lightness(upper_lch_tensor)
|
|
lower_l_eal_tensor = equivalent_achromatic_lightness(lower_lch_tensor)
|
|
|
|
image_hsv_upper, image_hsv_lower = None, None
|
|
upper_hsv_tensor, lower_hsv_tensor = None, None
|
|
if "hsv" in required:
|
|
image_hsv_upper, image_hsv_lower = image_upper.convert("HSV"), image_lower.convert("HSV")
|
|
upper_hsv_tensor = torch.stack(
|
|
[
|
|
tensor_from_pil_image(image_hsv_upper.getchannel("H"), normalize=False)[0, :, :],
|
|
tensor_from_pil_image(image_hsv_upper.getchannel("S"), normalize=False)[0, :, :],
|
|
tensor_from_pil_image(image_hsv_upper.getchannel("V"), normalize=False)[0, :, :],
|
|
]
|
|
)
|
|
lower_hsv_tensor = torch.stack(
|
|
[
|
|
tensor_from_pil_image(image_hsv_lower.getchannel("H"), normalize=False)[0, :, :],
|
|
tensor_from_pil_image(image_hsv_lower.getchannel("S"), normalize=False)[0, :, :],
|
|
tensor_from_pil_image(image_hsv_lower.getchannel("V"), normalize=False)[0, :, :],
|
|
]
|
|
)
|
|
|
|
upper_rgb_tensor = tensor_from_pil_image(image_upper, normalize=False)
|
|
lower_rgb_tensor = tensor_from_pil_image(image_lower, normalize=False)
|
|
|
|
alpha_upper_tensor, alpha_lower_tensor = None, None
|
|
if alpha_upper is None:
|
|
alpha_upper_tensor = torch.ones(upper_rgb_tensor[0, :, :].shape)
|
|
else:
|
|
alpha_upper_tensor = tensor_from_pil_image(alpha_upper, normalize=False)[0, :, :]
|
|
if alpha_lower is None:
|
|
alpha_lower_tensor = torch.ones(lower_rgb_tensor[0, :, :].shape)
|
|
else:
|
|
alpha_lower_tensor = tensor_from_pil_image(alpha_lower, normalize=False)[0, :, :]
|
|
|
|
mask_tensor = None
|
|
if mask_image is not None:
|
|
mask_tensor = tensor_from_pil_image(mask_image.convert("L"), normalize=False)[0, :, :]
|
|
|
|
upper_hsl_tensor, lower_hsl_tensor = None, None
|
|
if "hsl" in required:
|
|
upper_hsl_tensor = hsl_from_srgb(upper_rgb_tensor)
|
|
lower_hsl_tensor = hsl_from_srgb(lower_rgb_tensor)
|
|
|
|
upper_okhsl_tensor, lower_okhsl_tensor = None, None
|
|
if "okhsl" in required:
|
|
upper_okhsl_tensor = okhsl_from_srgb(upper_rgb_tensor, steps=(3 if self.high_precision else 1))
|
|
lower_okhsl_tensor = okhsl_from_srgb(lower_rgb_tensor, steps=(3 if self.high_precision else 1))
|
|
|
|
upper_okhsv_tensor, lower_okhsv_tensor = None, None
|
|
if "okhsv" in required:
|
|
upper_okhsv_tensor = okhsv_from_srgb(upper_rgb_tensor, steps=(3 if self.high_precision else 1))
|
|
lower_okhsv_tensor = okhsv_from_srgb(lower_rgb_tensor, steps=(3 if self.high_precision else 1))
|
|
|
|
upper_rgb_l_tensor = linear_srgb_from_srgb(upper_rgb_tensor)
|
|
lower_rgb_l_tensor = linear_srgb_from_srgb(lower_rgb_tensor)
|
|
|
|
upper_oklab_tensor, lower_oklab_tensor = None, None
|
|
upper_oklch_tensor, lower_oklch_tensor = None, None
|
|
if "oklch" in required:
|
|
upper_oklab_tensor = oklab_from_linear_srgb(upper_rgb_l_tensor)
|
|
lower_oklab_tensor = oklab_from_linear_srgb(lower_rgb_l_tensor)
|
|
|
|
upper_oklch_tensor = torch.stack(
|
|
[
|
|
upper_oklab_tensor[0, :, :],
|
|
torch.sqrt(
|
|
torch.add(
|
|
torch.pow(upper_oklab_tensor[1, :, :], 2.0), torch.pow(upper_oklab_tensor[2, :, :], 2.0)
|
|
)
|
|
),
|
|
torch.atan2(upper_oklab_tensor[2, :, :], upper_oklab_tensor[1, :, :]),
|
|
]
|
|
)
|
|
lower_oklch_tensor = torch.stack(
|
|
[
|
|
lower_oklab_tensor[0, :, :],
|
|
torch.sqrt(
|
|
torch.add(
|
|
torch.pow(lower_oklab_tensor[1, :, :], 2.0), torch.pow(lower_oklab_tensor[2, :, :], 2.0)
|
|
)
|
|
),
|
|
torch.atan2(lower_oklab_tensor[2, :, :], lower_oklab_tensor[1, :, :]),
|
|
]
|
|
)
|
|
|
|
return (
|
|
upper_rgb_l_tensor,
|
|
lower_rgb_l_tensor,
|
|
upper_rgb_tensor,
|
|
lower_rgb_tensor,
|
|
alpha_upper_tensor,
|
|
alpha_lower_tensor,
|
|
mask_tensor,
|
|
upper_hsv_tensor,
|
|
lower_hsv_tensor,
|
|
upper_hsl_tensor,
|
|
lower_hsl_tensor,
|
|
upper_lab_tensor,
|
|
lower_lab_tensor,
|
|
upper_lch_tensor,
|
|
lower_lch_tensor,
|
|
upper_l_eal_tensor,
|
|
lower_l_eal_tensor,
|
|
upper_oklab_tensor,
|
|
lower_oklab_tensor,
|
|
upper_oklch_tensor,
|
|
lower_oklch_tensor,
|
|
upper_okhsv_tensor,
|
|
lower_okhsv_tensor,
|
|
upper_okhsl_tensor,
|
|
lower_okhsl_tensor,
|
|
)
|
|
|
|
def apply_blend(self, image_tensors: torch.Tensor):
|
|
"""Apply the selected blend mode using the appropriate color space representations"""
|
|
|
|
blend_mode = self.blend_mode
|
|
color_space = self.color_space.split()[0]
|
|
if (color_space in ["RGB", "Linear"]) and (blend_mode in ["Hue", "Saturation", "Luminosity", "Color"]):
|
|
color_space = "HSL"
|
|
|
|
def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with: float = MAX_FLOAT):
|
|
"""Keep elements of the tensor finite"""
|
|
|
|
rgb_tensor = remove_nans(rgb_tensor, replace_with=replace_with)
|
|
|
|
if 0 < self.adaptive_gamut:
|
|
rgb_tensor = gamut_clip_tensor(
|
|
rgb_tensor, alpha=self.adaptive_gamut, steps=(3 if self.high_precision else 1)
|
|
)
|
|
rgb_tensor = remove_nans(rgb_tensor, replace_with=replace_with)
|
|
if clamp: # Use of MAX_FLOAT seems to lead to NaN's coming back in some cases:
|
|
rgb_tensor = rgb_tensor.clamp(0.0, 1.0)
|
|
|
|
return rgb_tensor
|
|
|
|
reassembly_function = {
|
|
"RGB": lambda t: linear_srgb_from_srgb(t),
|
|
"Linear": lambda t: t,
|
|
"HSL": lambda t: linear_srgb_from_srgb(srgb_from_hsl(t)),
|
|
"HSV": lambda t: linear_srgb_from_srgb(
|
|
tensor_from_pil_image(
|
|
pil_image_from_tensor(t.clamp(0.0, 1.0), mode="HSV").convert("RGB"), normalize=False
|
|
)
|
|
),
|
|
"Okhsl": lambda t: linear_srgb_from_srgb(
|
|
srgb_from_okhsl(t, alpha=self.adaptive_gamut, steps=(3 if self.high_precision else 1))
|
|
),
|
|
"Okhsv": lambda t: linear_srgb_from_srgb(
|
|
srgb_from_okhsv(t, alpha=self.adaptive_gamut, steps=(3 if self.high_precision else 1))
|
|
),
|
|
"Oklch": lambda t: linear_srgb_from_oklab(
|
|
torch.stack(
|
|
[
|
|
t[0, :, :],
|
|
torch.mul(t[1, :, :], torch.cos(t[2, :, :])),
|
|
torch.mul(t[1, :, :], torch.sin(t[2, :, :])),
|
|
]
|
|
)
|
|
),
|
|
"LCh": lambda t: linear_srgb_from_srgb(
|
|
tensor_from_pil_image(
|
|
self.image_convert_with_xform(
|
|
Image.merge(
|
|
"LAB",
|
|
tuple(
|
|
pil_image_from_tensor(u)
|
|
for u in [
|
|
t[0, :, :].clamp(0.0, 1.0),
|
|
torch.div(torch.add(torch.mul(t[1, :, :], torch.cos(t[2, :, :])), 1.0), 2.0),
|
|
torch.div(torch.add(torch.mul(t[1, :, :], torch.sin(t[2, :, :])), 1.0), 2.0),
|
|
]
|
|
),
|
|
),
|
|
"lab",
|
|
"rgb",
|
|
),
|
|
normalize=False,
|
|
)
|
|
),
|
|
}[color_space]
|
|
|
|
(
|
|
upper_rgb_l_tensor, # linear-light sRGB
|
|
lower_rgb_l_tensor, # linear-light sRGB
|
|
upper_rgb_tensor,
|
|
lower_rgb_tensor,
|
|
alpha_upper_tensor,
|
|
alpha_lower_tensor,
|
|
mask_tensor,
|
|
upper_hsv_tensor, # h_rgb, s_hsv, v_hsv
|
|
lower_hsv_tensor,
|
|
upper_hsl_tensor, # , s_hsl, l_hsl
|
|
lower_hsl_tensor,
|
|
upper_lab_tensor, # l_lab, a_lab, b_lab
|
|
lower_lab_tensor,
|
|
upper_lch_tensor, # , c_lab, h_lab
|
|
lower_lch_tensor,
|
|
upper_l_eal_tensor, # l_eal
|
|
lower_l_eal_tensor,
|
|
upper_oklab_tensor, # l_oklab, a_oklab, b_oklab
|
|
lower_oklab_tensor,
|
|
upper_oklch_tensor, # , c_oklab, h_oklab
|
|
lower_oklch_tensor,
|
|
upper_okhsv_tensor, # h_okhsv, s_okhsv, v_okhsv
|
|
lower_okhsv_tensor,
|
|
upper_okhsl_tensor, # h_okhsl, s_okhsl, l_r_oklab
|
|
lower_okhsl_tensor,
|
|
) = image_tensors
|
|
|
|
current_space_tensors = {
|
|
"RGB": [upper_rgb_tensor, lower_rgb_tensor],
|
|
"Linear": [upper_rgb_l_tensor, lower_rgb_l_tensor],
|
|
"HSL": [upper_hsl_tensor, lower_hsl_tensor],
|
|
"HSV": [upper_hsv_tensor, lower_hsv_tensor],
|
|
"Okhsl": [upper_okhsl_tensor, lower_okhsl_tensor],
|
|
"Okhsv": [upper_okhsv_tensor, lower_okhsv_tensor],
|
|
"Oklch": [upper_oklch_tensor, lower_oklch_tensor],
|
|
"LCh": [upper_lch_tensor, lower_lch_tensor],
|
|
}[color_space]
|
|
upper_space_tensor = current_space_tensors[0]
|
|
lower_space_tensor = current_space_tensors[1]
|
|
|
|
lightness_index = {
|
|
"RGB": None,
|
|
"Linear": None,
|
|
"HSL": 2,
|
|
"HSV": 2,
|
|
"Okhsl": 2,
|
|
"Okhsv": 2,
|
|
"Oklch": 0,
|
|
"LCh": 0,
|
|
}[color_space]
|
|
|
|
saturation_index = {
|
|
"RGB": None,
|
|
"Linear": None,
|
|
"HSL": 1,
|
|
"HSV": 1,
|
|
"Okhsl": 1,
|
|
"Okhsv": 1,
|
|
"Oklch": 1,
|
|
"LCh": 1,
|
|
}[color_space]
|
|
|
|
hue_index = {
|
|
"RGB": None,
|
|
"Linear": None,
|
|
"HSL": 0,
|
|
"HSV": 0,
|
|
"Okhsl": 0,
|
|
"Okhsv": 0,
|
|
"Oklch": 2,
|
|
"LCh": 2,
|
|
}[color_space]
|
|
|
|
if blend_mode == "Normal":
|
|
upper_rgb_l_tensor = reassembly_function(upper_space_tensor)
|
|
|
|
elif blend_mode == "Multiply":
|
|
upper_rgb_l_tensor = reassembly_function(torch.mul(lower_space_tensor, upper_space_tensor))
|
|
|
|
elif blend_mode == "Screen":
|
|
upper_rgb_l_tensor = reassembly_function(
|
|
torch.add(
|
|
torch.mul(
|
|
torch.mul(
|
|
torch.add(torch.mul(upper_space_tensor, -1.0), 1.0),
|
|
torch.add(torch.mul(lower_space_tensor, -1.0), 1.0),
|
|
),
|
|
-1.0,
|
|
),
|
|
1.0,
|
|
)
|
|
)
|
|
|
|
elif (blend_mode == "Overlay") or (blend_mode == "Hard Light"):
|
|
subject_of_cond_tensor = lower_space_tensor if (blend_mode == "Overlay") else upper_space_tensor
|
|
if lightness_index is None:
|
|
upper_space_tensor = torch.where(
|
|
torch.lt(subject_of_cond_tensor, 0.5),
|
|
torch.mul(torch.mul(lower_space_tensor, upper_space_tensor), 2.0),
|
|
torch.add(
|
|
torch.mul(
|
|
torch.mul(
|
|
torch.mul(
|
|
torch.add(torch.mul(lower_space_tensor, -1.0), 1.0),
|
|
torch.add(torch.mul(upper_space_tensor, -1.0), 1.0),
|
|
),
|
|
2.0,
|
|
),
|
|
-1.0,
|
|
),
|
|
1.0,
|
|
),
|
|
)
|
|
else: # TODO: Currently blending only the lightness channel, not really ideal.
|
|
upper_space_tensor[lightness_index, :, :] = torch.where(
|
|
torch.lt(subject_of_cond_tensor[lightness_index, :, :], 0.5),
|
|
torch.mul(
|
|
torch.mul(lower_space_tensor[lightness_index, :, :], upper_space_tensor[lightness_index, :, :]),
|
|
2.0,
|
|
),
|
|
torch.add(
|
|
torch.mul(
|
|
torch.mul(
|
|
torch.mul(
|
|
torch.add(torch.mul(lower_space_tensor[lightness_index, :, :], -1.0), 1.0),
|
|
torch.add(torch.mul(upper_space_tensor[lightness_index, :, :], -1.0), 1.0),
|
|
),
|
|
2.0,
|
|
),
|
|
-1.0,
|
|
),
|
|
1.0,
|
|
),
|
|
)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(upper_space_tensor))
|
|
|
|
elif blend_mode == "Soft Light":
|
|
if lightness_index is None:
|
|
g_tensor = torch.where(
|
|
torch.le(lower_space_tensor, 0.25),
|
|
torch.mul(
|
|
torch.add(
|
|
torch.mul(torch.sub(torch.mul(lower_space_tensor, 16.0), 12.0), lower_space_tensor), 4.0
|
|
),
|
|
lower_space_tensor,
|
|
),
|
|
torch.sqrt(lower_space_tensor),
|
|
)
|
|
lower_space_tensor = torch.where(
|
|
torch.le(upper_space_tensor, 0.5),
|
|
torch.sub(
|
|
lower_space_tensor,
|
|
torch.mul(
|
|
torch.mul(torch.add(torch.mul(lower_space_tensor, -1.0), 1.0), lower_space_tensor),
|
|
torch.add(torch.mul(torch.mul(upper_space_tensor, 2.0), -1.0), 1.0),
|
|
),
|
|
),
|
|
torch.add(
|
|
lower_space_tensor,
|
|
torch.mul(
|
|
torch.sub(torch.mul(upper_space_tensor, 2.0), 1.0), torch.sub(g_tensor, lower_space_tensor)
|
|
),
|
|
),
|
|
)
|
|
else:
|
|
print(
|
|
"\r\nCOND SHAPE:"
|
|
+ str(torch.le(lower_space_tensor[lightness_index, :, :], 0.25).unsqueeze(0).shape)
|
|
+ "\r\n"
|
|
)
|
|
g_tensor = torch.where( # Calculates all 3 channels but only one is currently used
|
|
torch.le(lower_space_tensor[lightness_index, :, :], 0.25).expand(upper_space_tensor.shape),
|
|
torch.mul(
|
|
torch.add(
|
|
torch.mul(torch.sub(torch.mul(lower_space_tensor, 16.0), 12.0), lower_space_tensor), 4.0
|
|
),
|
|
lower_space_tensor,
|
|
),
|
|
torch.sqrt(lower_space_tensor),
|
|
)
|
|
lower_space_tensor[lightness_index, :, :] = torch.where(
|
|
torch.le(upper_space_tensor[lightness_index, :, :], 0.5),
|
|
torch.sub(
|
|
lower_space_tensor[lightness_index, :, :],
|
|
torch.mul(
|
|
torch.mul(
|
|
torch.add(torch.mul(lower_space_tensor[lightness_index, :, :], -1.0), 1.0),
|
|
lower_space_tensor[lightness_index, :, :],
|
|
),
|
|
torch.add(torch.mul(torch.mul(upper_space_tensor[lightness_index, :, :], 2.0), -1.0), 1.0),
|
|
),
|
|
),
|
|
torch.add(
|
|
lower_space_tensor[lightness_index, :, :],
|
|
torch.mul(
|
|
torch.sub(torch.mul(upper_space_tensor[lightness_index, :, :], 2.0), 1.0),
|
|
torch.sub(g_tensor[lightness_index, :, :], lower_space_tensor[lightness_index, :, :]),
|
|
),
|
|
),
|
|
)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Linear Dodge (Add)":
|
|
lower_space_tensor = torch.add(lower_space_tensor, upper_space_tensor)
|
|
if hue_index is not None:
|
|
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Color Dodge":
|
|
lower_space_tensor = torch.div(lower_space_tensor, torch.add(torch.mul(upper_space_tensor, -1.0), 1.0))
|
|
if hue_index is not None:
|
|
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Divide":
|
|
lower_space_tensor = torch.div(lower_space_tensor, upper_space_tensor)
|
|
if hue_index is not None:
|
|
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Linear Burn":
|
|
# We compute the result in the lower image's current space tensor and return that:
|
|
if lightness_index is None: # Elementwise
|
|
lower_space_tensor = torch.sub(torch.add(lower_space_tensor, upper_space_tensor), 1.0)
|
|
else: # Operate only on the selected lightness channel
|
|
lower_space_tensor[lightness_index, :, :] = torch.sub(
|
|
torch.add(lower_space_tensor[lightness_index, :, :], upper_space_tensor[lightness_index, :, :]), 1.0
|
|
)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Color Burn":
|
|
upper_rgb_l_tensor = adaptive_clipped(
|
|
reassembly_function(
|
|
torch.add(
|
|
torch.mul(
|
|
torch.min(
|
|
torch.div(torch.add(torch.mul(lower_space_tensor, -1.0), 1.0), upper_space_tensor),
|
|
torch.ones(lower_space_tensor.shape),
|
|
),
|
|
-1.0,
|
|
),
|
|
1.0,
|
|
)
|
|
)
|
|
)
|
|
elif blend_mode == "Vivid Light":
|
|
if lightness_index is None:
|
|
lower_space_tensor = adaptive_clipped(
|
|
reassembly_function(
|
|
torch.where(
|
|
torch.lt(upper_space_tensor, 0.5),
|
|
torch.div(
|
|
torch.add(
|
|
torch.mul(
|
|
torch.div(
|
|
torch.add(torch.mul(lower_space_tensor, -1.0), 1.0), upper_space_tensor
|
|
),
|
|
-1.0,
|
|
),
|
|
1.0,
|
|
),
|
|
2.0,
|
|
),
|
|
torch.div(
|
|
torch.div(lower_space_tensor, torch.add(torch.mul(upper_space_tensor, -1.0), 1.0)), 2.0
|
|
),
|
|
)
|
|
)
|
|
)
|
|
else:
|
|
lower_space_tensor[lightness_index, :, :] = torch.where(
|
|
torch.lt(upper_space_tensor[lightness_index, :, :], 0.5),
|
|
torch.div(
|
|
torch.add(
|
|
torch.mul(
|
|
torch.div(
|
|
torch.add(torch.mul(lower_space_tensor[lightness_index, :, :], -1.0), 1.0),
|
|
upper_space_tensor[lightness_index, :, :],
|
|
),
|
|
-1.0,
|
|
),
|
|
1.0,
|
|
),
|
|
2.0,
|
|
),
|
|
torch.div(
|
|
torch.div(
|
|
lower_space_tensor[lightness_index, :, :],
|
|
torch.add(torch.mul(upper_space_tensor[lightness_index, :, :], -1.0), 1.0),
|
|
),
|
|
2.0,
|
|
),
|
|
)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Linear Light":
|
|
if lightness_index is None:
|
|
lower_space_tensor = torch.sub(torch.add(lower_space_tensor, torch.mul(upper_space_tensor, 2.0)), 1.0)
|
|
else:
|
|
lower_space_tensor[lightness_index, :, :] = torch.sub(
|
|
torch.add(
|
|
lower_space_tensor[lightness_index, :, :],
|
|
torch.mul(upper_space_tensor[lightness_index, :, :], 2.0),
|
|
),
|
|
1.0,
|
|
)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Subtract":
|
|
lower_space_tensor = torch.sub(lower_space_tensor, upper_space_tensor)
|
|
if hue_index is not None:
|
|
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
|
|
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
|
|
|
|
elif blend_mode == "Difference":
|
|
upper_rgb_l_tensor = adaptive_clipped(
|
|
reassembly_function(torch.abs(torch.sub(lower_space_tensor, upper_space_tensor)))
|
|
)
|
|
|
|
elif (blend_mode == "Darken Only") or (blend_mode == "Lighten Only"):
|
|
extrema_fn = torch.min if (blend_mode == "Darken Only") else torch.max
|
|
comparator_fn = torch.ge if (blend_mode == "Darken Only") else torch.lt
|
|
if lightness_index is None:
|
|
upper_space_tensor = torch.stack(
|
|
[
|
|
extrema_fn(upper_space_tensor[0, :, :], lower_space_tensor[0, :, :]),
|
|
extrema_fn(upper_space_tensor[1, :, :], lower_space_tensor[1, :, :]),
|
|
extrema_fn(upper_space_tensor[2, :, :], lower_space_tensor[2, :, :]),
|
|
]
|
|
)
|
|
else:
|
|
upper_space_tensor = torch.where(
|
|
comparator_fn(
|
|
upper_space_tensor[lightness_index, :, :], lower_space_tensor[lightness_index, :, :]
|
|
).expand(upper_space_tensor.shape),
|
|
lower_space_tensor,
|
|
upper_space_tensor,
|
|
)
|
|
upper_rgb_l_tensor = reassembly_function(upper_space_tensor)
|
|
|
|
elif blend_mode in [
|
|
"Hue",
|
|
"Saturation",
|
|
"Color",
|
|
"Luminosity",
|
|
]:
|
|
if blend_mode == "Hue": # l, c: lower / h: upper
|
|
upper_space_tensor[lightness_index, :, :] = lower_space_tensor[lightness_index, :, :]
|
|
upper_space_tensor[saturation_index, :, :] = lower_space_tensor[saturation_index, :, :]
|
|
elif blend_mode == "Saturation": # l, h: lower / c: upper
|
|
upper_space_tensor[lightness_index, :, :] = lower_space_tensor[lightness_index, :, :]
|
|
upper_space_tensor[hue_index, :, :] = lower_space_tensor[hue_index, :, :]
|
|
elif blend_mode == "Color": # l: lower / c, h: upper
|
|
upper_space_tensor[lightness_index, :, :] = lower_space_tensor[lightness_index, :, :]
|
|
elif blend_mode == "Luminosity": # h, c: lower / l: upper
|
|
upper_space_tensor[saturation_index, :, :] = lower_space_tensor[saturation_index, :, :]
|
|
upper_space_tensor[hue_index, :, :] = lower_space_tensor[hue_index, :, :]
|
|
upper_rgb_l_tensor = reassembly_function(upper_space_tensor)
|
|
|
|
elif blend_mode in ["Lighten Only (EAL)", "Darken Only (EAL)"]:
|
|
comparator_fn = torch.lt if (blend_mode == "Lighten Only (EAL)") else torch.ge
|
|
upper_space_tensor = torch.where(
|
|
comparator_fn(upper_l_eal_tensor, lower_l_eal_tensor).expand(upper_space_tensor.shape),
|
|
lower_space_tensor,
|
|
upper_space_tensor,
|
|
)
|
|
upper_rgb_l_tensor = reassembly_function(upper_space_tensor)
|
|
|
|
return upper_rgb_l_tensor
|
|
|
|
def alpha_composite(
|
|
self,
|
|
upper_tensor: torch.Tensor,
|
|
alpha_upper_tensor: torch.Tensor,
|
|
lower_tensor: torch.Tensor,
|
|
alpha_lower_tensor: torch.Tensor,
|
|
mask_tensor: Optional[torch.Tensor] = None,
|
|
):
|
|
"""Alpha compositing of upper on lower tensor with alpha channels, mask and scalar"""
|
|
|
|
upper_tensor = remove_nans(upper_tensor)
|
|
|
|
alpha_upper_tensor = torch.mul(alpha_upper_tensor, self.opacity)
|
|
if mask_tensor is not None:
|
|
alpha_upper_tensor = torch.mul(alpha_upper_tensor, torch.add(torch.mul(mask_tensor, -1.0), 1.0))
|
|
|
|
alpha_tensor = torch.add(
|
|
alpha_upper_tensor, torch.mul(alpha_lower_tensor, torch.add(torch.mul(alpha_upper_tensor, -1.0), 1.0))
|
|
)
|
|
|
|
return (
|
|
torch.div(
|
|
torch.add(
|
|
torch.mul(upper_tensor, alpha_upper_tensor),
|
|
torch.mul(
|
|
torch.mul(lower_tensor, alpha_lower_tensor), torch.add(torch.mul(alpha_upper_tensor, -1.0), 1.0)
|
|
),
|
|
),
|
|
alpha_tensor,
|
|
),
|
|
alpha_tensor,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
"""Main execution of the ImageBlendInvocation node"""
|
|
|
|
image_upper = context.images.get_pil(self.layer_upper.image_name)
|
|
image_base = context.images.get_pil(self.layer_base.image_name)
|
|
|
|
# Keep the modes for restoration after processing:
|
|
image_mode_base = image_base.mode
|
|
|
|
# Get rid of ICC profiles by converting to sRGB, but save for restoration:
|
|
cms_profile_srgb = None
|
|
if "icc_profile" in image_upper.info:
|
|
cms_profile_upper = BytesIO(image_upper.info["icc_profile"])
|
|
cms_profile_srgb = ImageCms.createProfile("sRGB")
|
|
cms_xform = ImageCms.buildTransformFromOpenProfiles(
|
|
cms_profile_upper, cms_profile_srgb, image_upper.mode, "RGBA"
|
|
)
|
|
image_upper = ImageCms.applyTransform(image_upper, cms_xform)
|
|
|
|
cms_profile_base = None
|
|
icc_profile_bytes = None
|
|
if "icc_profile" in image_base.info:
|
|
icc_profile_bytes = image_base.info["icc_profile"]
|
|
cms_profile_base = BytesIO(icc_profile_bytes)
|
|
if cms_profile_srgb is None:
|
|
cms_profile_srgb = ImageCms.createProfile("sRGB")
|
|
cms_xform = ImageCms.buildTransformFromOpenProfiles(
|
|
cms_profile_base, cms_profile_srgb, image_base.mode, "RGBA"
|
|
)
|
|
image_base = ImageCms.applyTransform(image_base, cms_xform)
|
|
|
|
image_mask = None
|
|
if self.mask is not None:
|
|
image_mask = context.images.get_pil(self.mask.image_name)
|
|
color_space = self.color_space.split()[0]
|
|
|
|
image_upper = self.scale_and_pad_or_crop_to_base(image_upper, image_base)
|
|
if image_mask is not None:
|
|
image_mask = self.scale_and_pad_or_crop_to_base(image_mask, image_base)
|
|
|
|
tensor_requirements = []
|
|
|
|
# Hue, Saturation, Color, and Luminosity won't work in sRGB, require HSL
|
|
if self.blend_mode in ["Hue", "Saturation", "Color", "Luminosity"] and self.color_space in [
|
|
"RGB",
|
|
"Linear RGB",
|
|
]:
|
|
tensor_requirements = ["hsl"]
|
|
|
|
if self.blend_mode in ["Lighten Only (EAL)", "Darken Only (EAL)"]:
|
|
tensor_requirements = tensor_requirements + ["lch", "l_eal"]
|
|
|
|
tensor_requirements += {
|
|
"Linear": [],
|
|
"RGB": [],
|
|
"HSL": ["hsl"],
|
|
"HSV": ["hsv"],
|
|
"Okhsl": ["okhsl"],
|
|
"Okhsv": ["okhsv"],
|
|
"Oklch": ["oklch"],
|
|
"LCh": ["lch"],
|
|
}[color_space]
|
|
|
|
image_tensors = (
|
|
upper_rgb_l_tensor, # linear-light sRGB
|
|
lower_rgb_l_tensor, # linear-light sRGB
|
|
upper_rgb_tensor,
|
|
lower_rgb_tensor,
|
|
alpha_upper_tensor,
|
|
alpha_lower_tensor,
|
|
mask_tensor,
|
|
upper_hsv_tensor,
|
|
lower_hsv_tensor,
|
|
upper_hsl_tensor,
|
|
lower_hsl_tensor,
|
|
upper_lab_tensor,
|
|
lower_lab_tensor,
|
|
upper_lch_tensor,
|
|
lower_lch_tensor,
|
|
upper_l_eal_tensor,
|
|
lower_l_eal_tensor,
|
|
upper_oklab_tensor,
|
|
lower_oklab_tensor,
|
|
upper_oklch_tensor,
|
|
lower_oklch_tensor,
|
|
upper_okhsv_tensor,
|
|
lower_okhsv_tensor,
|
|
upper_okhsl_tensor,
|
|
lower_okhsl_tensor,
|
|
) = self.prepare_tensors_from_images(
|
|
image_upper, image_base, mask_image=image_mask, required=tensor_requirements
|
|
)
|
|
|
|
# if not (self.blend_mode == "Normal"):
|
|
upper_rgb_l_tensor = self.apply_blend(image_tensors)
|
|
|
|
output_tensor, alpha_tensor = self.alpha_composite(
|
|
srgb_from_linear_srgb(
|
|
upper_rgb_l_tensor, alpha=self.adaptive_gamut, steps=(3 if self.high_precision else 1)
|
|
),
|
|
alpha_upper_tensor,
|
|
lower_rgb_tensor,
|
|
alpha_lower_tensor,
|
|
mask_tensor=mask_tensor,
|
|
)
|
|
|
|
# Restore alpha channel and base mode:
|
|
output_tensor = torch.stack(
|
|
[output_tensor[0, :, :], output_tensor[1, :, :], output_tensor[2, :, :], alpha_tensor]
|
|
)
|
|
image_out = pil_image_from_tensor(output_tensor, mode="RGBA")
|
|
|
|
# Restore ICC profile if base image had one:
|
|
if cms_profile_base is not None:
|
|
cms_xform = ImageCms.buildTransformFromOpenProfiles(
|
|
cms_profile_srgb, BytesIO(icc_profile_bytes), "RGBA", image_out.mode
|
|
)
|
|
image_out = ImageCms.applyTransform(image_out, cms_xform)
|
|
else:
|
|
image_out = image_out.convert(image_mode_base)
|
|
|
|
image_dto = context.images.save(image_out)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
@invocation(
|
|
"invokeai_img_composite",
|
|
title="Image Compositor",
|
|
tags=["image", "compose", "chroma", "key"],
|
|
category="image",
|
|
version="1.2.0",
|
|
)
|
|
class InvokeImageCompositorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Removes backdrop from subject image then overlays subject on background image. Originally created by @dwringer"""
|
|
|
|
image_subject: ImageField = InputField(description="Image of the subject on a plain monochrome background")
|
|
image_background: ImageField = InputField(description="Image of a background scene")
|
|
chroma_key: str = InputField(
|
|
default="", description="Can be empty for corner flood select, or CSS-3 color or tuple"
|
|
)
|
|
threshold: int = InputField(ge=0, default=50, description="Subject isolation flood-fill threshold")
|
|
fill_x: bool = InputField(default=False, description="Scale base subject image to fit background width")
|
|
fill_y: bool = InputField(default=True, description="Scale base subject image to fit background height")
|
|
x_offset: int = InputField(default=0, description="x-offset for the subject")
|
|
y_offset: int = InputField(default=0, description="y-offset for the subject")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image_background = context.images.get_pil(self.image_background.image_name).convert(mode="RGBA")
|
|
image_subject = context.images.get_pil(self.image_subject.image_name).convert(mode="RGBA")
|
|
|
|
if image_subject.height == 0 or image_subject.width == 0:
|
|
raise ValueError("The subject image has zero height or width")
|
|
if image_background.height == 0 or image_background.width == 0:
|
|
raise ValueError("The subject image has zero height or width")
|
|
|
|
# Handle backdrop removal:
|
|
chroma_key = self.chroma_key.strip()
|
|
if 0 < len(chroma_key):
|
|
# Remove pixels by chroma key:
|
|
if chroma_key[0] == "(":
|
|
chroma_key = tuple_from_string(chroma_key)
|
|
while len(chroma_key) < 3:
|
|
chroma_key = tuple(list(chroma_key) + [0])
|
|
if len(chroma_key) == 3:
|
|
chroma_key = tuple(list(chroma_key) + [255])
|
|
else:
|
|
chroma_key = ImageColor.getcolor(chroma_key, "RGBA")
|
|
threshold = self.threshold**2.0 # to compare vs squared color distance from key
|
|
pixels = image_subject.load()
|
|
if pixels is None:
|
|
raise ValueError("Unable to load pixels from subject image")
|
|
for i in range(image_subject.width):
|
|
for j in range(image_subject.height):
|
|
if (
|
|
reduce(
|
|
lambda a, b: a + b, [(pixels[i, j][k] - chroma_key[k]) ** 2 for k in range(len(chroma_key))]
|
|
)
|
|
< threshold
|
|
):
|
|
pixels[i, j] = tuple([0 for k in range(len(chroma_key))])
|
|
else:
|
|
# Remove pixels by flood select from corners:
|
|
ImageDraw.floodfill(image_subject, (0, 0), (0, 0, 0, 0), thresh=self.threshold)
|
|
ImageDraw.floodfill(image_subject, (0, image_subject.height - 1), (0, 0, 0, 0), thresh=self.threshold)
|
|
ImageDraw.floodfill(image_subject, (image_subject.width - 1, 0), (0, 0, 0, 0), thresh=self.threshold)
|
|
ImageDraw.floodfill(
|
|
image_subject, (image_subject.width - 1, image_subject.height - 1), (0, 0, 0, 0), thresh=self.threshold
|
|
)
|
|
|
|
# Scale and position the subject:
|
|
aspect_background = image_background.width / image_background.height
|
|
aspect_subject = image_subject.width / image_subject.height
|
|
if self.fill_x and self.fill_y:
|
|
image_subject = image_subject.resize((image_background.width, image_background.height))
|
|
elif (self.fill_x and (aspect_background < aspect_subject)) or (
|
|
self.fill_y and (aspect_subject <= aspect_background)
|
|
):
|
|
image_subject = ImageOps.pad(
|
|
image_subject, (image_background.width, image_background.height), color=(0, 0, 0, 0)
|
|
)
|
|
elif (self.fill_x and (aspect_subject <= aspect_background)) or (
|
|
self.fill_y and (aspect_background < aspect_subject)
|
|
):
|
|
image_subject = ImageOps.fit(image_subject, (image_background.width, image_background.height))
|
|
if (self.x_offset != 0) or (self.y_offset != 0):
|
|
image_subject = ImageChops.offset(image_subject, self.x_offset, yoffset=-1 * self.y_offset)
|
|
|
|
new_image = Image.alpha_composite(image_background, image_subject)
|
|
new_image.convert(mode="RGB")
|
|
image_dto = context.images.save(new_image)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
DILATE_ERODE_MODES = Literal[
|
|
"Dilate",
|
|
"Erode",
|
|
]
|
|
|
|
|
|
@invocation(
|
|
"invokeai_img_dilate_erode",
|
|
title="Image Dilate or Erode",
|
|
tags=["image", "mask", "dilate", "erode", "expand", "contract", "mask"],
|
|
category="image",
|
|
version="1.3.0",
|
|
)
|
|
class InvokeImageDilateOrErodeInvocation(BaseInvocation, WithMetadata):
|
|
"""Dilate (expand) or erode (contract) an image. Originally created by @dwringer"""
|
|
|
|
image: ImageField = InputField(description="The image from which to create a mask")
|
|
lightness_only: bool = InputField(default=False, description="If true, only applies to image lightness (CIELa*b*)")
|
|
radius_w: int = InputField(
|
|
ge=0, default=4, description="Width (in pixels) by which to dilate(expand) or erode (contract) the image"
|
|
)
|
|
radius_h: int = InputField(
|
|
ge=0, default=4, description="Height (in pixels) by which to dilate(expand) or erode (contract) the image"
|
|
)
|
|
mode: DILATE_ERODE_MODES = InputField(default="Dilate", description="How to operate on the image")
|
|
|
|
def expand_or_contract(self, image_in: Image.Image):
|
|
image_out = numpy.array(image_in)
|
|
expand_radius_w = self.radius_w
|
|
expand_radius_h = self.radius_h
|
|
|
|
expand_fn = None
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_radius_w * 2 + 1, expand_radius_h * 2 + 1))
|
|
if self.mode == "Dilate":
|
|
expand_fn = cv2.dilate
|
|
elif self.mode == "Erode":
|
|
expand_fn = cv2.erode
|
|
else:
|
|
raise ValueError("Invalid mode selected")
|
|
image_out = expand_fn(image_out, kernel, iterations=1)
|
|
return Image.fromarray(image_out, mode=image_in.mode)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image_in = context.images.get_pil(self.image.image_name)
|
|
image_out = image_in
|
|
|
|
if self.lightness_only:
|
|
image_mode = image_in.mode
|
|
alpha_channel = None
|
|
if (image_mode == "RGBA") or (image_mode == "LA") or (image_mode == "PA"):
|
|
alpha_channel = image_in.getchannel("A")
|
|
elif (image_mode == "RGBa") or (image_mode == "La") or (image_mode == "Pa"):
|
|
alpha_channel = image_in.getchannel("a")
|
|
if (image_mode == "RGBA") or (image_mode == "RGBa"):
|
|
image_mode = "RGB"
|
|
elif (image_mode == "LA") or (image_mode == "La"):
|
|
image_mode = "L"
|
|
elif image_mode == "PA":
|
|
image_mode = "P"
|
|
image_out = image_out.convert("RGB")
|
|
image_out = image_out.convert("LAB")
|
|
l_channel = self.expand_or_contract(image_out.getchannel("L"))
|
|
image_out = Image.merge("LAB", (l_channel, image_out.getchannel("A"), image_out.getchannel("B")))
|
|
if (image_mode == "L") or (image_mode == "P"):
|
|
image_out = image_out.convert("RGB")
|
|
image_out = image_out.convert(image_mode)
|
|
if "a" in image_in.mode.lower():
|
|
image_out = Image.merge(
|
|
image_in.mode, tuple([image_out.getchannel(c) for c in image_mode] + [alpha_channel])
|
|
)
|
|
else:
|
|
image_out = self.expand_or_contract(image_out)
|
|
|
|
image_dto = context.images.save(image_out)
|
|
|
|
return ImageOutput.build(image_dto)
|
|
|
|
|
|
@invocation(
|
|
"invokeai_img_val_thresholds",
|
|
title="Image Value Thresholds",
|
|
tags=["image", "mask", "value", "threshold"],
|
|
category="image",
|
|
version="1.2.0",
|
|
)
|
|
class InvokeImageValueThresholdsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Clip image to pure black/white past specified thresholds. Originally created by @dwringer"""
|
|
|
|
image: ImageField = InputField(description="The image from which to create a mask")
|
|
invert_output: bool = InputField(default=False, description="Make light areas dark and vice versa")
|
|
renormalize_values: bool = InputField(default=False, description="Rescale remaining values from minimum to maximum")
|
|
lightness_only: bool = InputField(default=False, description="If true, only applies to image lightness (CIELa*b*)")
|
|
threshold_upper: float = InputField(default=0.5, description="Threshold above which will be set to full value")
|
|
threshold_lower: float = InputField(default=0.5, description="Threshold below which will be set to minimum value")
|
|
|
|
def get_threshold_mask(self, image_tensor: torch.Tensor):
|
|
img_tensor = image_tensor.clone()
|
|
threshold_h, threshold_s = self.threshold_upper, self.threshold_lower
|
|
ones_tensor = torch.ones(img_tensor.shape)
|
|
zeros_tensor = torch.zeros(img_tensor.shape)
|
|
|
|
zeros_mask, ones_mask = None, None
|
|
if self.invert_output:
|
|
zeros_mask, ones_mask = torch.ge(img_tensor, threshold_h), torch.lt(img_tensor, threshold_s)
|
|
else:
|
|
ones_mask, zeros_mask = torch.ge(img_tensor, threshold_h), torch.lt(img_tensor, threshold_s)
|
|
|
|
if not (threshold_h == threshold_s):
|
|
mask_hi = torch.ge(img_tensor, threshold_s)
|
|
mask_lo = torch.lt(img_tensor, threshold_h)
|
|
mask = torch.logical_and(mask_hi, mask_lo)
|
|
masked = img_tensor[mask]
|
|
if 0 < masked.numel():
|
|
if self.renormalize_values:
|
|
vmax, vmin = max(threshold_h, threshold_s), min(threshold_h, threshold_s)
|
|
if vmax == vmin:
|
|
img_tensor[mask] = vmin * ones_tensor[mask]
|
|
elif self.invert_output:
|
|
img_tensor[mask] = torch.sub(1.0, (img_tensor[mask] - vmin) / (vmax - vmin))
|
|
else:
|
|
img_tensor[mask] = (img_tensor[mask] - vmin) / (vmax - vmin)
|
|
|
|
img_tensor[ones_mask] = ones_tensor[ones_mask]
|
|
img_tensor[zeros_mask] = zeros_tensor[zeros_mask]
|
|
|
|
return img_tensor
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image_in = context.images.get_pil(self.image.image_name)
|
|
|
|
if self.lightness_only:
|
|
image_mode = image_in.mode
|
|
alpha_channel = None
|
|
if (image_mode == "RGBA") or (image_mode == "LA") or (image_mode == "PA"):
|
|
alpha_channel = image_in.getchannel("A")
|
|
elif (image_mode == "RGBa") or (image_mode == "La") or (image_mode == "Pa"):
|
|
alpha_channel = image_in.getchannel("a")
|
|
if (image_mode == "RGBA") or (image_mode == "RGBa"):
|
|
image_mode = "RGB"
|
|
elif (image_mode == "LA") or (image_mode == "La"):
|
|
image_mode = "L"
|
|
elif image_mode == "PA":
|
|
image_mode = "P"
|
|
image_out = image_in.convert("RGB")
|
|
image_out = image_out.convert("LAB")
|
|
|
|
l_channel = image_resized_to_grid_as_tensor(image_out.getchannel("L"), normalize=False)
|
|
l_channel = self.get_threshold_mask(l_channel)
|
|
l_channel = pil_image_from_tensor(l_channel)
|
|
|
|
image_out = Image.merge("LAB", (l_channel, image_out.getchannel("A"), image_out.getchannel("B")))
|
|
if (image_mode == "L") or (image_mode == "P"):
|
|
image_out = image_out.convert("RGB")
|
|
image_out = image_out.convert(image_mode)
|
|
if "a" in image_in.mode.lower():
|
|
image_out = Image.merge(
|
|
image_in.mode, tuple([image_out.getchannel(c) for c in image_mode] + [alpha_channel])
|
|
)
|
|
else:
|
|
image_out = image_resized_to_grid_as_tensor(image_in, normalize=False)
|
|
image_out = self.get_threshold_mask(image_out)
|
|
image_out = pil_image_from_tensor(image_out)
|
|
|
|
image_dto = context.images.save(image_out)
|
|
|
|
return ImageOutput.build(image_dto)
|