mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add Seed Variance Enhancer invocation for Z-Image
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
286
invokeai/app/invocations/seed_variance_enhancer.py
Normal file
286
invokeai/app/invocations/seed_variance_enhancer.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Seed Variance Enhancer for Z-Image conditioning.
|
||||
|
||||
Adds controlled random noise to conditioning embeddings to increase output diversity,
|
||||
particularly useful for Z-Image models with low seed variance.
|
||||
|
||||
Based on the ComfyUI SeedVarianceEnhancer node by ChangeTheConstants.
|
||||
Released under MIT No Attribution License.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import Input, InputField, ZImageConditioningField
|
||||
from invokeai.app.invocations.primitives import ZImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningFieldData,
|
||||
ZImageConditioningInfo,
|
||||
)
|
||||
|
||||
|
||||
class NoiseInsertMode(str, Enum):
|
||||
"""When to apply noise during the generation process."""
|
||||
|
||||
BEGINNING = "noise on beginning steps"
|
||||
ENDING = "noise on ending steps"
|
||||
ALL = "noise on all steps"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class MaskStartPosition(str, Enum):
|
||||
"""Which end of the prompt will be protected from noise."""
|
||||
|
||||
BEGINNING = "beginning"
|
||||
END = "end"
|
||||
|
||||
|
||||
@invocation(
|
||||
"seed_variance_enhancer",
|
||||
title="Seed Variance Enhancer - Z-Image",
|
||||
tags=["conditioning", "z-image", "variance", "seed", "noise"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
"""Adds random noise to Z-Image conditioning embeddings to increase output diversity.
|
||||
|
||||
This node compensates for low seed variance by adding controlled noise to the conditioning
|
||||
embeddings during specific steps of generation. Works specifically with Z-Image models.
|
||||
|
||||
The noise can be applied to beginning steps, ending steps, or all steps. Applying noise
|
||||
only to beginning steps (default) allows the model to pivot back toward prompt adherence.
|
||||
|
||||
Masking features allow protecting portions of the prompt from noise exposure.
|
||||
"""
|
||||
|
||||
conditioning: ZImageConditioningField = InputField(
|
||||
description="The Z-Image conditioning to enhance with variance.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
randomize_percent: float = InputField(
|
||||
default=50.0,
|
||||
ge=1.0,
|
||||
le=100.0,
|
||||
description="Percentage of embedding values to which random noise is added.",
|
||||
)
|
||||
strength: float = InputField(
|
||||
default=20.0,
|
||||
description="Scale of the random noise. Typical range: 15-40 for Z-Image.",
|
||||
)
|
||||
noise_insert: NoiseInsertMode = InputField(
|
||||
default=NoiseInsertMode.BEGINNING,
|
||||
description="Which steps of generation process use the noisy embedding.",
|
||||
)
|
||||
steps_switchover_percent: float = InputField(
|
||||
default=20.0,
|
||||
ge=1.0,
|
||||
le=99.0,
|
||||
description="Percentage of steps before switching between noisy and original embeddings. "
|
||||
"Formula: (100/TOTAL_STEPS) * STEPS - 1",
|
||||
)
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Random seed for noise generation and value selection.",
|
||||
)
|
||||
mask_starts_at: MaskStartPosition = InputField(
|
||||
default=MaskStartPosition.BEGINNING,
|
||||
description="Which end of prompt will be protected from noise.",
|
||||
)
|
||||
mask_percent: float = InputField(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=99.0,
|
||||
description="Percentage of prompt protected from noise. 0 = no masking.",
|
||||
)
|
||||
log_statistics: bool = InputField(
|
||||
default=False,
|
||||
description="Log embedding statistics to console for debugging.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
|
||||
# Load the conditioning data
|
||||
conditioning_data = context.conditioning.load(self.conditioning.conditioning_name)
|
||||
|
||||
# Early return if disabled
|
||||
if self.noise_insert == NoiseInsertMode.DISABLED:
|
||||
if self.log_statistics:
|
||||
context.logger.info("Seed Variance Enhancer is disabled. Passing conditioning through unchanged.")
|
||||
self._log_statistics(context, conditioning_data)
|
||||
return ZImageConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
# Early return if strength is zero
|
||||
if self.strength == 0:
|
||||
if self.log_statistics:
|
||||
context.logger.info(
|
||||
"Seed Variance Enhancer strength is zero. Passing conditioning through unchanged."
|
||||
)
|
||||
self._log_statistics(context, conditioning_data)
|
||||
return ZImageConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
# Validate we have Z-Image conditioning
|
||||
if not conditioning_data.conditionings or len(conditioning_data.conditionings) == 0:
|
||||
context.logger.warning("Seed Variance Enhancer received empty conditioning.")
|
||||
return ZImageConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
conditioning_info = conditioning_data.conditionings[0]
|
||||
if not isinstance(conditioning_info, ZImageConditioningInfo):
|
||||
context.logger.warning(
|
||||
f"Seed Variance Enhancer expected Z-Image conditioning, got {type(conditioning_info).__name__}"
|
||||
)
|
||||
return ZImageConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
# Get the prompt embeddings tensor
|
||||
prompt_embeds = conditioning_info.prompt_embeds
|
||||
|
||||
if self.log_statistics:
|
||||
self._log_statistics(context, conditioning_data)
|
||||
|
||||
# Apply noise to the embeddings
|
||||
noisy_prompt_embeds = self._apply_noise(context, prompt_embeds)
|
||||
|
||||
# Create new conditioning with noisy embeddings
|
||||
new_conditioning_info = ZImageConditioningInfo(prompt_embeds=noisy_prompt_embeds)
|
||||
new_conditioning_data = ConditioningFieldData(conditionings=[new_conditioning_info])
|
||||
|
||||
# Save and return
|
||||
conditioning_name = context.conditioning.save(new_conditioning_data)
|
||||
return ZImageConditioningOutput(
|
||||
conditioning=ZImageConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.conditioning.mask,
|
||||
)
|
||||
)
|
||||
|
||||
def _apply_noise(self, context: InvocationContext, prompt_embeds: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply random noise to prompt embeddings."""
|
||||
# Normalize parameters
|
||||
randomize_percent = max(1, min(100, self.randomize_percent)) / 100.0
|
||||
mask_percent = max(0, min(99, self.mask_percent)) / 100.0
|
||||
|
||||
# Generate noise
|
||||
torch.manual_seed(self.seed)
|
||||
noise = torch.rand_like(prompt_embeds) * 2 * self.strength - self.strength
|
||||
|
||||
# Reset seed for value selection (v2.2 behavior for consistency with prompt variations)
|
||||
torch.manual_seed(self.seed + 1)
|
||||
noise_mask = torch.bernoulli(torch.ones_like(prompt_embeds) * randomize_percent).bool()
|
||||
|
||||
# Check for null sequences (padding)
|
||||
first_null, last_nonnull, null_sequences = self._find_null_sequences(prompt_embeds)
|
||||
|
||||
# Apply masking if needed
|
||||
if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1:
|
||||
seq_len = last_nonnull + 1 if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(1) - 1 else prompt_embeds.size(1)
|
||||
|
||||
# Determine mask range
|
||||
if self.mask_starts_at == MaskStartPosition.END:
|
||||
mask_start = seq_len - int(seq_len * mask_percent)
|
||||
mask_end = prompt_embeds.size(1)
|
||||
else: # BEGINNING
|
||||
mask_start = 0
|
||||
mask_end = int(seq_len * mask_percent)
|
||||
|
||||
# Create position-based mask
|
||||
prompt_mask = (
|
||||
torch.arange(prompt_embeds.size(1), device=prompt_embeds.device)
|
||||
.view(1, -1, 1)
|
||||
.expand(prompt_embeds.size(0), -1, prompt_embeds.size(2))
|
||||
)
|
||||
prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end)
|
||||
|
||||
# Include null sequences in protected region
|
||||
if first_null > -1:
|
||||
if self.log_statistics:
|
||||
context.logger.info("Seed Variance Enhancer is masking null sequences from noise")
|
||||
|
||||
null_mask_tensor = ~torch.tensor(
|
||||
null_sequences, device=prompt_embeds.device, dtype=torch.bool
|
||||
)
|
||||
null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand(
|
||||
prompt_embeds.size(0), -1, prompt_embeds.size(2)
|
||||
)
|
||||
prompt_mask = prompt_mask | null_mask_tensor
|
||||
|
||||
# Combine with noise mask
|
||||
noise_mask = noise_mask & (~prompt_mask)
|
||||
|
||||
# Apply masked noise
|
||||
modified_noise = noise * noise_mask
|
||||
noisy_embeds = prompt_embeds + modified_noise
|
||||
|
||||
return noisy_embeds
|
||||
|
||||
def _find_null_sequences(self, tensor: torch.Tensor) -> tuple[int, int, list[int]]:
|
||||
"""Find sequences in tensor that contain all zeros (padding).
|
||||
|
||||
Returns:
|
||||
Tuple of (first_null_index, last_nonnull_index, null_sequences_list)
|
||||
"""
|
||||
first_null = -1
|
||||
last_nonnull = -1
|
||||
null_sequences = [0] * tensor.size(1)
|
||||
|
||||
if tensor.dim() == 3:
|
||||
for i in range(tensor.size(1)):
|
||||
sequence = tensor[:, i, ...]
|
||||
is_all_zero = torch.all(sequence == 0)
|
||||
|
||||
null_sequences[i] = 0 if is_all_zero else 1
|
||||
|
||||
if not is_all_zero:
|
||||
last_nonnull = i
|
||||
|
||||
if is_all_zero and first_null == -1:
|
||||
first_null = i
|
||||
|
||||
return first_null, last_nonnull, null_sequences
|
||||
|
||||
def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None:
|
||||
"""Log statistics about the conditioning tensor."""
|
||||
if not conditioning_data.conditionings:
|
||||
context.logger.warning("Conditioning data has no conditionings")
|
||||
return
|
||||
|
||||
conditioning_info = conditioning_data.conditionings[0]
|
||||
if not isinstance(conditioning_info, ZImageConditioningInfo):
|
||||
context.logger.warning(f"Expected ZImageConditioningInfo, got {type(conditioning_info).__name__}")
|
||||
return
|
||||
|
||||
tensor = conditioning_info.prompt_embeds
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
context.logger.warning("Conditioning does not contain a tensor")
|
||||
return
|
||||
|
||||
# Find null sequences
|
||||
first_null, last_nonnull, null_sequences = self._find_null_sequences(tensor)
|
||||
|
||||
# Calculate statistics on non-null portion
|
||||
if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0:
|
||||
sliced_tensor = tensor[:, : last_nonnull + 1, :]
|
||||
mean = torch.mean(sliced_tensor).item()
|
||||
std = torch.std(sliced_tensor).item()
|
||||
min_val = torch.min(sliced_tensor).item()
|
||||
max_val = torch.max(sliced_tensor).item()
|
||||
else:
|
||||
mean = torch.mean(tensor).item()
|
||||
std = torch.std(tensor).item()
|
||||
min_val = torch.min(tensor).item()
|
||||
max_val = torch.max(tensor).item()
|
||||
|
||||
context.logger.info("=== Seed Variance Enhancer - Embedding Statistics ===")
|
||||
context.logger.info(f"Dimensions: {list(tensor.shape)}")
|
||||
context.logger.info(f"Min: {min_val:.6f}, Max: {max_val:.6f}")
|
||||
context.logger.info(f"Mean: {mean:.6f}, Std Dev: {std:.6f}")
|
||||
context.logger.info(f"Suggested strength range: {std/10:.6f} - {std*10:.6f}")
|
||||
|
||||
if first_null != -1:
|
||||
num_null = sum(1 for x in null_sequences if x == 0)
|
||||
context.logger.info(
|
||||
f"Null sequences: First at {first_null}, Last non-null at {last_nonnull}, Total null: {num_null}"
|
||||
)
|
||||
Reference in New Issue
Block a user