mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
WIP - Pass prompt masks to FLUX model during denoising.
This commit is contained in:
@@ -10,6 +10,7 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
|
||||
@@ -18,14 +19,8 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
pos_text_conditioning: FluxRegionalTextConditioning,
|
||||
neg_text_conditioning: FluxRegionalTextConditioning | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
@@ -55,6 +50,7 @@ def denoise(
|
||||
# Run ControlNet models.
|
||||
controlnet_residuals: list[ControlNetFluxOutput] = []
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
# FIX(ryand): Revive ControlNet functionality.
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step_index,
|
||||
|
||||
32
invokeai/backend/flux/text_conditioning.py
Normal file
32
invokeai/backend/flux/text_conditioning.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxTextConditioning:
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
# Dtype: torch.bool
|
||||
image_masks: torch.Tensor
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
# clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
clip_embedding_ranges: list[Range]
|
||||
Reference in New Issue
Block a user