ruff ruff

This commit is contained in:
Kent Keirsey
2025-07-03 14:17:53 -04:00
committed by psychedelicious
parent 52dbdb7118
commit 983cb5ebd2
2 changed files with 11 additions and 9 deletions

View File

@@ -49,10 +49,10 @@ def denoise(
)
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -78,21 +78,23 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning"
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
img=img_input,
img_ids=img_input_ids,
@@ -108,7 +110,7 @@ def denoise(
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
# Slice prediction to only include the main image tokens
if img_input_ids is not None:
pred = pred[:, :original_seq_len]

View File

@@ -1,7 +1,7 @@
import einops
import numpy as np
import torch
from einops import repeat
import numpy as np
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField