ruff ruff

This commit is contained in:
Kent Keirsey
2025-07-03 14:17:53 -04:00
committed by psychedelicious
parent 52dbdb7118
commit 983cb5ebd2
2 changed files with 11 additions and 9 deletions

View File

@@ -49,10 +49,10 @@ def denoise(
) )
# guidance_vec is ignored for schnell. # guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions # Store original sequence length for slicing predictions
original_seq_len = img.shape[1] original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -78,21 +78,23 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same # controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance. # tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model - concatenate fresh each step # Prepare input for model - concatenate fresh each step
img_input = img img_input = img
img_input_ids = img_ids img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.) # Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None: if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1) img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (for Kontext) # Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None: if img_cond_seq is not None:
assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning" assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1) img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model( pred = model(
img=img_input, img=img_input,
img_ids=img_input_ids, img_ids=img_input_ids,
@@ -108,7 +110,7 @@ def denoise(
ip_adapter_extensions=pos_ip_adapter_extensions, ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension, regional_prompting_extension=pos_regional_prompting_extension,
) )
# Slice prediction to only include the main image tokens # Slice prediction to only include the main image tokens
if img_input_ids is not None: if img_input_ids is not None:
pred = pred[:, :original_seq_len] pred = pred[:, :original_seq_len]

View File

@@ -1,7 +1,7 @@
import einops import einops
import numpy as np
import torch import torch
from einops import repeat from einops import repeat
import numpy as np
from PIL import Image from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField from invokeai.app.invocations.fields import FluxKontextConditioningField