mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Address code review feedback - improve error handling
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
@@ -161,27 +161,26 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
|
||||
# Apply masking if needed
|
||||
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
|
||||
seq_dim = 0
|
||||
if mask_percent > 0 or last_nonnull < prompt_embeds.size(seq_dim) - 1:
|
||||
if mask_percent > 0 or last_nonnull < prompt_embeds.size(0) - 1:
|
||||
seq_len = (
|
||||
last_nonnull + 1
|
||||
if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(seq_dim) - 1
|
||||
else prompt_embeds.size(seq_dim)
|
||||
if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(0) - 1
|
||||
else prompt_embeds.size(0)
|
||||
)
|
||||
|
||||
# Determine mask range
|
||||
if self.mask_starts_at == MaskStartPosition.END:
|
||||
mask_start = seq_len - int(seq_len * mask_percent)
|
||||
mask_end = prompt_embeds.size(seq_dim)
|
||||
mask_end = prompt_embeds.size(0)
|
||||
else: # BEGINNING
|
||||
mask_start = 0
|
||||
mask_end = int(seq_len * mask_percent)
|
||||
|
||||
# Create position-based mask for 2D tensor [seq_len, hidden_size]
|
||||
prompt_mask = (
|
||||
torch.arange(prompt_embeds.size(seq_dim), device=prompt_embeds.device)
|
||||
torch.arange(prompt_embeds.size(0), device=prompt_embeds.device)
|
||||
.unsqueeze(1) # [seq_len, 1]
|
||||
.expand(prompt_embeds.size(seq_dim), prompt_embeds.size(1)) # [seq_len, hidden_size]
|
||||
.expand(prompt_embeds.size(0), prompt_embeds.size(1)) # [seq_len, hidden_size]
|
||||
)
|
||||
prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end)
|
||||
|
||||
@@ -193,7 +192,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
|
||||
null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool)
|
||||
null_mask_tensor = null_mask_tensor.unsqueeze(1).expand(
|
||||
prompt_embeds.size(seq_dim), prompt_embeds.size(1)
|
||||
prompt_embeds.size(0), prompt_embeds.size(1)
|
||||
)
|
||||
prompt_mask = prompt_mask | null_mask_tensor
|
||||
|
||||
@@ -221,18 +220,21 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
|
||||
is_null_list = [0] * tensor.size(0)
|
||||
|
||||
if tensor.dim() == 2:
|
||||
for i in range(tensor.size(0)):
|
||||
sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions)
|
||||
is_all_zero = torch.all(sequence == 0)
|
||||
if tensor.dim() != 2:
|
||||
# Unexpected tensor dimensions - return empty results
|
||||
return first_null, last_nonnull, is_null_list
|
||||
|
||||
is_null_list[i] = 1 if is_all_zero else 0
|
||||
for i in range(tensor.size(0)):
|
||||
sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions)
|
||||
is_all_zero = torch.all(sequence == 0)
|
||||
|
||||
if not is_all_zero:
|
||||
last_nonnull = i
|
||||
is_null_list[i] = 1 if is_all_zero else 0
|
||||
|
||||
if is_all_zero and first_null == -1:
|
||||
first_null = i
|
||||
if not is_all_zero:
|
||||
last_nonnull = i
|
||||
|
||||
if is_all_zero and first_null == -1:
|
||||
first_null = i
|
||||
|
||||
return first_null, last_nonnull, is_null_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user