Use a black image as the negative IP prompt for parity with X-Labs implementation.

This commit is contained in:
Ryan Dick
2024-10-18 18:52:12 +00:00
parent dde54740c5
commit 73bbb12f7a

View File

@@ -1,6 +1,8 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import numpy as np
import numpy.typing as npt
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
@@ -570,12 +572,28 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
ipa_images = [context.images.get_pil(image.image_name) for image in ipa_image_fields]
pos_images: list[npt.NDArray[np.uint8]] = []
neg_images: list[npt.NDArray[np.uint8]] = []
for ipa_image in ipa_images:
assert ipa_image.mode == "RGB"
pos_image = np.array(ipa_image)
# We use a black image as the negative image prompt for parity with
# https://github.com/XLabs-AI/x-flux-comfyui/blob/45c834727dd2141aebc505ae4b01f193a8414e38/nodes.py#L592-L593
# An alternative scheme would be to apply zeros_like() after calling the clip_image_processor.
neg_image = np.zeros_like(pos_image)
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=ipa_images, return_tensors="pt").pixel_values
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
neg_clip_image_embeds = image_encoder_model(torch.zeros_like(clip_image)).image_embeds
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
neg_image_prompt_clip_embeds.append(neg_clip_image_embeds)
@@ -590,7 +608,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
context: InvocationContext,
exit_stack: ExitStack,
dtype: torch.dtype,
) -> list[XLabsIPAdapterExtension]:
) -> tuple[list[XLabsIPAdapterExtension], list[XLabsIPAdapterExtension]]:
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(