From 247130a32ac66b32327ebfcfdd7a5168154b5ccf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 16:55:00 +0000 Subject: [PATCH] Address code review feedback - improve error handling Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --- .../app/invocations/seed_variance_enhancer.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/invokeai/app/invocations/seed_variance_enhancer.py b/invokeai/app/invocations/seed_variance_enhancer.py index 8b0d64a390..d7a1655fff 100644 --- a/invokeai/app/invocations/seed_variance_enhancer.py +++ b/invokeai/app/invocations/seed_variance_enhancer.py @@ -161,27 +161,26 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): # Apply masking if needed # For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size - seq_dim = 0 - if mask_percent > 0 or last_nonnull < prompt_embeds.size(seq_dim) - 1: + if mask_percent > 0 or last_nonnull < prompt_embeds.size(0) - 1: seq_len = ( last_nonnull + 1 - if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(seq_dim) - 1 - else prompt_embeds.size(seq_dim) + if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(0) - 1 + else prompt_embeds.size(0) ) # Determine mask range if self.mask_starts_at == MaskStartPosition.END: mask_start = seq_len - int(seq_len * mask_percent) - mask_end = prompt_embeds.size(seq_dim) + mask_end = prompt_embeds.size(0) else: # BEGINNING mask_start = 0 mask_end = int(seq_len * mask_percent) # Create position-based mask for 2D tensor [seq_len, hidden_size] prompt_mask = ( - torch.arange(prompt_embeds.size(seq_dim), device=prompt_embeds.device) + torch.arange(prompt_embeds.size(0), device=prompt_embeds.device) .unsqueeze(1) # [seq_len, 1] - .expand(prompt_embeds.size(seq_dim), prompt_embeds.size(1)) # [seq_len, hidden_size] + .expand(prompt_embeds.size(0), prompt_embeds.size(1)) # [seq_len, hidden_size] ) prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end) @@ -193,7 +192,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): # Convert is_null_list to tensor: True where sequences should be protected (null sequences) null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool) null_mask_tensor = null_mask_tensor.unsqueeze(1).expand( - prompt_embeds.size(seq_dim), prompt_embeds.size(1) + prompt_embeds.size(0), prompt_embeds.size(1) ) prompt_mask = prompt_mask | null_mask_tensor @@ -221,18 +220,21 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): # For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size is_null_list = [0] * tensor.size(0) - if tensor.dim() == 2: - for i in range(tensor.size(0)): - sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions) - is_all_zero = torch.all(sequence == 0) + if tensor.dim() != 2: + # Unexpected tensor dimensions - return empty results + return first_null, last_nonnull, is_null_list - is_null_list[i] = 1 if is_all_zero else 0 + for i in range(tensor.size(0)): + sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions) + is_all_zero = torch.all(sequence == 0) - if not is_all_zero: - last_nonnull = i + is_null_list[i] = 1 if is_all_zero else 0 - if is_all_zero and first_null == -1: - first_null = i + if not is_all_zero: + last_nonnull = i + + if is_all_zero and first_null == -1: + first_null = i return first_null, last_nonnull, is_null_list