Address code review feedback - improve error handling

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-28 16:55:00 +00:00
parent 323cb2dbd0
commit 247130a32a

View File

@@ -161,27 +161,26 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
# Apply masking if needed
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
seq_dim = 0
if mask_percent > 0 or last_nonnull < prompt_embeds.size(seq_dim) - 1:
if mask_percent > 0 or last_nonnull < prompt_embeds.size(0) - 1:
seq_len = (
last_nonnull + 1
if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(seq_dim) - 1
else prompt_embeds.size(seq_dim)
if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(0) - 1
else prompt_embeds.size(0)
)
# Determine mask range
if self.mask_starts_at == MaskStartPosition.END:
mask_start = seq_len - int(seq_len * mask_percent)
mask_end = prompt_embeds.size(seq_dim)
mask_end = prompt_embeds.size(0)
else: # BEGINNING
mask_start = 0
mask_end = int(seq_len * mask_percent)
# Create position-based mask for 2D tensor [seq_len, hidden_size]
prompt_mask = (
torch.arange(prompt_embeds.size(seq_dim), device=prompt_embeds.device)
torch.arange(prompt_embeds.size(0), device=prompt_embeds.device)
.unsqueeze(1) # [seq_len, 1]
.expand(prompt_embeds.size(seq_dim), prompt_embeds.size(1)) # [seq_len, hidden_size]
.expand(prompt_embeds.size(0), prompt_embeds.size(1)) # [seq_len, hidden_size]
)
prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end)
@@ -193,7 +192,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool)
null_mask_tensor = null_mask_tensor.unsqueeze(1).expand(
prompt_embeds.size(seq_dim), prompt_embeds.size(1)
prompt_embeds.size(0), prompt_embeds.size(1)
)
prompt_mask = prompt_mask | null_mask_tensor
@@ -221,18 +220,21 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
is_null_list = [0] * tensor.size(0)
if tensor.dim() == 2:
for i in range(tensor.size(0)):
sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions)
is_all_zero = torch.all(sequence == 0)
if tensor.dim() != 2:
# Unexpected tensor dimensions - return empty results
return first_null, last_nonnull, is_null_list
is_null_list[i] = 1 if is_all_zero else 0
for i in range(tensor.size(0)):
sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions)
is_all_zero = torch.all(sequence == 0)
if not is_all_zero:
last_nonnull = i
is_null_list[i] = 1 if is_all_zero else 0
if is_all_zero and first_null == -1:
first_null = i
if not is_all_zero:
last_nonnull = i
if is_all_zero and first_null == -1:
first_null = i
return first_null, last_nonnull, is_null_list