mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Address code review feedback - fix logic and add constants
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,13 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ZImageConditioningInfo,
|
||||
)
|
||||
|
||||
# Seed offset for noise mask generation (v2.2 behavior for consistency with prompt variations)
|
||||
NOISE_MASK_SEED_OFFSET = 1
|
||||
|
||||
# Factors for suggested strength calculation based on embedding standard deviation
|
||||
MIN_STRENGTH_FACTOR = 0.1
|
||||
MAX_STRENGTH_FACTOR = 10.0
|
||||
|
||||
|
||||
class MaskStartPosition(str, Enum):
|
||||
"""Which end of the prompt will be protected from noise."""
|
||||
@@ -142,12 +149,12 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
torch.manual_seed(self.seed)
|
||||
noise = torch.rand_like(prompt_embeds) * 2 * self.strength - self.strength
|
||||
|
||||
# Reset seed for value selection (v2.2 behavior for consistency with prompt variations)
|
||||
torch.manual_seed(self.seed + 1)
|
||||
noise_mask = torch.bernoulli(torch.ones_like(prompt_embeds) * randomize_percent).bool()
|
||||
# Reset seed for value selection to ensure consistency with prompt variations
|
||||
torch.manual_seed(self.seed + NOISE_MASK_SEED_OFFSET)
|
||||
noise_mask = torch.bernoulli(torch.full_like(prompt_embeds, randomize_percent)).bool()
|
||||
|
||||
# Check for null sequences (padding)
|
||||
first_null, last_nonnull, null_sequences = self._find_null_sequences(prompt_embeds)
|
||||
first_null, last_nonnull, is_null_list = self._find_null_sequences(prompt_embeds)
|
||||
|
||||
# Apply masking if needed
|
||||
if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1:
|
||||
@@ -174,8 +181,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
if self.log_statistics:
|
||||
context.logger.info("Seed Variance Enhancer is masking null sequences from noise")
|
||||
|
||||
null_mask_tensor = ~torch.tensor(
|
||||
null_sequences, device=prompt_embeds.device, dtype=torch.bool
|
||||
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
|
||||
null_mask_tensor = torch.tensor(
|
||||
is_null_list, device=prompt_embeds.device, dtype=torch.bool
|
||||
)
|
||||
null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand(
|
||||
prompt_embeds.size(0), -1, prompt_embeds.size(2)
|
||||
@@ -195,18 +203,19 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
"""Find sequences in tensor that contain all zeros (padding).
|
||||
|
||||
Returns:
|
||||
Tuple of (first_null_index, last_nonnull_index, null_sequences_list)
|
||||
Tuple of (first_null_index, last_nonnull_index, is_null_list)
|
||||
where is_null_list contains 1 for null sequences and 0 for non-null sequences
|
||||
"""
|
||||
first_null = -1
|
||||
last_nonnull = -1
|
||||
null_sequences = [0] * tensor.size(1)
|
||||
is_null_list = [0] * tensor.size(1)
|
||||
|
||||
if tensor.dim() == 3:
|
||||
for i in range(tensor.size(1)):
|
||||
sequence = tensor[:, i, ...]
|
||||
is_all_zero = torch.all(sequence == 0)
|
||||
|
||||
null_sequences[i] = 0 if is_all_zero else 1
|
||||
is_null_list[i] = 1 if is_all_zero else 0
|
||||
|
||||
if not is_all_zero:
|
||||
last_nonnull = i
|
||||
@@ -214,7 +223,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
if is_all_zero and first_null == -1:
|
||||
first_null = i
|
||||
|
||||
return first_null, last_nonnull, null_sequences
|
||||
return first_null, last_nonnull, is_null_list
|
||||
|
||||
def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None:
|
||||
"""Log statistics about the conditioning tensor."""
|
||||
@@ -233,7 +242,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
return
|
||||
|
||||
# Find null sequences
|
||||
first_null, last_nonnull, null_sequences = self._find_null_sequences(tensor)
|
||||
first_null, last_nonnull, is_null_list = self._find_null_sequences(tensor)
|
||||
|
||||
# Calculate statistics on non-null portion
|
||||
if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0:
|
||||
@@ -252,10 +261,12 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
context.logger.info(f"Dimensions: {list(tensor.shape)}")
|
||||
context.logger.info(f"Min: {min_val:.6f}, Max: {max_val:.6f}")
|
||||
context.logger.info(f"Mean: {mean:.6f}, Std Dev: {std:.6f}")
|
||||
context.logger.info(f"Suggested strength range: {std/10:.6f} - {std*10:.6f}")
|
||||
context.logger.info(
|
||||
f"Suggested strength range: {std * MIN_STRENGTH_FACTOR:.6f} - {std * MAX_STRENGTH_FACTOR:.6f}"
|
||||
)
|
||||
|
||||
if first_null != -1:
|
||||
num_null = sum(1 for x in null_sequences if x == 0)
|
||||
num_null = sum(1 for x in is_null_list if x == 1)
|
||||
context.logger.info(
|
||||
f"Null sequences: First at {first_null}, Last non-null at {last_nonnull}, Total null: {num_null}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user