Fix dimension handling for Z-Image 2D tensors

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-28 16:53:23 +00:00
parent 5f4ef67f92
commit 323cb2dbd0

View File

@@ -140,7 +140,10 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
)
def _apply_noise(self, context: InvocationContext, prompt_embeds: torch.Tensor) -> torch.Tensor:
"""Apply random noise to prompt embeddings."""
"""Apply random noise to prompt embeddings.
Z-Image uses 2D tensors: [seq_len, hidden_size]
"""
# Normalize parameters
randomize_percent = max(1, min(100, self.randomize_percent)) / 100.0
mask_percent = max(0, min(99, self.mask_percent)) / 100.0
@@ -157,22 +160,28 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
first_null, last_nonnull, is_null_list = self._find_null_sequences(prompt_embeds)
# Apply masking if needed
if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1:
seq_len = last_nonnull + 1 if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(1) - 1 else prompt_embeds.size(1)
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
seq_dim = 0
if mask_percent > 0 or last_nonnull < prompt_embeds.size(seq_dim) - 1:
seq_len = (
last_nonnull + 1
if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(seq_dim) - 1
else prompt_embeds.size(seq_dim)
)
# Determine mask range
if self.mask_starts_at == MaskStartPosition.END:
mask_start = seq_len - int(seq_len * mask_percent)
mask_end = prompt_embeds.size(1)
mask_end = prompt_embeds.size(seq_dim)
else: # BEGINNING
mask_start = 0
mask_end = int(seq_len * mask_percent)
# Create position-based mask
# Create position-based mask for 2D tensor [seq_len, hidden_size]
prompt_mask = (
torch.arange(prompt_embeds.size(1), device=prompt_embeds.device)
.view(1, -1, 1)
.expand(prompt_embeds.size(0), -1, prompt_embeds.size(2))
torch.arange(prompt_embeds.size(seq_dim), device=prompt_embeds.device)
.unsqueeze(1) # [seq_len, 1]
.expand(prompt_embeds.size(seq_dim), prompt_embeds.size(1)) # [seq_len, hidden_size]
)
prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end)
@@ -182,11 +191,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
context.logger.info("Seed Variance Enhancer is masking null sequences from noise")
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
null_mask_tensor = torch.tensor(
is_null_list, device=prompt_embeds.device, dtype=torch.bool
)
null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand(
prompt_embeds.size(0), -1, prompt_embeds.size(2)
null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool)
null_mask_tensor = null_mask_tensor.unsqueeze(1).expand(
prompt_embeds.size(seq_dim), prompt_embeds.size(1)
)
prompt_mask = prompt_mask | null_mask_tensor
@@ -201,6 +208,8 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
def _find_null_sequences(self, tensor: torch.Tensor) -> tuple[int, int, list[int]]:
"""Find sequences in tensor that contain all zeros (padding).
Z-Image uses 2D tensors: [seq_len, hidden_size]
Returns:
Tuple of (first_null_index, last_nonnull_index, is_null_list)
@@ -208,11 +217,13 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
"""
first_null = -1
last_nonnull = -1
is_null_list = [0] * tensor.size(1)
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
is_null_list = [0] * tensor.size(0)
if tensor.dim() == 3:
for i in range(tensor.size(1)):
sequence = tensor[:, i, ...]
if tensor.dim() == 2:
for i in range(tensor.size(0)):
sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions)
is_all_zero = torch.all(sequence == 0)
is_null_list[i] = 1 if is_all_zero else 0
@@ -226,7 +237,10 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
return first_null, last_nonnull, is_null_list
def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None:
"""Log statistics about the conditioning tensor."""
"""Log statistics about the conditioning tensor.
Z-Image uses 2D tensors: [seq_len, hidden_size]
"""
if not conditioning_data.conditionings:
context.logger.warning("Conditioning data has no conditionings")
return
@@ -245,8 +259,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
first_null, last_nonnull, is_null_list = self._find_null_sequences(tensor)
# Calculate statistics on non-null portion
if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0:
sliced_tensor = tensor[:, : last_nonnull + 1, :]
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
if last_nonnull < tensor.size(0) - 1 and last_nonnull >= 0:
sliced_tensor = tensor[: last_nonnull + 1, :]
mean = torch.mean(sliced_tensor).item()
std = torch.std(sliced_tensor).item()
min_val = torch.min(sliced_tensor).item()