(style) fix ruff and typegen errors

This commit is contained in:
Lincoln Stein
2025-12-28 12:50:08 -05:00
parent 247130a32a
commit 847ac00e17
2 changed files with 97 additions and 15 deletions

View File

@@ -99,9 +99,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
# Early return if strength is zero
if self.strength == 0:
if self.log_statistics:
context.logger.info(
"Seed Variance Enhancer strength is zero. Passing conditioning through unchanged."
)
context.logger.info("Seed Variance Enhancer strength is zero. Passing conditioning through unchanged.")
self._log_statistics(context, conditioning_data)
return ZImageConditioningOutput(conditioning=self.conditioning)
@@ -141,7 +139,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
def _apply_noise(self, context: InvocationContext, prompt_embeds: torch.Tensor) -> torch.Tensor:
"""Apply random noise to prompt embeddings.
Z-Image uses 2D tensors: [seq_len, hidden_size]
"""
# Normalize parameters
@@ -191,9 +189,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
# Convert is_null_list to tensor: True where sequences should be protected (null sequences)
null_mask_tensor = torch.tensor(is_null_list, device=prompt_embeds.device, dtype=torch.bool)
null_mask_tensor = null_mask_tensor.unsqueeze(1).expand(
prompt_embeds.size(0), prompt_embeds.size(1)
)
null_mask_tensor = null_mask_tensor.unsqueeze(1).expand(prompt_embeds.size(0), prompt_embeds.size(1))
prompt_mask = prompt_mask | null_mask_tensor
# Combine with noise mask
@@ -207,7 +203,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
def _find_null_sequences(self, tensor: torch.Tensor) -> tuple[int, int, list[int]]:
"""Find sequences in tensor that contain all zeros (padding).
Z-Image uses 2D tensors: [seq_len, hidden_size]
Returns:
@@ -216,7 +212,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
"""
first_null = -1
last_nonnull = -1
# For 2D tensor: dimension 0 = seq_len, dimension 1 = hidden_size
is_null_list = [0] * tensor.size(0)
@@ -240,7 +236,7 @@ class SeedVarianceEnhancerInvocation(BaseInvocation):
def _log_statistics(self, context: InvocationContext, conditioning_data: ConditioningFieldData) -> None:
"""Log statistics about the conditioning tensor.
Z-Image uses 2D tensors: [seq_len, hidden_size]
"""
if not conditioning_data.conditionings:

File diff suppressed because one or more lines are too long