Address code review feedback - fix logic and add constants

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-28 14:35:32 +00:00
parent 7d65cdfc16
commit 5f4ef67f92

View File

@@ -20,6 +20,13 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ZImageConditioningInfo,
)
# Seed offset for noise mask generation (v2.2 behavior for consistency with prompt variations)
NOISE_MASK_SEED_OFFSET = 1
# Factors for suggested strength calculation based on embedding standard deviation
MIN_STRENGTH_FACTOR = 0.1
MAX_STRENGTH_FACTOR = 10.0
class MaskStartPosition(str, Enum):
"""Which end of the prompt will be protected from noise."""
@@ -142,12 +149,12 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
torch.manual_seed(self.seed)
noise = torch.rand_like(prompt_embeds) * 2 * self.strength - self.strength
# Reset seed for value selection (v2.2 behavior for consistency with prompt variations)
torch.manual_seed(self.seed + 1)
noise_mask = torch.bernoulli(torch.ones_like(prompt_embeds) * randomize_percent).bool()
# Reset seed for value selection to ensure consistency with prompt variations
torch.manual_seed(self.seed + NOISE_MASK_SEED_OFFSET)
noise_mask = torch.bernoulli(torch.full_like(prompt_embeds, randomize_percent)).bool()
# Check for null sequences (padding)
first_null, last_nonnull, null_sequences = self._find_null_sequences(prompt_embeds)
first_null, last_nonnull, is_null_list = self._find_null_sequences(prompt_embeds)
# Apply masking if needed
if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1:
@@ -174,8 +181,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
if self.log_statistics:
context.logger.info("Seed Variance Enhancer is masking null sequences from noise")
null_mask_tensor = ~torch.tensor(
null_sequences, device=prompt_embeds.device, dtype=torch.bool
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
null_mask_tensor = torch.tensor(
is_null_list, device=prompt_embeds.device, dtype=torch.bool
)
null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand(
prompt_embeds.size(0), -1, prompt_embeds.size(2)
@@ -195,18 +203,19 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
"""Find sequences in tensor that contain all zeros (padding).
Returns:
Tuple of (first_null_index, last_nonnull_index, null_sequences_list)
Tuple of (first_null_index, last_nonnull_index, is_null_list)
where is_null_list contains 1 for null sequences and 0 for non-null sequences
"""
first_null = -1
last_nonnull = -1
null_sequences = [0] * tensor.size(1)
is_null_list = [0] * tensor.size(1)
if tensor.dim() == 3:
for i in range(tensor.size(1)):
sequence = tensor[:, i, ...]
is_all_zero = torch.all(sequence == 0)
null_sequences[i] = 0 if is_all_zero else 1
is_null_list[i] = 1 if is_all_zero else 0
if not is_all_zero:
last_nonnull = i
@@ -214,7 +223,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
if is_all_zero and first_null == -1:
first_null = i
return first_null, last_nonnull, null_sequences
return first_null, last_nonnull, is_null_list
def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None:
"""Log statistics about the conditioning tensor."""
@@ -233,7 +242,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
return
# Find null sequences
first_null, last_nonnull, null_sequences = self._find_null_sequences(tensor)
first_null, last_nonnull, is_null_list = self._find_null_sequences(tensor)
# Calculate statistics on non-null portion
if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0:
@@ -252,10 +261,12 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
context.logger.info(f"Dimensions: {list(tensor.shape)}")
context.logger.info(f"Min: {min_val:.6f}, Max: {max_val:.6f}")
context.logger.info(f"Mean: {mean:.6f}, Std Dev: {std:.6f}")
context.logger.info(f"Suggested strength range: {std/10:.6f} - {std*10:.6f}")
context.logger.info(
f"Suggested strength range: {std * MIN_STRENGTH_FACTOR:.6f} - {std * MAX_STRENGTH_FACTOR:.6f}"
)
if first_null != -1:
num_null = sum(1 for x in null_sequences if x == 0)
num_null = sum(1 for x in is_null_list if x == 1)
context.logger.info(
f"Null sequences: First at {first_null}, Last non-null at {last_nonnull}, Total null: {num_null}"
)