mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
ruff ruff
This commit is contained in:
committed by
psychedelicious
parent
52dbdb7118
commit
983cb5ebd2
@@ -49,10 +49,10 @@ def denoise(
|
|||||||
)
|
)
|
||||||
# guidance_vec is ignored for schnell.
|
# guidance_vec is ignored for schnell.
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
# Store original sequence length for slicing predictions
|
# Store original sequence length for slicing predictions
|
||||||
original_seq_len = img.shape[1]
|
original_seq_len = img.shape[1]
|
||||||
|
|
||||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
|
||||||
@@ -78,21 +78,23 @@ def denoise(
|
|||||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||||
|
|
||||||
# Prepare input for model - concatenate fresh each step
|
# Prepare input for model - concatenate fresh each step
|
||||||
img_input = img
|
img_input = img
|
||||||
img_input_ids = img_ids
|
img_input_ids = img_ids
|
||||||
|
|
||||||
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
|
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
|
||||||
if img_cond is not None:
|
if img_cond is not None:
|
||||||
img_input = torch.cat((img_input, img_cond), dim=-1)
|
img_input = torch.cat((img_input, img_cond), dim=-1)
|
||||||
|
|
||||||
# Add sequence-wise conditioning (for Kontext)
|
# Add sequence-wise conditioning (for Kontext)
|
||||||
if img_cond_seq is not None:
|
if img_cond_seq is not None:
|
||||||
assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning"
|
assert img_cond_seq_ids is not None, (
|
||||||
|
"You need to provide either both or neither of the sequence conditioning"
|
||||||
|
)
|
||||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||||
|
|
||||||
pred = model(
|
pred = model(
|
||||||
img=img_input,
|
img=img_input,
|
||||||
img_ids=img_input_ids,
|
img_ids=img_input_ids,
|
||||||
@@ -108,7 +110,7 @@ def denoise(
|
|||||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||||
regional_prompting_extension=pos_regional_prompting_extension,
|
regional_prompting_extension=pos_regional_prompting_extension,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Slice prediction to only include the main image tokens
|
# Slice prediction to only include the main image tokens
|
||||||
if img_input_ids is not None:
|
if img_input_ids is not None:
|
||||||
pred = pred[:, :original_seq_len]
|
pred = pred[:, :original_seq_len]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
||||||
|
|||||||
Reference in New Issue
Block a user