WIP - Pass prompt masks to FLUX model during denoising.

This commit is contained in:
Ryan Dick
2024-11-20 18:51:43 +00:00
parent 1948ffe106
commit 85c616fa34
5 changed files with 186 additions and 54 deletions

View File

@@ -10,6 +10,7 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -18,14 +19,8 @@ def denoise(
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
# positive text conditioning
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# negative text conditioning
neg_txt: torch.Tensor | None,
neg_txt_ids: torch.Tensor | None,
neg_vec: torch.Tensor | None,
pos_text_conditioning: FluxRegionalTextConditioning,
neg_text_conditioning: FluxRegionalTextConditioning | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -55,6 +50,7 @@ def denoise(
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
# FIX(ryand): Revive ControlNet functionality.
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step_index,

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
@dataclass
class FluxTextConditioning:
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
mask: torch.Tensor
@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
t5_txt_ids: torch.Tensor
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
image_masks: torch.Tensor
# List of ranges that represent the embedding ranges for each mask.
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
# clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i].
t5_embedding_ranges: list[Range]
clip_embedding_ranges: list[Range]