tidy(nodes): bring masked blend latents masking logic into invoke core

This commit is contained in:
psychedelicious
2024-11-19 08:47:24 -06:00
parent a69c5ff9ef
commit 4267e5ffc4
2 changed files with 51 additions and 150 deletions

View File

@@ -1,51 +1,75 @@
from typing import Any, Union
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"lblend",
title="Blend Latents",
tags=["latents", "blend"],
tags=["latents", "blend", "mask"],
category="latents",
version="1.0.3",
version="1.1.0",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
"""Blend two latents using a given alpha. If a mask is provided, the second latents will be masked before blending.
Latents must have same size. Masking functionality added by @dwringer."""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
latents_a: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
latents_b: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[ImageField] = InputField(default=None, description="Mask for blending in latents B")
alpha: float = InputField(ge=0, default=0.5, description=FieldDescriptions.blend_alpha)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
return mask_tensor
def replace_tensor_from_masked_tensor(
self, tensor: torch.Tensor, other_tensor: torch.Tensor, mask_tensor: torch.Tensor
):
output = tensor.clone()
mask_tensor = mask_tensor.expand(output.shape)
if output.dtype != torch.float16:
output = torch.add(output, mask_tensor * torch.sub(other_tensor, tensor))
else:
output = torch.add(output, mask_tensor.half() * torch.sub(other_tensor, tensor))
return output
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if self.mask is None:
mask_tensor = torch.zeros(latents_a.shape[-2:])
else:
mask_tensor = self.prep_mask_tensor(context.images.get_pil(self.mask.image_name))
mask_tensor = tv_resize(mask_tensor, latents_a.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
latents_b = self.replace_tensor_from_masked_tensor(latents_b, latents_a, mask_tensor)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
raise ValueError("Latents to blend must be the same size.")
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
t: Union[float, np.ndarray],
v0: Union[torch.Tensor, np.ndarray],
v1: Union[torch.Tensor, np.ndarray],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
):
"""
Spherical linear interpolation
Args:
@@ -78,21 +102,16 @@ class BlendLatentsInvocation(BaseInvocation):
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
v2 = torch.from_numpy(v2).to(device)
return v2
# blend
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
blended_latents = slerp(self.alpha, latents_a, latents_b)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
TorchDevice.empty_cache()
torch.cuda.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
return LatentsOutput.build(latents_name=name, latents=blended_latents)

View File

@@ -5,15 +5,12 @@ from ast import literal_eval as tuple_from_string
from functools import reduce
from io import BytesIO
from math import pi as PI
from typing import Literal, Optional, Union
from typing import Literal, Optional
import cv2
import numpy
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageChops, ImageCms, ImageColor, ImageDraw, ImageEnhance, ImageOps
from torchvision.transforms.functional import resize as tv_resize
from torchvision.transforms.functional import to_pil_image as pil_image_from_tensor
from invokeai.app.invocations.primitives import ImageOutput
@@ -38,16 +35,11 @@ from invokeai.backend.image_util.composition import (
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.invocation_api import (
BaseInvocation,
FieldDescriptions,
ImageField,
Input,
InputField,
InvocationContext,
LatentsField,
LatentsOutput,
WithBoard,
WithMetadata,
choose_torch_device,
invocation,
)
@@ -277,116 +269,6 @@ class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard
return ImageOutput.build(image_dto)
@invocation(
"invokeai_lmblend",
title="Blend Latents/Noise (Masked)",
tags=["latents", "noise", "blend"],
category="latents",
version="1.2.0",
)
class InvokeMaskedBlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha and mask. Latents must have same size. Originally created by @dwringer"""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
mask: Optional[ImageField] = InputField(default=None, description="Mask for blending in latents B")
alpha: float = InputField(ge=0, default=0.5, description=FieldDescriptions.blend_alpha)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
return mask_tensor
def replace_tensor_from_masked_tensor(
self, tensor: torch.Tensor, other_tensor: torch.Tensor, mask_tensor: torch.Tensor
):
output = tensor.clone()
mask_tensor = mask_tensor.expand(output.shape)
if output.dtype != torch.float16:
output = torch.add(output, mask_tensor * torch.sub(other_tensor, tensor))
else:
output = torch.add(output, mask_tensor.half() * torch.sub(other_tensor, tensor))
return output
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if self.mask is None:
mask_tensor = torch.zeros(latents_a.shape[-2:])
else:
mask_tensor = self.prep_mask_tensor(context.images.get_pil(self.mask.image_name))
mask_tensor = tv_resize(mask_tensor, latents_a.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
latents_b = self.replace_tensor_from_masked_tensor(latents_b, latents_a, mask_tensor)
if latents_a.shape != latents_b.shape:
raise ValueError("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()
def slerp(
t: Union[float, np.ndarray],
v0: Union[torch.Tensor, np.ndarray],
v1: Union[torch.Tensor, np.ndarray],
DOT_THRESHOLD: float = 0.9995,
):
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(device)
return v2
# blend
blended_latents = slerp(self.alpha, latents_a, latents_b)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
@invocation(
"invokeai_img_enhance",
title="Enhance Image",