Feature: Add Z-Image-Turbo regional guidance (#8672)

* feat: Add Regional Guidance support for Z-Image model

Implements regional prompting for Z-Image (S3-DiT Transformer) allowing
different prompts to affect different image regions using attention masks.

Backend changes:
- Add ZImageRegionalPromptingExtension for mask preparation
- Add ZImageTextConditioning and ZImageRegionalTextConditioning data classes
- Patch transformer forward to inject 4D regional attention masks
- Use additive float mask (0.0 attend, -inf block) in bfloat16 for compatibility
- Alternate regional/full attention layers for global coherence

Frontend changes:
- Update buildZImageGraph to support regional conditioning collectors
- Update addRegions to create z_image_text_encoder nodes for regions
- Update addZImageLoRAs to handle optional negCond when guidance_scale=0
- Add Z-Image validation (no IP adapters, no autoNegative)

* @Pfannkuchensack
Fix windows path again

* ruff check fix

* ruff formating

* fix(ui): Z-Image CFG guidance_scale check uses > 1 instead of > 0

Changed the guidance_scale check from > 0 to > 1 for Z-Image models.
Since Z-Image uses guidance_scale=1.0 as "no CFG" (matching FLUX convention),
negative conditioning should only be created when guidance_scale > 1.

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
Alexander Eichhorn
2025-12-26 03:25:38 +01:00
committed by GitHub
parent de1aa557b8
commit 769cf52209
14 changed files with 24448 additions and 5284 deletions

View File

@@ -333,6 +333,11 @@ class ZImageConditioningField(BaseModel):
"""A Z-Image conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor for regional prompting. "
"Excluded regions should be set to False, included regions should be set to True.",
)
class ConditioningField(BaseModel):

View File

@@ -32,11 +32,14 @@ from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import Rec
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
from invokeai.backend.z_image.z_image_controlnet_extension import (
ZImageControlNetExtension,
z_image_forward_with_control,
)
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
@invocation(
@@ -44,11 +47,14 @@ from invokeai.backend.z_image.z_image_controlnet_extension import (
title="Denoise - Z-Image",
tags=["image", "z-image"],
category="image",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class ZImageDenoiseInvocation(BaseInvocation):
"""Run the denoising process with a Z-Image model."""
"""Run the denoising process with a Z-Image model.
Supports regional prompting by connecting multiple conditioning inputs with masks.
"""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
@@ -63,10 +69,10 @@ class ZImageDenoiseInvocation(BaseInvocation):
transformer: TransformerField = InputField(
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: ZImageConditioningField = InputField(
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: Optional[ZImageConditioningField] = InputField(
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
@@ -126,25 +132,50 @@ class ZImageDenoiseInvocation(BaseInvocation):
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
cond_field: ZImageConditioningField | list[ZImageConditioningField],
img_height: int,
img_width: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Load Z-Image text conditioning."""
cond_data = context.conditioning.load(conditioning_name)
if len(cond_data.conditionings) != 1:
raise ValueError(
f"Expected exactly 1 conditioning entry for Z-Image, got {len(cond_data.conditionings)}. "
"Ensure you are using the Z-Image text encoder."
)
z_image_conditioning = cond_data.conditionings[0]
if not isinstance(z_image_conditioning, ZImageConditioningInfo):
raise TypeError(
f"Expected ZImageConditioningInfo, got {type(z_image_conditioning).__name__}. "
"Ensure you are using the Z-Image text encoder."
)
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
return z_image_conditioning.prompt_embeds
) -> list[ZImageTextConditioning]:
"""Load Z-Image text conditioning with optional regional masks.
Args:
context: The invocation context.
cond_field: Single conditioning field or list of fields.
img_height: Height of the image token grid (H // patch_size).
img_width: Width of the image token grid (W // patch_size).
dtype: Target dtype.
device: Target device.
Returns:
List of ZImageTextConditioning objects with embeddings and masks.
"""
# Normalize to a list
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
text_conditionings: list[ZImageTextConditioning] = []
for cond in cond_list:
# Load the text embeddings
cond_data = context.conditioning.load(cond.conditioning_name)
assert len(cond_data.conditionings) == 1
z_image_conditioning = cond_data.conditionings[0]
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
prompt_embeds = z_image_conditioning.prompt_embeds
# Load the mask, if provided
mask: torch.Tensor | None = None
if cond.mask is not None:
mask = context.tensors.load(cond.mask.tensor_name)
mask = mask.to(device=device)
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, img_height, img_width, dtype, device
)
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
return text_conditionings
def _get_noise(
self,
@@ -221,14 +252,33 @@ class ZImageDenoiseInvocation(BaseInvocation):
transformer_info = context.models.load(self.transformer.transformer)
# Load positive conditioning
pos_prompt_embeds = self._load_text_conditioning(
# Calculate image token grid dimensions
patch_size = 2 # Z-Image uses patch_size=2
latent_height = self.height // LATENT_SCALE_FACTOR
latent_width = self.width // LATENT_SCALE_FACTOR
img_token_height = latent_height // patch_size
img_token_width = latent_width // patch_size
img_seq_len = img_token_height * img_token_width
# Load positive conditioning with regional masks
pos_text_conditionings = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
cond_field=self.positive_conditioning,
img_height=img_token_height,
img_width=img_token_width,
dtype=inference_dtype,
device=device,
)
# Create regional prompting extension
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
text_conditionings=pos_text_conditionings,
img_seq_len=img_seq_len,
)
# Get the concatenated prompt embeddings for the transformer
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
# Load negative conditioning if provided and guidance_scale != 1.0
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
@@ -238,21 +288,22 @@ class ZImageDenoiseInvocation(BaseInvocation):
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
)
if do_classifier_free_guidance:
if self.negative_conditioning is None:
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
neg_prompt_embeds = self._load_text_conditioning(
assert self.negative_conditioning is not None
# Load all negative conditionings and concatenate embeddings
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
neg_text_conditionings = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
cond_field=self.negative_conditioning,
img_height=img_token_height,
img_width=img_token_width,
dtype=inference_dtype,
device=device,
)
# Calculate image sequence length for timestep shifting
patch_size = 2 # Z-Image uses patch_size=2
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
# Concatenate all negative embeddings
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
# Calculate shift based on image sequence length
mu = self._calculate_shift(image_seq_len)
mu = self._calculate_shift(img_seq_len)
# Generate sigma schedule with time shift
sigmas = self._get_sigmas(mu, self.steps)
@@ -443,6 +494,15 @@ class ZImageDenoiseInvocation(BaseInvocation):
)
)
# Apply regional prompting patch if we have regional masks
exit_stack.enter_context(
patch_transformer_for_regional_prompting(
transformer=transformer,
regional_attn_mask=regional_extension.regional_attn_mask,
img_seq_len=img_seq_len,
)
)
# Denoising loop
for step_idx in tqdm(range(total_steps)):
sigma_curr = sigmas[step_idx]

View File

@@ -1,11 +1,18 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
from typing import Iterator, Optional, Tuple
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
TensorField,
UIComponent,
ZImageConditioningField,
)
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import ZImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -27,11 +34,14 @@ Z_IMAGE_MAX_SEQ_LEN = 512
title="Prompt - Z-Image",
tags=["prompt", "conditioning", "z-image"],
category="conditioning",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a Z-Image image."""
"""Encodes and preps a prompt for a Z-Image image.
Supports regional prompting by connecting a mask input.
"""
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
qwen3_encoder: Qwen3EncoderField = InputField(
@@ -39,13 +49,19 @@ class ZImageTextEncoderInvocation(BaseInvocation):
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
mask: Optional[TensorField] = InputField(
default=None,
description="A mask defining the region that this conditioning prompt applies to.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return ZImageConditioningOutput.build(conditioning_name)
return ZImageConditioningOutput(
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
"""Encode prompt using Qwen3 text encoder.

View File

@@ -1,4 +1,4 @@
# Z-Image Control Transformer support for InvokeAI
# Z-Image backend utilities
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
from invokeai.backend.z_image.z_image_controlnet_extension import (

View File

@@ -0,0 +1 @@
# Z-Image extensions

View File

@@ -0,0 +1,207 @@
from typing import Optional
import torch
import torchvision
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.z_image.text_conditioning import ZImageRegionalTextConditioning, ZImageTextConditioning
class ZImageRegionalPromptingExtension:
"""A class for managing regional prompting with Z-Image.
This implementation is inspired by the FLUX regional prompting extension and
the paper https://arxiv.org/pdf/2411.02395.
Key difference from FLUX: Z-Image uses sequence order [img_tokens, txt_tokens],
while FLUX uses [txt_tokens, img_tokens]. The attention mask construction
accounts for this difference.
"""
def __init__(
self,
regional_text_conditioning: ZImageRegionalTextConditioning,
regional_attn_mask: torch.Tensor | None = None,
):
self.regional_text_conditioning = regional_text_conditioning
self.regional_attn_mask = regional_attn_mask
def get_attn_mask(self, block_index: int) -> torch.Tensor | None:
"""Get the attention mask for a given block index.
Uses alternating pattern: apply mask on even blocks, no mask on odd blocks.
This helps balance regional control with global coherence.
"""
order = [self.regional_attn_mask, None]
return order[block_index % len(order)]
@classmethod
def from_text_conditionings(
cls,
text_conditionings: list[ZImageTextConditioning],
img_seq_len: int,
) -> "ZImageRegionalPromptingExtension":
"""Create a ZImageRegionalPromptingExtension from a list of text conditionings.
Args:
text_conditionings: List of text conditionings with optional masks.
img_seq_len: The image sequence length (i.e. (H // patch_size) * (W // patch_size)).
Returns:
A configured ZImageRegionalPromptingExtension.
"""
regional_text_conditioning = ZImageRegionalTextConditioning.from_text_conditionings(text_conditionings)
attn_mask = cls._prepare_regional_attn_mask(regional_text_conditioning, img_seq_len)
return cls(
regional_text_conditioning=regional_text_conditioning,
regional_attn_mask=attn_mask,
)
@classmethod
def _prepare_regional_attn_mask(
cls,
regional_text_conditioning: ZImageRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor | None:
"""Prepare a regional attention mask for Z-Image.
The mask controls which tokens can attend to each other:
- Image tokens within a region attend only to each other
- Image tokens attend only to their corresponding regional text
- Text tokens attend only to their corresponding regional image
- Text tokens attend to themselves
Z-Image sequence order: [img_tokens, txt_tokens]
Args:
regional_text_conditioning: The regional text conditioning data.
img_seq_len: Number of image tokens.
Returns:
Attention mask of shape (img_seq_len + txt_seq_len, img_seq_len + txt_seq_len).
Returns None if no regional masks are present.
"""
# Check if any regional masks exist
has_regional_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
if not has_regional_masks:
# No regional masks, return None to use default attention
return None
# Identify background region (area not covered by any mask)
background_region_mask: torch.Tensor | None = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
# image_mask shape: (1, 1, img_seq_len) -> flatten to (img_seq_len,)
mask_flat = image_mask.view(-1)
if background_region_mask is None:
background_region_mask = torch.ones_like(mask_flat)
background_region_mask = background_region_mask * (1 - mask_flat)
device = TorchDevice.choose_torch_device()
txt_seq_len = regional_text_conditioning.prompt_embeds.shape[0]
total_seq_len = img_seq_len + txt_seq_len
# Initialize empty attention mask
# Z-Image sequence: [img_tokens (0:img_seq_len), txt_tokens (img_seq_len:total_seq_len)]
regional_attention_mask = torch.zeros((total_seq_len, total_seq_len), device=device, dtype=torch.float16)
for image_mask, embedding_range in zip(
regional_text_conditioning.image_masks,
regional_text_conditioning.embedding_ranges,
strict=True,
):
# Calculate text token positions in the unified sequence
txt_start = img_seq_len + embedding_range.start
txt_end = img_seq_len + embedding_range.end
# 1. txt attends to itself
regional_attention_mask[txt_start:txt_end, txt_start:txt_end] = 1.0
if image_mask is not None:
# Flatten mask: (1, 1, img_seq_len) -> (img_seq_len,)
mask_flat = image_mask.view(img_seq_len)
# 2. img attends to corresponding regional txt
# Reshape mask to (img_seq_len, 1) for broadcasting
regional_attention_mask[:img_seq_len, txt_start:txt_end] = mask_flat.view(img_seq_len, 1)
# 3. txt attends to corresponding regional img
# Reshape mask to (1, img_seq_len) for broadcasting
regional_attention_mask[txt_start:txt_end, :img_seq_len] = mask_flat.view(1, img_seq_len)
# 4. img self-attention within region
# mask @ mask.T creates pairwise attention within the masked region
regional_attention_mask[:img_seq_len, :img_seq_len] += mask_flat.view(img_seq_len, 1) @ mask_flat.view(
1, img_seq_len
)
else:
# Global prompt: allow attention to/from background regions only
if background_region_mask is not None:
# 2. background img attends to global txt
regional_attention_mask[:img_seq_len, txt_start:txt_end] = background_region_mask.view(
img_seq_len, 1
)
# 3. global txt attends to background img
regional_attention_mask[txt_start:txt_end, :img_seq_len] = background_region_mask.view(
1, img_seq_len
)
else:
# No regional masks at all, allow full attention
regional_attention_mask[:img_seq_len, txt_start:txt_end] = 1.0
regional_attention_mask[txt_start:txt_end, :img_seq_len] = 1.0
# Allow background regions to attend to themselves
if background_region_mask is not None:
bg_mask = background_region_mask.view(img_seq_len, 1)
regional_attention_mask[:img_seq_len, :img_seq_len] += bg_mask @ bg_mask.T
# Convert to boolean mask
regional_attention_mask = regional_attention_mask > 0.5
return regional_attention_mask
@staticmethod
def preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor],
target_height: int,
target_width: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target image token grid.
Args:
mask: Input mask tensor. If None, returns a mask of all ones.
target_height: Height of the image token grid (H // patch_size).
target_width: Width of the image token grid (W // patch_size).
dtype: Target dtype for the mask.
device: Target device for the mask.
Returns:
Processed mask of shape (1, 1, target_height * target_width).
"""
img_seq_len = target_height * target_width
if mask is None:
return torch.ones((1, 1, img_seq_len), dtype=dtype, device=device)
mask = to_standard_float_mask(mask, out_dtype=dtype)
# Resize mask to target dimensions
tf = torchvision.transforms.Resize(
(target_height, target_width),
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
)
# Add batch dimension if needed: (h, w) -> (1, h, w) -> (1, 1, h, w)
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim == 3:
mask = mask.unsqueeze(0)
resized_mask = tf(mask)
# Flatten to (1, 1, img_seq_len)
return resized_mask.flatten(start_dim=2).to(device=device)

View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
@dataclass
class ZImageTextConditioning:
"""Z-Image text conditioning with optional regional mask.
Attributes:
prompt_embeds: Text embeddings from Qwen3 encoder. Shape: (seq_len, hidden_size).
mask: Optional binary mask for regional prompting. If None, the prompt is global.
Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size).
"""
prompt_embeds: torch.Tensor
mask: torch.Tensor | None = None
@dataclass
class ZImageRegionalTextConditioning:
"""Container for multiple regional text conditionings concatenated together.
In Z-Image, the unified sequence is [img_tokens, txt_tokens], which is different
from FLUX where it's [txt_tokens, img_tokens]. The attention mask must account for this.
Attributes:
prompt_embeds: Concatenated text embeddings from all regional prompts.
Shape: (total_seq_len, hidden_size).
image_masks: List of binary masks for each regional prompt.
image_masks[i] corresponds to embedding_ranges[i].
If None, the prompt is global (applies to entire image).
Shape: (1, 1, img_seq_len).
embedding_ranges: List of ranges indicating which portion of prompt_embeds
corresponds to each regional prompt.
"""
prompt_embeds: torch.Tensor
image_masks: list[torch.Tensor | None]
embedding_ranges: list[Range]
@classmethod
def from_text_conditionings(
cls,
text_conditionings: list[ZImageTextConditioning],
) -> "ZImageRegionalTextConditioning":
"""Create a ZImageRegionalTextConditioning from a list of ZImageTextConditioning objects.
Args:
text_conditionings: List of text conditionings, each with optional mask.
Returns:
A single ZImageRegionalTextConditioning with concatenated embeddings.
"""
concat_embeds: list[torch.Tensor] = []
concat_ranges: list[Range] = []
image_masks: list[torch.Tensor | None] = []
cur_embed_len = 0
for tc in text_conditionings:
concat_embeds.append(tc.prompt_embeds)
concat_ranges.append(Range(start=cur_embed_len, end=cur_embed_len + tc.prompt_embeds.shape[0]))
image_masks.append(tc.mask)
cur_embed_len += tc.prompt_embeds.shape[0]
prompt_embeds = torch.cat(concat_embeds, dim=0)
return cls(
prompt_embeds=prompt_embeds,
image_masks=image_masks,
embedding_ranges=concat_ranges,
)

View File

@@ -0,0 +1,234 @@
"""Utilities for patching the ZImageTransformer2DModel to support regional attention masks."""
from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
def create_regional_forward(
original_forward: Callable,
regional_attn_mask: torch.Tensor,
img_seq_len: int,
) -> Callable:
"""Create a modified forward function that uses a regional attention mask.
The regional attention mask replaces the internally computed padding mask,
allowing for regional prompting where different image regions attend to
different text prompts.
Args:
original_forward: The original forward method of ZImageTransformer2DModel.
regional_attn_mask: Attention mask of shape (seq_len, seq_len) where
seq_len = img_seq_len + txt_seq_len.
img_seq_len: Number of image tokens in the sequence.
Returns:
A modified forward function with regional attention support.
"""
def regional_forward(
self,
x: List[torch.Tensor],
t: torch.Tensor,
cap_feats: List[torch.Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
) -> Tuple[List[torch.Tensor], dict]:
"""Modified forward with regional attention mask injection.
This is based on the original ZImageTransformer2DModel.forward but
replaces the padding-based attention mask with a regional attention mask.
"""
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
bsz = len(x)
device = x[0].device
t_scaled = t * self.t_scale
t_emb = self.t_embedder(t_scaled)
SEQ_MULTI_OF = 32 # From diffusers transformer_z_image.py
# Patchify and embed (reusing the original method)
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# x embed & refine
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x_cat = torch.cat(x, dim=0)
x_cat = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
adaln_input = t_emb.type_as(x_cat)
x_cat[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x_list = list(x_cat.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x_padded = pad_sequence(x_list, batch_first=True, padding_value=0.0)
x_freqs_cis_padded = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
# Process through noise_refiner
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.noise_refiner:
x_padded = self._gradient_checkpointing_func(
layer, x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input
)
else:
for layer in self.noise_refiner:
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input)
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_cat = torch.cat(cap_feats, dim=0)
cap_cat = self.cap_embedder(cap_cat)
cap_cat[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_list = list(cap_cat.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_padded = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
cap_freqs_cis_padded = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
# Process through context_refiner
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_padded = self._gradient_checkpointing_func(layer, cap_padded, cap_attn_mask, cap_freqs_cis_padded)
else:
for layer in self.context_refiner:
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis_padded)
# Unified sequence: [img_tokens, txt_tokens]
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis_padded[i][:x_len], cap_freqs_cis_padded[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=False)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified_padded = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis_padded = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
# --- REGIONAL ATTENTION MASK INJECTION ---
# Instead of using the padding mask, we use the regional attention mask
# The regional mask is (seq_len, seq_len), we need to expand it to (batch, seq_len, seq_len)
# and then add the batch dimension for broadcasting: (batch, 1, seq_len, seq_len)
# Expand regional mask to match the actual sequence length (may include padding)
if regional_attn_mask.shape[0] != unified_max_item_seqlen:
# Pad the regional mask to match unified sequence length
padded_regional_mask = torch.zeros(
(unified_max_item_seqlen, unified_max_item_seqlen),
dtype=regional_attn_mask.dtype,
device=device,
)
mask_size = min(regional_attn_mask.shape[0], unified_max_item_seqlen)
padded_regional_mask[:mask_size, :mask_size] = regional_attn_mask[:mask_size, :mask_size]
else:
padded_regional_mask = regional_attn_mask.to(device)
# Convert boolean mask to additive float mask for attention
# True (attend) -> 0.0, False (block) -> -inf
# This is required because the attention backend expects additive masks for 4D inputs
# Use bfloat16 to match the transformer's query dtype
float_mask = torch.zeros_like(padded_regional_mask, dtype=torch.bfloat16)
float_mask[~padded_regional_mask] = float("-inf")
# Expand to (batch, 1, seq_len, seq_len) for attention
unified_attn_mask = float_mask.unsqueeze(0).unsqueeze(0).expand(bsz, 1, -1, -1)
# Process through main layers with regional attention mask
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer_idx, layer in enumerate(self.layers):
# Alternate between regional mask and full attention
if layer_idx % 2 == 0:
unified_padded = self._gradient_checkpointing_func(
layer, unified_padded, unified_attn_mask, unified_freqs_cis_padded, adaln_input
)
else:
# Use padding mask only for odd layers (allows global coherence)
padding_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
padding_mask[i, :seq_len] = 1
unified_padded = self._gradient_checkpointing_func(
layer, unified_padded, padding_mask, unified_freqs_cis_padded, adaln_input
)
else:
for layer_idx, layer in enumerate(self.layers):
# Alternate between regional mask and full attention
if layer_idx % 2 == 0:
unified_padded = layer(unified_padded, unified_attn_mask, unified_freqs_cis_padded, adaln_input)
else:
# Use padding mask only for odd layers (allows global coherence)
padding_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
padding_mask[i, :seq_len] = 1
unified_padded = layer(unified_padded, padding_mask, unified_freqs_cis_padded, adaln_input)
# Final layer
unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified_padded, adaln_input)
unified_list = list(unified_out.unbind(dim=0))
x_out = self.unpatchify(unified_list, x_size, patch_size, f_patch_size)
return x_out, {}
return regional_forward
@contextmanager
def patch_transformer_for_regional_prompting(
transformer,
regional_attn_mask: Optional[torch.Tensor],
img_seq_len: int,
):
"""Context manager to temporarily patch the transformer for regional prompting.
Args:
transformer: The ZImageTransformer2DModel instance.
regional_attn_mask: Regional attention mask of shape (seq_len, seq_len).
If None, the transformer is not patched.
img_seq_len: Number of image tokens.
Yields:
The (possibly patched) transformer.
"""
if regional_attn_mask is None:
# No regional prompting, use original forward
yield transformer
return
# Store original forward
original_forward = transformer.forward
# Create and bind the regional forward
regional_fwd = create_regional_forward(original_forward, regional_attn_mask, img_seq_len)
transformer.forward = lambda *args, **kwargs: regional_fwd(transformer, *args, **kwargs)
try:
yield transformer
finally:
# Restore original forward
transformer.forward = original_forward

File diff suppressed because it is too large Load Diff

View File

@@ -59,6 +59,17 @@ export const getRegionalGuidanceWarnings = (
}
}
if (model.base === 'z-image') {
// Z-Image has similar limitations to FLUX - no negative prompts via CFG by default
// Reference images (IP Adapters) are not supported for Z-Image
if (entity.referenceImages.length > 0) {
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
}
if (entity.autoNegative) {
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
}
}
entity.referenceImages.forEach(({ config }) => {
if (!config.model) {
// No model selected

View File

@@ -32,8 +32,8 @@ type AddRegionsArg = {
g: Graph;
bbox: Rect;
model: MainModelConfig;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'> | null;
posCondCollect: Invocation<'collect'>;
negCondCollect: Invocation<'collect'> | null;
ipAdapterCollect: Invocation<'collect'>;
@@ -71,6 +71,7 @@ export const addRegions = async ({
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';
const isZImage = model.base === 'z-image';
const validRegions = regions
.filter((entity) => entity.isEnabled)
@@ -111,7 +112,7 @@ export const addRegions = async ({
if (region.positivePrompt) {
// The main positive conditioning node
result.addedPositivePrompt = true;
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'>;
if (isSDXL) {
regionalPosCond = g.addNode({
type: 'sdxl_compel_prompt',
@@ -125,6 +126,12 @@ export const addRegions = async ({
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
});
} else if (isZImage) {
regionalPosCond = g.addNode({
type: 'z_image_text_encoder',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
});
} else {
regionalPosCond = g.addNode({
type: 'compel',
@@ -155,6 +162,12 @@ export const addRegions = async ({
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'z_image_text_encoder') {
for (const edge of g.getEdgesTo(posCond, ['qwen3_encoder', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported positive conditioning node type.');
}
@@ -166,7 +179,7 @@ export const addRegions = async ({
// The main negative conditioning node
result.addedNegativePrompt = true;
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'>;
if (isSDXL) {
regionalNegCond = g.addNode({
type: 'sdxl_compel_prompt',
@@ -180,6 +193,12 @@ export const addRegions = async ({
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
});
} else if (isZImage) {
regionalNegCond = g.addNode({
type: 'z_image_text_encoder',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
});
} else {
regionalNegCond = g.addNode({
type: 'compel',
@@ -211,6 +230,12 @@ export const addRegions = async ({
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else if (negCond.type === 'z_image_text_encoder') {
for (const edge of g.getEdgesTo(negCond, ['qwen3_encoder', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported negative conditioning node type.');
}
@@ -229,7 +254,9 @@ export const addRegions = async ({
// Connect the OG mask image to the inverted mask-to-tensor node
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
let regionalPosCondInverted: Invocation<
'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'z_image_text_encoder'
>;
if (isSDXL) {
regionalPosCondInverted = g.addNode({
type: 'sdxl_compel_prompt',
@@ -243,6 +270,12 @@ export const addRegions = async ({
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
});
} else if (isZImage) {
regionalPosCondInverted = g.addNode({
type: 'z_image_text_encoder',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
});
} else {
regionalPosCondInverted = g.addNode({
type: 'compel',
@@ -273,6 +306,12 @@ export const addRegions = async ({
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'z_image_text_encoder') {
for (const edge of g.getEdgesTo(posCond, ['qwen3_encoder', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported positive conditioning node type.');
}

View File

@@ -10,7 +10,7 @@ export const addZImageLoRAs = (
denoise: Invocation<'z_image_denoise'>,
modelLoader: Invocation<'z_image_model_loader'>,
posCond: Invocation<'z_image_text_encoder'>,
negCond: Invocation<'z_image_text_encoder'>
negCond: Invocation<'z_image_text_encoder'> | null
): void => {
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'z-image');
const loraCount = enabledLoRAs.length;
@@ -39,10 +39,13 @@ export const addZImageLoRAs = (
// Reroute model connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['transformer']);
g.deleteEdgesTo(posCond, ['qwen3_encoder']);
g.deleteEdgesTo(negCond, ['qwen3_encoder']);
g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');
g.addEdge(loraCollectionLoader, 'qwen3_encoder', posCond, 'qwen3_encoder');
g.addEdge(loraCollectionLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
// Only reroute negCond if it exists (guidance_scale > 0)
if (negCond !== null) {
g.deleteEdgesTo(negCond, ['qwen3_encoder']);
g.addEdge(loraCollectionLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
}
for (const lora of enabledLoRAs) {
const { weight } = lora;

View File

@@ -14,6 +14,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { addZImageLoRAs } from 'features/nodes/util/graph/generation/addZImageLoRAs';
@@ -75,12 +76,32 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
type: 'z_image_text_encoder',
id: getPrefixedId('pos_prompt'),
});
// Collect node for regional prompting support
const posCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('pos_cond_collect'),
});
// Z-Image supports negative conditioning when guidance_scale > 0
const negCond = g.addNode({
type: 'z_image_text_encoder',
id: getPrefixedId('neg_prompt'),
prompt: prompts.negative,
// Z-Image supports negative conditioning when guidance_scale > 1
// Only create negative conditioning nodes if guidance is used
let negCond: Invocation<'z_image_text_encoder'> | null = null;
let negCondCollect: Invocation<'collect'> | null = null;
if (guidance_scale > 1) {
negCond = g.addNode({
type: 'z_image_text_encoder',
id: getPrefixedId('neg_prompt'),
prompt: prompts.negative,
});
negCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('neg_cond_collect'),
});
}
// Placeholder collect node for IP adapters (not supported for Z-Image but needed for addRegions)
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collect'),
});
const seed = g.addNode({
@@ -100,17 +121,20 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'qwen3_encoder', posCond, 'qwen3_encoder');
g.addEdge(modelLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
// Connect VAE to denoise for control image encoding
g.addEdge(modelLoader, 'vae', denoise, 'vae');
g.addEdge(positivePrompt, 'value', posCond, 'prompt');
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
// Connect positive conditioning through collector for regional support
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
// Only add negative conditioning edge if guidance_scale > 0
if (guidance_scale > 0) {
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
// Connect negative conditioning if guidance_scale > 1
if (negCond !== null && negCondCollect !== null) {
g.addEdge(modelLoader, 'qwen3_encoder', negCond, 'qwen3_encoder');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
}
g.addEdge(seed, 'value', denoise, 'seed');
@@ -119,9 +143,10 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
// Add Z-Image LoRAs if any are enabled
addZImageLoRAs(state, g, denoise, modelLoader, posCond, negCond);
// Add Z-Image Control layers if any are enabled
// Add regional guidance if canvas manager is available
const canvas = selectCanvasSlice(state);
if (manager !== null) {
const canvas = selectCanvasSlice(state);
// Add Z-Image Control layers if any are enabled
const rect = canvas.bbox.rect;
await addZImageControl({
manager,
@@ -130,8 +155,26 @@ export const buildZImageGraph = async (arg: GraphBuilderArg): Promise<GraphBuild
rect,
denoise,
});
// Add regional guidance
await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect: null, // Not supported for Z-Image
});
}
// IP Adapters are not supported for Z-Image, so delete the unused collector
g.deleteNode(ipAdapterCollect.id);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'z-image');

View File

@@ -25308,6 +25308,11 @@ export type components = {
* @description The name of conditioning tensor
*/
conditioning_name: string;
/**
* @description The mask associated with this conditioning tensor for regional prompting. Excluded regions should be set to False, included regions should be set to True.
* @default null
*/
mask?: components["schemas"]["TensorField"] | null;
};
/**
* ZImageConditioningOutput
@@ -25435,6 +25440,8 @@ export type components = {
/**
* Denoise - Z-Image
* @description Run the denoising process with a Z-Image model.
*
* Supports regional prompting by connecting multiple conditioning inputs with masks.
*/
ZImageDenoiseInvocation: {
/**
@@ -25483,15 +25490,17 @@ export type components = {
*/
transformer?: components["schemas"]["TransformerField"] | null;
/**
* Positive Conditioning
* @description Positive conditioning tensor
* @default null
*/
positive_conditioning?: components["schemas"]["ZImageConditioningField"] | null;
positive_conditioning?: components["schemas"]["ZImageConditioningField"] | components["schemas"]["ZImageConditioningField"][] | null;
/**
* Negative Conditioning
* @description Negative conditioning tensor
* @default null
*/
negative_conditioning?: components["schemas"]["ZImageConditioningField"] | null;
negative_conditioning?: components["schemas"]["ZImageConditioningField"] | components["schemas"]["ZImageConditioningField"][] | null;
/**
* Guidance Scale
* @description Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). Values > 1.0 amplify guidance.
@@ -25848,6 +25857,8 @@ export type components = {
/**
* Prompt - Z-Image
* @description Encodes and preps a prompt for a Z-Image image.
*
* Supports regional prompting by connecting a mask input.
*/
ZImageTextEncoderInvocation: {
/**
@@ -25879,6 +25890,11 @@ export type components = {
* @default null
*/
qwen3_encoder?: components["schemas"]["Qwen3EncoderField"] | null;
/**
* @description A mask defining the region that this conditioning prompt applies to.
* @default null
*/
mask?: components["schemas"]["TensorField"] | null;
/**
* type
* @default z_image_text_encoder