mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix dimension handling for Z-Image 2D tensors
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
This commit is contained in:
@@ -140,7 +140,10 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def _apply_noise(self, context: InvocationContext, prompt_embeds: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply random noise to prompt embeddings."""
|
||||
"""Apply random noise to prompt embeddings.
|
||||
|
||||
Z-Image uses 2D tensors: [seq_len, hidden_size]
|
||||
"""
|
||||
# Normalize parameters
|
||||
randomize_percent = max(1, min(100, self.randomize_percent)) / 100.0
|
||||
mask_percent = max(0, min(99, self.mask_percent)) / 100.0
|
||||
@@ -157,22 +160,28 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
first_null, last_nonnull, is_null_list = self._find_null_sequences(prompt_embeds)
|
||||
|
||||
# Apply masking if needed
|
||||
if mask_percent > 0 or last_nonnull < prompt_embeds.size(1) - 1:
|
||||
seq_len = last_nonnull + 1 if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(1) - 1 else prompt_embeds.size(1)
|
||||
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
|
||||
seq_dim = 0
|
||||
if mask_percent > 0 or last_nonnull < prompt_embeds.size(seq_dim) - 1:
|
||||
seq_len = (
|
||||
last_nonnull + 1
|
||||
if last_nonnull >= 0 and last_nonnull < prompt_embeds.size(seq_dim) - 1
|
||||
else prompt_embeds.size(seq_dim)
|
||||
)
|
||||
|
||||
# Determine mask range
|
||||
if self.mask_starts_at == MaskStartPosition.END:
|
||||
mask_start = seq_len - int(seq_len * mask_percent)
|
||||
mask_end = prompt_embeds.size(1)
|
||||
mask_end = prompt_embeds.size(seq_dim)
|
||||
else: # BEGINNING
|
||||
mask_start = 0
|
||||
mask_end = int(seq_len * mask_percent)
|
||||
|
||||
# Create position-based mask
|
||||
# Create position-based mask for 2D tensor [seq_len, hidden_size]
|
||||
prompt_mask = (
|
||||
torch.arange(prompt_embeds.size(1), device=prompt_embeds.device)
|
||||
.view(1, -1, 1)
|
||||
.expand(prompt_embeds.size(0), -1, prompt_embeds.size(2))
|
||||
torch.arange(prompt_embeds.size(seq_dim), device=prompt_embeds.device)
|
||||
.unsqueeze(1) # [seq_len, 1]
|
||||
.expand(prompt_embeds.size(seq_dim), prompt_embeds.size(1)) # [seq_len, hidden_size]
|
||||
)
|
||||
prompt_mask = (prompt_mask >= mask_start) & (prompt_mask < mask_end)
|
||||
|
||||
@@ -182,11 +191,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
context.logger.info("Seed Variance Enhancer is masking null sequences from noise")
|
||||
|
||||
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
|
||||
null_mask_tensor = torch.tensor(
|
||||
is_null_list, device=prompt_embeds.device, dtype=torch.bool
|
||||
)
|
||||
null_mask_tensor = null_mask_tensor.view(1, -1, 1).expand(
|
||||
prompt_embeds.size(0), -1, prompt_embeds.size(2)
|
||||
null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool)
|
||||
null_mask_tensor = null_mask_tensor.unsqueeze(1).expand(
|
||||
prompt_embeds.size(seq_dim), prompt_embeds.size(1)
|
||||
)
|
||||
prompt_mask = prompt_mask | null_mask_tensor
|
||||
|
||||
@@ -201,6 +208,8 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
|
||||
def _find_null_sequences(self, tensor: torch.Tensor) -> tuple[int, int, list[int]]:
|
||||
"""Find sequences in tensor that contain all zeros (padding).
|
||||
|
||||
Z-Image uses 2D tensors: [seq_len, hidden_size]
|
||||
|
||||
Returns:
|
||||
Tuple of (first_null_index, last_nonnull_index, is_null_list)
|
||||
@@ -208,11 +217,13 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
"""
|
||||
first_null = -1
|
||||
last_nonnull = -1
|
||||
is_null_list = [0] * tensor.size(1)
|
||||
|
||||
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
|
||||
is_null_list = [0] * tensor.size(0)
|
||||
|
||||
if tensor.dim() == 3:
|
||||
for i in range(tensor.size(1)):
|
||||
sequence = tensor[:, i, ...]
|
||||
if tensor.dim() == 2:
|
||||
for i in range(tensor.size(0)):
|
||||
sequence = tensor[i, :] # Get the i-th sequence (all hidden dimensions)
|
||||
is_all_zero = torch.all(sequence == 0)
|
||||
|
||||
is_null_list[i] = 1 if is_all_zero else 0
|
||||
@@ -226,7 +237,10 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
return first_null, last_nonnull, is_null_list
|
||||
|
||||
def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None:
|
||||
"""Log statistics about the conditioning tensor."""
|
||||
"""Log statistics about the conditioning tensor.
|
||||
|
||||
Z-Image uses 2D tensors: [seq_len, hidden_size]
|
||||
"""
|
||||
if not conditioning_data.conditionings:
|
||||
context.logger.warning("Conditioning data has no conditionings")
|
||||
return
|
||||
@@ -245,8 +259,9 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
first_null, last_nonnull, is_null_list = self._find_null_sequences(tensor)
|
||||
|
||||
# Calculate statistics on non-null portion
|
||||
if last_nonnull < tensor.size(1) - 1 and last_nonnull >= 0:
|
||||
sliced_tensor = tensor[:, : last_nonnull + 1, :]
|
||||
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
|
||||
if last_nonnull < tensor.size(0) - 1 and last_nonnull >= 0:
|
||||
sliced_tensor = tensor[: last_nonnull + 1, :]
|
||||
mean = torch.mean(sliced_tensor).item()
|
||||
std = torch.std(sliced_tensor).item()
|
||||
min_val = torch.min(sliced_tensor).item()
|
||||
|
||||
Reference in New Issue
Block a user