From 323cb2dbd0a2998a4e61db32dc493e25c46aab01 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 16:53:23 +0000 Subject: [PATCH] Fix dimension handling for Z-Image 2D tensors Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --- .../app/invocations/seed_variance_enhancer.py | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/invokeai/app/invocations/seed_variance_enhancer.py b/invokeai/app/invocations/seed_variance_enhancer.py index 3ee27b2a3e..8b0d64a390 100644 --- a/invokeai/app/invocations/seed_variance_enhancer.py +++ b/invokeai/app/invocations/seed_variance_enhancer.py @@ -140,7 +140,10 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): ) def _apply_noise(self, context: InvocationContext, prompt_embeds: torch.Tensor) -> torch.Tensor: - """Apply random noise to prompt embeddings.""" + """Apply random noise to prompt embeddings. + + Z-Image uses 2D tensors: [seq_len, hidden_size] + """ # Normalize parameters randomize_percent = max(1, min(100, self.randomize_percent)) / 100.0 mask_percent = max(0, min(99, self.mask_percent)) / 100.0 @@ -157,22 +160,28 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): first_null, last_nonnull, is_null_list = self._find_null_sequences(prompt_embeds) # Apply masking if needed - if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1: - seq_len = last_nonnull + 1 if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(1) - 1 else prompt_embeds.size(1) + # For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size + seq_dim = 0 + if mask_percent > 0 or last_nonnull < prompt_embeds.size(seq_dim) - 1: + seq_len = ( + last_nonnull + 1 + if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(seq_dim) - 1 + else prompt_embeds.size(seq_dim) + ) # Determine mask range if self.mask_starts_at == MaskStartPosition.END: mask_start = seq_len - int(seq_len * mask_percent) - mask_end = prompt_embeds.size(1) + mask_end = prompt_embeds.size(seq_dim) else: # BEGINNING mask_start = 0 mask_end = int(seq_len * mask_percent) - # Create position-based mask + # Create position-based mask for 2D tensor [seq_len, hidden_size] prompt_mask = ( - torch.arange(prompt_embeds.size(1), device=prompt_embeds.device) - .view(1, -1, 1) - .expand(prompt_embeds.size(0), -1, prompt_embeds.size(2)) + torch.arange(prompt_embeds.size(seq_dim), device=prompt_embeds.device) + .unsqueeze(1) # [seq_len, 1] + .expand(prompt_embeds.size(seq_dim), prompt_embeds.size(1)) # [seq_len, hidden_size] ) prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end) @@ -182,11 +191,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): context.logger.info("Seed Variance Enhancer is masking null sequences from noise") # Convert is_null_list to tensor: True where sequences should be protected (null sequences) - null_mask_tensor = torch.tensor( - is_null_list, device=prompt_embeds.device, dtype=torch.bool - ) - null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand( - prompt_embeds.size(0), -1, prompt_embeds.size(2) + null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool) + null_mask_tensor = null_mask_tensor.unsqueeze(1).expand( + prompt_embeds.size(seq_dim), prompt_embeds.size(1) ) prompt_mask = prompt_mask | null_mask_tensor @@ -201,6 +208,8 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): def _find_null_sequences(self, tensor: torch.Tensor) -> tuple[int, int, list[int]]: """Find sequences in tensor that contain all zeros (padding). + + Z-Image uses 2D tensors: [seq_len, hidden_size] Returns: Tuple of (first_null_index, last_nonnull_index, is_null_list) @@ -208,11 +217,13 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): """ first_null = -1 last_nonnull = -1 - is_null_list = [0] * tensor.size(1) + + # For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size + is_null_list = [0] * tensor.size(0) - if tensor.dim() == 3: - for i in range(tensor.size(1)): - sequence = tensor[:, i, ...] + if tensor.dim() == 2: + for i in range(tensor.size(0)): + sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions) is_all_zero = torch.all(sequence == 0) is_null_list[i] = 1 if is_all_zero else 0 @@ -226,7 +237,10 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): return first_null, last_nonnull, is_null_list def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None: - """Log statistics about the conditioning tensor.""" + """Log statistics about the conditioning tensor. + + Z-Image uses 2D tensors: [seq_len, hidden_size] + """ if not conditioning_data.conditionings: context.logger.warning("Conditioning data has no conditionings") return @@ -245,8 +259,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): first_null, last_nonnull, is_null_list = self._find_null_sequences(tensor) # Calculate statistics on non-null portion - if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0: - sliced_tensor = tensor[:, : last_nonnull + 1, :] + # For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size + if last_nonnull < tensor.size(0) - 1 and last_nonnull >= 0: + sliced_tensor = tensor[: last_nonnull + 1, :] mean = torch.mean(sliced_tensor).item() std = torch.std(sliced_tensor).item() min_val = torch.min(sliced_tensor).item()