Test out IP-Adapter with CFG.

This commit is contained in:
Ryan Dick
2024-10-16 18:11:48 +00:00
parent f70a8e2c1a
commit dde54740c5
2 changed files with 52 additions and 26 deletions

View File

@@ -4,7 +4,7 @@ from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@@ -258,7 +258,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# We do this before loading other models to minimize peak memory.
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self._normalize_ip_adapter_fields()
image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(ip_adapter_fields, context)
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
ip_adapter_fields, context
)
cfg_scale = self.prep_cfg_scale(
cfg_scale=self.cfg_scale,
@@ -316,8 +318,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare IP-Adapter extensions.
ip_adapter_extensions = self._prep_ip_adapter_extensions(
image_prompt_clip_embeds=image_prompt_clip_embeds,
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
neg_image_prompt_clip_embeds=neg_image_prompt_clip_embeds,
ip_adapter_fields=ip_adapter_fields,
context=context,
exit_stack=exit_stack,
@@ -340,7 +343,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
cfg_scale=cfg_scale,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
ip_adapter_extensions=ip_adapter_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -548,9 +552,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
self,
ip_adapter_fields: list[IPAdapterField],
context: InvocationContext,
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompt_clip_embeds: list[torch.Tensor] = []
clip_image_processor = CLIPImageProcessor()
pos_image_prompt_clip_embeds: list[torch.Tensor] = []
neg_image_prompt_clip_embeds: list[torch.Tensor] = []
for ip_adapter_field in ip_adapter_fields:
# `ip_adapter_field.image` could be a list or a single ImageField. Normalize to a list here.
ipa_image_fields: list[ImageField]
@@ -565,24 +572,30 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
image_prompt_clip_embeds.append(
XLabsIPAdapterExtension.run_clip_image_encoder(
pil_image=ipa_images,
image_encoder=image_encoder_model,
)
)
return image_prompt_clip_embeds
clip_image: torch.Tensor = clip_image_processor(images=ipa_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
neg_clip_image_embeds = image_encoder_model(torch.zeros_like(clip_image)).image_embeds
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
neg_image_prompt_clip_embeds.append(neg_clip_image_embeds)
return pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds
def _prep_ip_adapter_extensions(
self,
ip_adapter_fields: list[IPAdapterField],
image_prompt_clip_embeds: list[torch.Tensor],
pos_image_prompt_clip_embeds: list[torch.Tensor],
neg_image_prompt_clip_embeds: list[torch.Tensor],
context: InvocationContext,
exit_stack: ExitStack,
dtype: torch.dtype,
) -> list[XLabsIPAdapterExtension]:
ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
for ip_adapter_field, image_prompt_clip_embed in zip(ip_adapter_fields, image_prompt_clip_embeds, strict=True):
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension] = []
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
@@ -590,16 +603,25 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError("IP-Adapter masks are not yet supported in Flux.")
ip_adapter_extension = XLabsIPAdapterExtension(
model=ip_adapter_model,
image_prompt_clip_embed=image_prompt_clip_embed,
image_prompt_clip_embed=pos_image_prompt_clip_embed,
weight=ip_adapter_field.weight,
begin_step_percent=ip_adapter_field.begin_step_percent,
end_step_percent=ip_adapter_field.end_step_percent,
)
ip_adapter_extension.run_image_proj(dtype=dtype)
ip_adapter_extensions.append(ip_adapter_extension)
pos_ip_adapter_extensions.append(ip_adapter_extension)
return ip_adapter_extensions
ip_adapter_extension = XLabsIPAdapterExtension(
model=ip_adapter_model,
image_prompt_clip_embed=neg_image_prompt_clip_embed,
weight=ip_adapter_field.weight,
begin_step_percent=ip_adapter_field.begin_step_percent,
end_step_percent=ip_adapter_field.end_step_percent,
)
ip_adapter_extension.run_image_proj(dtype=dtype)
neg_ip_adapter_extensions.append(ip_adapter_extension)
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:

View File

@@ -33,7 +33,8 @@ def denoise(
cfg_scale: list[float],
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
ip_adapter_extensions: list[XLabsIPAdapterExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -69,7 +70,7 @@ def denoise(
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
@@ -86,15 +87,15 @@ def denoise(
total_num_timesteps=total_steps,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=ip_adapter_extensions,
ip_adapter_extensions=pos_ip_adapter_extensions,
)
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
if not math.isclose(step_cfg_scale, 1.0):
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
# systems with sufficient VRAM.
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
@@ -107,8 +108,11 @@ def denoise(
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
)
pred = neg_pred + step_cfg_scale * (pred - neg_pred)