diff --git a/invokeai/app/invocations/seed_variance_enhancer.py b/invokeai/app/invocations/seed_variance_enhancer.py index d08f51ec4a..3ee27b2a3e 100644 --- a/invokeai/app/invocations/seed_variance_enhancer.py +++ b/invokeai/app/invocations/seed_variance_enhancer.py @@ -20,6 +20,13 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ZImageConditioningInfo, ) +# Seed offset for noise mask generation (v2.2 behavior for consistency with prompt variations) +NOISE_MASK_SEED_OFFSET = 1 + +# Factors for suggested strength calculation based on embedding standard deviation +MIN_STRENGTH_FACTOR = 0.1 +MAX_STRENGTH_FACTOR = 10.0 + class MaskStartPosition(str, Enum): """Which end of the prompt will be protected from noise.""" @@ -142,12 +149,12 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): torch.manual_seed(self.seed) noise = torch.rand_like(prompt_embeds) * 2 * self.strength - self.strength - # Reset seed for value selection (v2.2 behavior for consistency with prompt variations) - torch.manual_seed(self.seed + 1) - noise_mask = torch.bernoulli(torch.ones_like(prompt_embeds) * randomize_percent).bool() + # Reset seed for value selection to ensure consistency with prompt variations + torch.manual_seed(self.seed + NOISE_MASK_SEED_OFFSET) + noise_mask = torch.bernoulli(torch.full_like(prompt_embeds, randomize_percent)).bool() # Check for null sequences (padding) - first_null, last_nonnull, null_sequences = self._find_null_sequences(prompt_embeds) + first_null, last_nonnull, is_null_list = self._find_null_sequences(prompt_embeds) # Apply masking if needed if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1: @@ -174,8 +181,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): if self.log_statistics: context.logger.info("Seed Variance Enhancer is masking null sequences from noise") - null_mask_tensor = ~torch.tensor( - null_sequences, device=prompt_embeds.device, dtype=torch.bool + # Convert is_null_list to tensor: True where sequences should be protected (null sequences) + null_mask_tensor = torch.tensor( + is_null_list, device=prompt_embeds.device, dtype=torch.bool ) null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand( prompt_embeds.size(0), -1, prompt_embeds.size(2) @@ -195,18 +203,19 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): """Find sequences in tensor that contain all zeros (padding). Returns: - Tuple of (first_null_index, last_nonnull_index, null_sequences_list) + Tuple of (first_null_index, last_nonnull_index, is_null_list) + where is_null_list contains 1 for null sequences and 0 for non-null sequences """ first_null = -1 last_nonnull = -1 - null_sequences = [0] * tensor.size(1) + is_null_list = [0] * tensor.size(1) if tensor.dim() == 3: for i in range(tensor.size(1)): sequence = tensor[:, i, ...] is_all_zero = torch.all(sequence == 0) - null_sequences[i] = 0 if is_all_zero else 1 + is_null_list[i] = 1 if is_all_zero else 0 if not is_all_zero: last_nonnull = i @@ -214,7 +223,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): if is_all_zero and first_null == -1: first_null = i - return first_null, last_nonnull, null_sequences + return first_null, last_nonnull, is_null_list def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None: """Log statistics about the conditioning tensor.""" @@ -233,7 +242,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): return # Find null sequences - first_null, last_nonnull, null_sequences = self._find_null_sequences(tensor) + first_null, last_nonnull, is_null_list = self._find_null_sequences(tensor) # Calculate statistics on non-null portion if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0: @@ -252,10 +261,12 @@ class SeedVarianceEnhancerInvocation(BaseInvocation): context.logger.info(f"Dimensions: {list(tensor.shape)}") context.logger.info(f"Min: {min_val:.6f}, Max: {max_val:.6f}") context.logger.info(f"Mean: {mean:.6f}, Std Dev: {std:.6f}") - context.logger.info(f"Suggested strength range: {std/10:.6f} - {std*10:.6f}") + context.logger.info( + f"Suggested strength range: {std * MIN_STRENGTH_FACTOR:.6f} - {std * MAX_STRENGTH_FACTOR:.6f}" + ) if first_null != -1: - num_null = sum(1 for x in null_sequences if x == 0) + num_null = sum(1 for x in is_null_list if x == 1) context.logger.info( f"Null sequences: First at {first_null}, Last non-null at {last_nonnull}, Total null: {num_null}" )