mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Feature: Add Z-Image-Turbo regional guidance (#8672)
* feat: Add Regional Guidance support for Z-Image model Implements regional prompting for Z-Image (S3-DiT Transformer) allowing different prompts to affect different image regions using attention masks. Backend changes: - Add ZImageRegionalPromptingExtension for mask preparation - Add ZImageTextConditioning and ZImageRegionalTextConditioning data classes - Patch transformer forward to inject 4D regional attention masks - Use additive float mask (0.0 attend, -inf block) in bfloat16 for compatibility - Alternate regional/full attention layers for global coherence Frontend changes: - Update buildZImageGraph to support regional conditioning collectors - Update addRegions to create z_image_text_encoder nodes for regions - Update addZImageLoRAs to handle optional negCond when guidance_scale=0 - Add Z-Image validation (no IP adapters, no autoNegative) * @Pfannkuchensack Fix windows path again * ruff check fix * ruff formating * fix(ui): Z-Image CFG guidance_scale check uses > 1 instead of > 0 Changed the guidance_scale check from > 0 to > 1 for Z-Image models. Since Z-Image uses guidance_scale=1.0 as "no CFG" (matching FLUX convention), negative conditioning should only be created when guidance_scale > 1. --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
committed by
GitHub
parent
de1aa557b8
commit
769cf52209
@@ -333,6 +333,11 @@ class ZImageConditioningField(BaseModel):
|
||||
"""A Z-Image conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor for regional prompting. "
|
||||
"Excluded regions should be set to False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
|
||||
@@ -32,11 +32,14 @@ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import Rec
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
|
||||
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
|
||||
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
||||
from invokeai.backend.z_image.z_image_controlnet_extension import (
|
||||
ZImageControlNetExtension,
|
||||
z_image_forward_with_control,
|
||||
)
|
||||
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -44,11 +47,14 @@ from invokeai.backend.z_image.z_image_controlnet_extension import (
|
||||
title="Denoise - Z-Image",
|
||||
tags=["image", "z-image"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageDenoiseInvocation(BaseInvocation):
|
||||
"""Run the denoising process with a Z-Image model."""
|
||||
"""Run the denoising process with a Z-Image model.
|
||||
|
||||
Supports regional prompting by connecting multiple conditioning inputs with masks.
|
||||
"""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
@@ -63,10 +69,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: ZImageConditioningField = InputField(
|
||||
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: Optional[ZImageConditioningField] = InputField(
|
||||
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
|
||||
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
|
||||
@@ -126,25 +132,50 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
cond_field: ZImageConditioningField | list[ZImageConditioningField],
|
||||
img_height: int,
|
||||
img_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Load Z-Image text conditioning."""
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
if len(cond_data.conditionings) != 1:
|
||||
raise ValueError(
|
||||
f"Expected exactly 1 conditioning entry for Z-Image, got {len(cond_data.conditionings)}. "
|
||||
"Ensure you are using the Z-Image text encoder."
|
||||
)
|
||||
z_image_conditioning = cond_data.conditionings[0]
|
||||
if not isinstance(z_image_conditioning, ZImageConditioningInfo):
|
||||
raise TypeError(
|
||||
f"Expected ZImageConditioningInfo, got {type(z_image_conditioning).__name__}. "
|
||||
"Ensure you are using the Z-Image text encoder."
|
||||
)
|
||||
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
|
||||
return z_image_conditioning.prompt_embeds
|
||||
) -> list[ZImageTextConditioning]:
|
||||
"""Load Z-Image text conditioning with optional regional masks.
|
||||
|
||||
Args:
|
||||
context: The invocation context.
|
||||
cond_field: Single conditioning field or list of fields.
|
||||
img_height: Height of the image token grid (H // patch_size).
|
||||
img_width: Width of the image token grid (W // patch_size).
|
||||
dtype: Target dtype.
|
||||
device: Target device.
|
||||
|
||||
Returns:
|
||||
List of ZImageTextConditioning objects with embeddings and masks.
|
||||
"""
|
||||
# Normalize to a list
|
||||
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[ZImageTextConditioning] = []
|
||||
for cond in cond_list:
|
||||
# Load the text embeddings
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
z_image_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
|
||||
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
|
||||
prompt_embeds = z_image_conditioning.prompt_embeds
|
||||
|
||||
# Load the mask, if provided
|
||||
mask: torch.Tensor | None = None
|
||||
if cond.mask is not None:
|
||||
mask = context.tensors.load(cond.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, img_height, img_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
@@ -221,14 +252,33 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Load positive conditioning
|
||||
pos_prompt_embeds = self._load_text_conditioning(
|
||||
# Calculate image token grid dimensions
|
||||
patch_size = 2 # Z-Image uses patch_size=2
|
||||
latent_height = self.height // LATENT_SCALE_FACTOR
|
||||
latent_width = self.width // LATENT_SCALE_FACTOR
|
||||
img_token_height = latent_height // patch_size
|
||||
img_token_width = latent_width // patch_size
|
||||
img_seq_len = img_token_height * img_token_width
|
||||
|
||||
# Load positive conditioning with regional masks
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
cond_field=self.positive_conditioning,
|
||||
img_height=img_token_height,
|
||||
img_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create regional prompting extension
|
||||
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
|
||||
text_conditionings=pos_text_conditionings,
|
||||
img_seq_len=img_seq_len,
|
||||
)
|
||||
|
||||
# Get the concatenated prompt embeddings for the transformer
|
||||
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
|
||||
|
||||
# Load negative conditioning if provided and guidance_scale != 1.0
|
||||
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
|
||||
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
|
||||
@@ -238,21 +288,22 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
if self.negative_conditioning is None:
|
||||
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
|
||||
neg_prompt_embeds = self._load_text_conditioning(
|
||||
assert self.negative_conditioning is not None
|
||||
# Load all negative conditionings and concatenate embeddings
|
||||
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
cond_field=self.negative_conditioning,
|
||||
img_height=img_token_height,
|
||||
img_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Calculate image sequence length for timestep shifting
|
||||
patch_size = 2 # Z-Image uses patch_size=2
|
||||
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
|
||||
# Concatenate all negative embeddings
|
||||
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
|
||||
|
||||
# Calculate shift based on image sequence length
|
||||
mu = self._calculate_shift(image_seq_len)
|
||||
mu = self._calculate_shift(img_seq_len)
|
||||
|
||||
# Generate sigma schedule with time shift
|
||||
sigmas = self._get_sigmas(mu, self.steps)
|
||||
@@ -443,6 +494,15 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
# Apply regional prompting patch if we have regional masks
|
||||
exit_stack.enter_context(
|
||||
patch_transformer_for_regional_prompting(
|
||||
transformer=transformer,
|
||||
regional_attn_mask=regional_extension.regional_attn_mask,
|
||||
img_seq_len=img_seq_len,
|
||||
)
|
||||
)
|
||||
|
||||
# Denoising loop
|
||||
for step_idx in tqdm(range(total_steps)):
|
||||
sigma_curr = sigmas[step_idx]
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
ZImageConditioningField,
|
||||
)
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import ZImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -27,11 +34,14 @@ Z_IMAGE_MAX_SEQ_LEN = 512
|
||||
title="Prompt - Z-Image",
|
||||
tags=["prompt", "conditioning", "z-image"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a Z-Image image."""
|
||||
"""Encodes and preps a prompt for a Z-Image image.
|
||||
|
||||
Supports regional prompting by connecting a mask input.
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
qwen3_encoder: Qwen3EncoderField = InputField(
|
||||
@@ -39,13 +49,19 @@ class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None,
|
||||
description="A mask defining the region that this conditioning prompt applies to.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
|
||||
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ZImageConditioningOutput.build(conditioning_name)
|
||||
return ZImageConditioningOutput(
|
||||
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
"""Encode prompt using Qwen3 text encoder.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Z-Image Control Transformer support for InvokeAI
|
||||
# Z-Image backend utilities
|
||||
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
||||
from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
|
||||
from invokeai.backend.z_image.z_image_controlnet_extension import (
|
||||
|
||||
1
invokeai/backend/z_image/extensions/__init__.py
Normal file
1
invokeai/backend/z_image/extensions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Z-Image extensions
|
||||
@@ -0,0 +1,207 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.z_image.text_conditioning import ZImageRegionalTextConditioning, ZImageTextConditioning
|
||||
|
||||
|
||||
class ZImageRegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with Z-Image.
|
||||
|
||||
This implementation is inspired by the FLUX regional prompting extension and
|
||||
the paper https://arxiv.org/pdf/2411.02395.
|
||||
|
||||
Key difference from FLUX: Z-Image uses sequence order [img_tokens, txt_tokens],
|
||||
while FLUX uses [txt_tokens, img_tokens]. The attention mask construction
|
||||
accounts for this difference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: ZImageRegionalTextConditioning,
|
||||
regional_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.regional_attn_mask = regional_attn_mask
|
||||
|
||||
def get_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
"""Get the attention mask for a given block index.
|
||||
|
||||
Uses alternating pattern: apply mask on even blocks, no mask on odd blocks.
|
||||
This helps balance regional control with global coherence.
|
||||
"""
|
||||
order = [self.regional_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditionings(
|
||||
cls,
|
||||
text_conditionings: list[ZImageTextConditioning],
|
||||
img_seq_len: int,
|
||||
) -> "ZImageRegionalPromptingExtension":
|
||||
"""Create a ZImageRegionalPromptingExtension from a list of text conditionings.
|
||||
|
||||
Args:
|
||||
text_conditionings: List of text conditionings with optional masks.
|
||||
img_seq_len: The image sequence length (i.e. (H // patch_size) * (W // patch_size)).
|
||||
|
||||
Returns:
|
||||
A configured ZImageRegionalPromptingExtension.
|
||||
"""
|
||||
regional_text_conditioning = ZImageRegionalTextConditioning.from_text_conditionings(text_conditionings)
|
||||
attn_mask = cls._prepare_regional_attn_mask(regional_text_conditioning, img_seq_len)
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
regional_attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_regional_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: ZImageRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare a regional attention mask for Z-Image.
|
||||
|
||||
The mask controls which tokens can attend to each other:
|
||||
- Image tokens within a region attend only to each other
|
||||
- Image tokens attend only to their corresponding regional text
|
||||
- Text tokens attend only to their corresponding regional image
|
||||
- Text tokens attend to themselves
|
||||
|
||||
Z-Image sequence order: [img_tokens, txt_tokens]
|
||||
|
||||
Args:
|
||||
regional_text_conditioning: The regional text conditioning data.
|
||||
img_seq_len: Number of image tokens.
|
||||
|
||||
Returns:
|
||||
Attention mask of shape (img_seq_len + txt_seq_len, img_seq_len + txt_seq_len).
|
||||
Returns None if no regional masks are present.
|
||||
"""
|
||||
# Check if any regional masks exist
|
||||
has_regional_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
|
||||
if not has_regional_masks:
|
||||
# No regional masks, return None to use default attention
|
||||
return None
|
||||
|
||||
# Identify background region (area not covered by any mask)
|
||||
background_region_mask: torch.Tensor | None = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
# image_mask shape: (1, 1, img_seq_len) -> flatten to (img_seq_len,)
|
||||
mask_flat = image_mask.view(-1)
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(mask_flat)
|
||||
background_region_mask = background_region_mask * (1 - mask_flat)
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
txt_seq_len = regional_text_conditioning.prompt_embeds.shape[0]
|
||||
total_seq_len = img_seq_len + txt_seq_len
|
||||
|
||||
# Initialize empty attention mask
|
||||
# Z-Image sequence: [img_tokens (0:img_seq_len), txt_tokens (img_seq_len:total_seq_len)]
|
||||
regional_attention_mask = torch.zeros((total_seq_len, total_seq_len), device=device, dtype=torch.float16)
|
||||
|
||||
for image_mask, embedding_range in zip(
|
||||
regional_text_conditioning.image_masks,
|
||||
regional_text_conditioning.embedding_ranges,
|
||||
strict=True,
|
||||
):
|
||||
# Calculate text token positions in the unified sequence
|
||||
txt_start = img_seq_len + embedding_range.start
|
||||
txt_end = img_seq_len + embedding_range.end
|
||||
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[txt_start:txt_end, txt_start:txt_end] = 1.0
|
||||
|
||||
if image_mask is not None:
|
||||
# Flatten mask: (1, 1, img_seq_len) -> (img_seq_len,)
|
||||
mask_flat = image_mask.view(img_seq_len)
|
||||
|
||||
# 2. img attends to corresponding regional txt
|
||||
# Reshape mask to (img_seq_len, 1) for broadcasting
|
||||
regional_attention_mask[:img_seq_len, txt_start:txt_end] = mask_flat.view(img_seq_len, 1)
|
||||
|
||||
# 3. txt attends to corresponding regional img
|
||||
# Reshape mask to (1, img_seq_len) for broadcasting
|
||||
regional_attention_mask[txt_start:txt_end, :img_seq_len] = mask_flat.view(1, img_seq_len)
|
||||
|
||||
# 4. img self-attention within region
|
||||
# mask @ mask.T creates pairwise attention within the masked region
|
||||
regional_attention_mask[:img_seq_len, :img_seq_len] += mask_flat.view(img_seq_len, 1) @ mask_flat.view(
|
||||
1, img_seq_len
|
||||
)
|
||||
else:
|
||||
# Global prompt: allow attention to/from background regions only
|
||||
if background_region_mask is not None:
|
||||
# 2. background img attends to global txt
|
||||
regional_attention_mask[:img_seq_len, txt_start:txt_end] = background_region_mask.view(
|
||||
img_seq_len, 1
|
||||
)
|
||||
|
||||
# 3. global txt attends to background img
|
||||
regional_attention_mask[txt_start:txt_end, :img_seq_len] = background_region_mask.view(
|
||||
1, img_seq_len
|
||||
)
|
||||
else:
|
||||
# No regional masks at all, allow full attention
|
||||
regional_attention_mask[:img_seq_len, txt_start:txt_end] = 1.0
|
||||
regional_attention_mask[txt_start:txt_end, :img_seq_len] = 1.0
|
||||
|
||||
# Allow background regions to attend to themselves
|
||||
if background_region_mask is not None:
|
||||
bg_mask = background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[:img_seq_len, :img_seq_len] += bg_mask @ bg_mask.T
|
||||
|
||||
# Convert to boolean mask
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor],
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target image token grid.
|
||||
|
||||
Args:
|
||||
mask: Input mask tensor. If None, returns a mask of all ones.
|
||||
target_height: Height of the image token grid (H // patch_size).
|
||||
target_width: Width of the image token grid (W // patch_size).
|
||||
dtype: Target dtype for the mask.
|
||||
device: Target device for the mask.
|
||||
|
||||
Returns:
|
||||
Processed mask of shape (1, 1, target_height * target_width).
|
||||
"""
|
||||
img_seq_len = target_height * target_width
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, img_seq_len), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
# Resize mask to target dimensions
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width),
|
||||
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
|
||||
)
|
||||
|
||||
# Add batch dimension if needed: (h, w) -> (1, h, w) -> (1, 1, h, w)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
resized_mask = tf(mask)
|
||||
|
||||
# Flatten to (1, 1, img_seq_len)
|
||||
return resized_mask.flatten(start_dim=2).to(device=device)
|
||||
74
invokeai/backend/z_image/text_conditioning.py
Normal file
74
invokeai/backend/z_image/text_conditioning.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZImageTextConditioning:
|
||||
"""Z-Image text conditioning with optional regional mask.
|
||||
|
||||
Attributes:
|
||||
prompt_embeds: Text embeddings from Qwen3 encoder. Shape: (seq_len, hidden_size).
|
||||
mask: Optional binary mask for regional prompting. If None, the prompt is global.
|
||||
Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size).
|
||||
"""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZImageRegionalTextConditioning:
|
||||
"""Container for multiple regional text conditionings concatenated together.
|
||||
|
||||
In Z-Image, the unified sequence is [img_tokens, txt_tokens], which is different
|
||||
from FLUX where it's [txt_tokens, img_tokens]. The attention mask must account for this.
|
||||
|
||||
Attributes:
|
||||
prompt_embeds: Concatenated text embeddings from all regional prompts.
|
||||
Shape: (total_seq_len, hidden_size).
|
||||
image_masks: List of binary masks for each regional prompt.
|
||||
image_masks[i] corresponds to embedding_ranges[i].
|
||||
If None, the prompt is global (applies to entire image).
|
||||
Shape: (1, 1, img_seq_len).
|
||||
embedding_ranges: List of ranges indicating which portion of prompt_embeds
|
||||
corresponds to each regional prompt.
|
||||
"""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
image_masks: list[torch.Tensor | None]
|
||||
embedding_ranges: list[Range]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditionings(
|
||||
cls,
|
||||
text_conditionings: list[ZImageTextConditioning],
|
||||
) -> "ZImageRegionalTextConditioning":
|
||||
"""Create a ZImageRegionalTextConditioning from a list of ZImageTextConditioning objects.
|
||||
|
||||
Args:
|
||||
text_conditionings: List of text conditionings, each with optional mask.
|
||||
|
||||
Returns:
|
||||
A single ZImageRegionalTextConditioning with concatenated embeddings.
|
||||
"""
|
||||
concat_embeds: list[torch.Tensor] = []
|
||||
concat_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
|
||||
cur_embed_len = 0
|
||||
for tc in text_conditionings:
|
||||
concat_embeds.append(tc.prompt_embeds)
|
||||
concat_ranges.append(Range(start=cur_embed_len, end=cur_embed_len + tc.prompt_embeds.shape[0]))
|
||||
image_masks.append(tc.mask)
|
||||
cur_embed_len += tc.prompt_embeds.shape[0]
|
||||
|
||||
prompt_embeds = torch.cat(concat_embeds, dim=0)
|
||||
|
||||
return cls(
|
||||
prompt_embeds=prompt_embeds,
|
||||
image_masks=image_masks,
|
||||
embedding_ranges=concat_ranges,
|
||||
)
|
||||
234
invokeai/backend/z_image/z_image_transformer_patch.py
Normal file
234
invokeai/backend/z_image/z_image_transformer_patch.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Utilities for patching the ZImageTransformer2DModel to support regional attention masks."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def create_regional_forward(
|
||||
original_forward: Callable,
|
||||
regional_attn_mask: torch.Tensor,
|
||||
img_seq_len: int,
|
||||
) -> Callable:
|
||||
"""Create a modified forward function that uses a regional attention mask.
|
||||
|
||||
The regional attention mask replaces the internally computed padding mask,
|
||||
allowing for regional prompting where different image regions attend to
|
||||
different text prompts.
|
||||
|
||||
Args:
|
||||
original_forward: The original forward method of ZImageTransformer2DModel.
|
||||
regional_attn_mask: Attention mask of shape (seq_len, seq_len) where
|
||||
seq_len = img_seq_len + txt_seq_len.
|
||||
img_seq_len: Number of image tokens in the sequence.
|
||||
|
||||
Returns:
|
||||
A modified forward function with regional attention support.
|
||||
"""
|
||||
|
||||
def regional_forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t: torch.Tensor,
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
) -> Tuple[List[torch.Tensor], dict]:
|
||||
"""Modified forward with regional attention mask injection.
|
||||
|
||||
This is based on the original ZImageTransformer2DModel.forward but
|
||||
replaces the padding-based attention mask with a regional attention mask.
|
||||
"""
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t_scaled = t * self.t_scale
|
||||
t_emb = self.t_embedder(t_scaled)
|
||||
|
||||
SEQ_MULTI_OF = 32 # From diffusers transformer_z_image.py
|
||||
|
||||
# Patchify and embed (reusing the original method)
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x_cat = torch.cat(x, dim=0)
|
||||
x_cat = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
|
||||
|
||||
adaln_input = t_emb.type_as(x_cat)
|
||||
x_cat[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x_list = list(x_cat.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x_padded = pad_sequence(x_list, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis_padded = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Process through noise_refiner
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.noise_refiner:
|
||||
x_padded = self._gradient_checkpointing_func(
|
||||
layer, x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer in self.noise_refiner:
|
||||
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_cat = torch.cat(cap_feats, dim=0)
|
||||
cap_cat = self.cap_embedder(cap_cat)
|
||||
cap_cat[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_list = list(cap_cat.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_padded = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis_padded = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Process through context_refiner
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_padded = self._gradient_checkpointing_func(layer, cap_padded, cap_attn_mask, cap_freqs_cis_padded)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis_padded)
|
||||
|
||||
# Unified sequence: [img_tokens, txt_tokens]
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis_padded[i][:x_len], cap_freqs_cis_padded[i][:cap_len]]))
|
||||
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=False)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified_padded = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis_padded = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
|
||||
# --- REGIONAL ATTENTION MASK INJECTION ---
|
||||
# Instead of using the padding mask, we use the regional attention mask
|
||||
# The regional mask is (seq_len, seq_len), we need to expand it to (batch, seq_len, seq_len)
|
||||
# and then add the batch dimension for broadcasting: (batch, 1, seq_len, seq_len)
|
||||
|
||||
# Expand regional mask to match the actual sequence length (may include padding)
|
||||
if regional_attn_mask.shape[0] != unified_max_item_seqlen:
|
||||
# Pad the regional mask to match unified sequence length
|
||||
padded_regional_mask = torch.zeros(
|
||||
(unified_max_item_seqlen, unified_max_item_seqlen),
|
||||
dtype=regional_attn_mask.dtype,
|
||||
device=device,
|
||||
)
|
||||
mask_size = min(regional_attn_mask.shape[0], unified_max_item_seqlen)
|
||||
padded_regional_mask[:mask_size, :mask_size] = regional_attn_mask[:mask_size, :mask_size]
|
||||
else:
|
||||
padded_regional_mask = regional_attn_mask.to(device)
|
||||
|
||||
# Convert boolean mask to additive float mask for attention
|
||||
# True (attend) -> 0.0, False (block) -> -inf
|
||||
# This is required because the attention backend expects additive masks for 4D inputs
|
||||
# Use bfloat16 to match the transformer's query dtype
|
||||
float_mask = torch.zeros_like(padded_regional_mask, dtype=torch.bfloat16)
|
||||
float_mask[~padded_regional_mask] = float("-inf")
|
||||
|
||||
# Expand to (batch, 1, seq_len, seq_len) for attention
|
||||
unified_attn_mask = float_mask.unsqueeze(0).unsqueeze(0).expand(bsz, 1, -1, -1)
|
||||
|
||||
# Process through main layers with regional attention mask
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
# Alternate between regional mask and full attention
|
||||
if layer_idx % 2 == 0:
|
||||
unified_padded = self._gradient_checkpointing_func(
|
||||
layer, unified_padded, unified_attn_mask, unified_freqs_cis_padded, adaln_input
|
||||
)
|
||||
else:
|
||||
# Use padding mask only for odd layers (allows global coherence)
|
||||
padding_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
padding_mask[i, :seq_len] = 1
|
||||
unified_padded = self._gradient_checkpointing_func(
|
||||
layer, unified_padded, padding_mask, unified_freqs_cis_padded, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
# Alternate between regional mask and full attention
|
||||
if layer_idx % 2 == 0:
|
||||
unified_padded = layer(unified_padded, unified_attn_mask, unified_freqs_cis_padded, adaln_input)
|
||||
else:
|
||||
# Use padding mask only for odd layers (allows global coherence)
|
||||
padding_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
padding_mask[i, :seq_len] = 1
|
||||
unified_padded = layer(unified_padded, padding_mask, unified_freqs_cis_padded, adaln_input)
|
||||
|
||||
# Final layer
|
||||
unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified_padded, adaln_input)
|
||||
unified_list = list(unified_out.unbind(dim=0))
|
||||
x_out = self.unpatchify(unified_list, x_size, patch_size, f_patch_size)
|
||||
|
||||
return x_out, {}
|
||||
|
||||
return regional_forward
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_transformer_for_regional_prompting(
|
||||
transformer,
|
||||
regional_attn_mask: Optional[torch.Tensor],
|
||||
img_seq_len: int,
|
||||
):
|
||||
"""Context manager to temporarily patch the transformer for regional prompting.
|
||||
|
||||
Args:
|
||||
transformer: The ZImageTransformer2DModel instance.
|
||||
regional_attn_mask: Regional attention mask of shape (seq_len, seq_len).
|
||||
If None, the transformer is not patched.
|
||||
img_seq_len: Number of image tokens.
|
||||
|
||||
Yields:
|
||||
The (possibly patched) transformer.
|
||||
"""
|
||||
if regional_attn_mask is None:
|
||||
# No regional prompting, use original forward
|
||||
yield transformer
|
||||
return
|
||||
|
||||
# Store original forward
|
||||
original_forward = transformer.forward
|
||||
|
||||
# Create and bind the regional forward
|
||||
regional_fwd = create_regional_forward(original_forward, regional_attn_mask, img_seq_len)
|
||||
transformer.forward = lambda *args, **kwargs: regional_fwd(transformer, *args, **kwargs)
|
||||
|
||||
try:
|
||||
yield transformer
|
||||
finally:
|
||||
# Restore original forward
|
||||
transformer.forward = original_forward
|
||||
File diff suppressed because it is too large
Load Diff
@@ -59,6 +59,17 @@ export const getRegionalGuidanceWarnings = (
|
||||
}
|
||||
}
|
||||
|
||||
if (model.base === 'z-image') {
|
||||
// Z-Image has similar limitations to FLUX - no negative prompts via CFG by default
|
||||
// Reference images (IP Adapters) are not supported for Z-Image
|
||||
if (entity.referenceImages.length > 0) {
|
||||
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.autoNegative) {
|
||||
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
|
||||
}
|
||||
}
|
||||
|
||||
entity.referenceImages.forEach(({ config }) => {
|
||||
if (!config.model) {
|
||||
// No model selected
|
||||
|
||||
@@ -32,8 +32,8 @@ type AddRegionsArg = {
|
||||
g: Graph;
|
||||
bbox: Rect;
|
||||
model: MainModelConfig;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'>;
|
||||
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'> | null;
|
||||
posCondCollect: Invocation<'collect'>;
|
||||
negCondCollect: Invocation<'collect'> | null;
|
||||
ipAdapterCollect: Invocation<'collect'>;
|
||||
@@ -71,6 +71,7 @@ export const addRegions = async ({
|
||||
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = model.base === 'sdxl';
|
||||
const isFLUX = model.base === 'flux';
|
||||
const isZImage = model.base === 'z-image';
|
||||
|
||||
const validRegions = regions
|
||||
.filter((entity) => entity.isEnabled)
|
||||
@@ -111,7 +112,7 @@ export const addRegions = async ({
|
||||
if (region.positivePrompt) {
|
||||
// The main positive conditioning node
|
||||
result.addedPositivePrompt = true;
|
||||
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
@@ -125,6 +126,12 @@ export const addRegions = async ({
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else if (isZImage) {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'z_image_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCond = g.addNode({
|
||||
type: 'compel',
|
||||
@@ -155,6 +162,12 @@ export const addRegions = async ({
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'z_image_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['qwen3_encoder', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
@@ -166,7 +179,7 @@ export const addRegions = async ({
|
||||
|
||||
// The main negative conditioning node
|
||||
result.addedNegativePrompt = true;
|
||||
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'>;
|
||||
if (isSDXL) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
@@ -180,6 +193,12 @@ export const addRegions = async ({
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
} else if (isZImage) {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'z_image_text_encoder',
|
||||
id: getPrefixedId('prompt_region_negative_cond'),
|
||||
prompt: region.negativePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalNegCond = g.addNode({
|
||||
type: 'compel',
|
||||
@@ -211,6 +230,12 @@ export const addRegions = async ({
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (negCond.type === 'z_image_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['qwen3_encoder', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported negative conditioning node type.');
|
||||
}
|
||||
@@ -229,7 +254,9 @@ export const addRegions = async ({
|
||||
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
|
||||
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
|
||||
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
let regionalPosCondInverted: Invocation<
|
||||
'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'
|
||||
>;
|
||||
if (isSDXL) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
@@ -243,6 +270,12 @@ export const addRegions = async ({
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else if (isZImage) {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'z_image_text_encoder',
|
||||
id: getPrefixedId('prompt_region_positive_cond_inverted'),
|
||||
prompt: region.positivePrompt,
|
||||
});
|
||||
} else {
|
||||
regionalPosCondInverted = g.addNode({
|
||||
type: 'compel',
|
||||
@@ -273,6 +306,12 @@ export const addRegions = async ({
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else if (posCond.type === 'z_image_text_encoder') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['qwen3_encoder', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
assert(false, 'Unsupported positive conditioning node type.');
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ export const addZImageLoRAs = (
|
||||
denoise: Invocation<'z_image_denoise'>,
|
||||
modelLoader: Invocation<'z_image_model_loader'>,
|
||||
posCond: Invocation<'z_image_text_encoder'>,
|
||||
negCond: Invocation<'z_image_text_encoder'>
|
||||
negCond: Invocation<'z_image_text_encoder'> | null
|
||||
): void => {
|
||||
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'z-image');
|
||||
const loraCount = enabledLoRAs.length;
|
||||
@@ -39,10 +39,13 @@ export const addZImageLoRAs = (
|
||||
// Reroute model connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, ['transformer']);
|
||||
g.deleteEdgesTo(posCond, ['qwen3_encoder']);
|
||||
g.deleteEdgesTo(negCond, ['qwen3_encoder']);
|
||||
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
|
||||
g.addEdge(loraCollectionLoader, 'qwen3_encoder', posCond, 'qwen3_encoder');
|
||||
g.addEdge(loraCollectionLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
|
||||
// Only reroute negCond if it exists (guidance_scale > 0)
|
||||
if (negCond !== null) {
|
||||
g.deleteEdgesTo(negCond, ['qwen3_encoder']);
|
||||
g.addEdge(loraCollectionLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
|
||||
}
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
|
||||
@@ -14,6 +14,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { addZImageLoRAs } from 'features/nodes/util/graph/generation/addZImageLoRAs';
|
||||
@@ -75,12 +76,32 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
|
||||
type: 'z_image_text_encoder',
|
||||
id: getPrefixedId('pos_prompt'),
|
||||
});
|
||||
// Collect node for regional prompting support
|
||||
const posCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('pos_cond_collect'),
|
||||
});
|
||||
|
||||
// Z-Image supports negative conditioning when guidance_scale > 0
|
||||
const negCond = g.addNode({
|
||||
type: 'z_image_text_encoder',
|
||||
id: getPrefixedId('neg_prompt'),
|
||||
prompt: prompts.negative,
|
||||
// Z-Image supports negative conditioning when guidance_scale > 1
|
||||
// Only create negative conditioning nodes if guidance is used
|
||||
let negCond: Invocation<'z_image_text_encoder'> | null = null;
|
||||
let negCondCollect: Invocation<'collect'> | null = null;
|
||||
if (guidance_scale > 1) {
|
||||
negCond = g.addNode({
|
||||
type: 'z_image_text_encoder',
|
||||
id: getPrefixedId('neg_prompt'),
|
||||
prompt: prompts.negative,
|
||||
});
|
||||
negCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('neg_cond_collect'),
|
||||
});
|
||||
}
|
||||
|
||||
// Placeholder collect node for IP adapters (not supported for Z-Image but needed for addRegions)
|
||||
const ipAdapterCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collect'),
|
||||
});
|
||||
|
||||
const seed = g.addNode({
|
||||
@@ -100,17 +121,20 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
|
||||
|
||||
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
|
||||
g.addEdge(modelLoader, 'qwen3_encoder', posCond, 'qwen3_encoder');
|
||||
g.addEdge(modelLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
// Connect VAE to denoise for control image encoding
|
||||
g.addEdge(modelLoader, 'vae', denoise, 'vae');
|
||||
|
||||
g.addEdge(positivePrompt, 'value', posCond, 'prompt');
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
|
||||
// Connect positive conditioning through collector for regional support
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||
|
||||
// Only add negative conditioning edge if guidance_scale > 0
|
||||
if (guidance_scale > 0) {
|
||||
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
|
||||
// Connect negative conditioning if guidance_scale > 1
|
||||
if (negCond !== null && negCondCollect !== null) {
|
||||
g.addEdge(modelLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
|
||||
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
||||
}
|
||||
|
||||
g.addEdge(seed, 'value', denoise, 'seed');
|
||||
@@ -119,9 +143,10 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
|
||||
// Add Z-Image LoRAs if any are enabled
|
||||
addZImageLoRAs(state, g, denoise, modelLoader, posCond, negCond);
|
||||
|
||||
// Add Z-Image Control layers if any are enabled
|
||||
// Add regional guidance if canvas manager is available
|
||||
const canvas = selectCanvasSlice(state);
|
||||
if (manager !== null) {
|
||||
const canvas = selectCanvasSlice(state);
|
||||
// Add Z-Image Control layers if any are enabled
|
||||
const rect = canvas.bbox.rect;
|
||||
await addZImageControl({
|
||||
manager,
|
||||
@@ -130,8 +155,26 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
|
||||
rect,
|
||||
denoise,
|
||||
});
|
||||
|
||||
// Add regional guidance
|
||||
await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
g,
|
||||
bbox: canvas.bbox.rect,
|
||||
model,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
fluxReduxCollect: null, // Not supported for Z-Image
|
||||
});
|
||||
}
|
||||
|
||||
// IP Adapters are not supported for Z-Image, so delete the unused collector
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'z-image');
|
||||
|
||||
|
||||
@@ -25308,6 +25308,11 @@ export type components = {
|
||||
* @description The name of conditioning tensor
|
||||
*/
|
||||
conditioning_name: string;
|
||||
/**
|
||||
* @description The mask associated with this conditioning tensor for regional prompting. Excluded regions should be set to False, included regions should be set to True.
|
||||
* @default null
|
||||
*/
|
||||
mask?: components["schemas"]["TensorField"] | null;
|
||||
};
|
||||
/**
|
||||
* ZImageConditioningOutput
|
||||
@@ -25435,6 +25440,8 @@ export type components = {
|
||||
/**
|
||||
* Denoise - Z-Image
|
||||
* @description Run the denoising process with a Z-Image model.
|
||||
*
|
||||
* Supports regional prompting by connecting multiple conditioning inputs with masks.
|
||||
*/
|
||||
ZImageDenoiseInvocation: {
|
||||
/**
|
||||
@@ -25483,15 +25490,17 @@ export type components = {
|
||||
*/
|
||||
transformer?: components["schemas"]["TransformerField"] | null;
|
||||
/**
|
||||
* Positive Conditioning
|
||||
* @description Positive conditioning tensor
|
||||
* @default null
|
||||
*/
|
||||
positive_conditioning?: components["schemas"]["ZImageConditioningField"] | null;
|
||||
positive_conditioning?: components["schemas"]["ZImageConditioningField"] | components["schemas"]["ZImageConditioningField"][] | null;
|
||||
/**
|
||||
* Negative Conditioning
|
||||
* @description Negative conditioning tensor
|
||||
* @default null
|
||||
*/
|
||||
negative_conditioning?: components["schemas"]["ZImageConditioningField"] | null;
|
||||
negative_conditioning?: components["schemas"]["ZImageConditioningField"] | components["schemas"]["ZImageConditioningField"][] | null;
|
||||
/**
|
||||
* Guidance Scale
|
||||
* @description Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). Values > 1.0 amplify guidance.
|
||||
@@ -25848,6 +25857,8 @@ export type components = {
|
||||
/**
|
||||
* Prompt - Z-Image
|
||||
* @description Encodes and preps a prompt for a Z-Image image.
|
||||
*
|
||||
* Supports regional prompting by connecting a mask input.
|
||||
*/
|
||||
ZImageTextEncoderInvocation: {
|
||||
/**
|
||||
@@ -25879,6 +25890,11 @@ export type components = {
|
||||
* @default null
|
||||
*/
|
||||
qwen3_encoder?: components["schemas"]["Qwen3EncoderField"] | null;
|
||||
/**
|
||||
* @description A mask defining the region that this conditioning prompt applies to.
|
||||
* @default null
|
||||
*/
|
||||
mask?: components["schemas"]["TensorField"] | null;
|
||||
/**
|
||||
* type
|
||||
* @default z_image_text_encoder
|
||||
|
||||
Reference in New Issue
Block a user