mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:58:13 -05:00
Add support for cfg_scale list on FLUX Denoise node.
This commit is contained in:
committed by
psychedelicious
parent
da171114ea
commit
5df10cc494
@@ -85,9 +85,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
negative_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
# TODO(ryand): Add support for cfg_scale to be a list of floats: one for each step.
|
||||
# TODO(ryand): Add cfg_scale range validation.
|
||||
cfg_scale: float = InputField(default=3.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@@ -28,7 +29,7 @@ def denoise(
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
cfg_scale: float,
|
||||
cfg_scale: float | list[float],
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
):
|
||||
@@ -43,10 +44,9 @@ def denoise(
|
||||
latents=img,
|
||||
),
|
||||
)
|
||||
step = 1
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run ControlNet models.
|
||||
@@ -54,7 +54,7 @@ def denoise(
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step - 1,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -84,21 +84,24 @@ def denoise(
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
)
|
||||
|
||||
# TODO(ryand): Add option to apply controlnet to negative conditioning as well.
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
|
||||
# systems with sufficient VRAM.
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
)
|
||||
pred = neg_pred + cfg_scale * (pred - neg_pred)
|
||||
step_cfg_scale = cfg_scale[step_index] if isinstance(cfg_scale, list) else cfg_scale
|
||||
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
|
||||
# systems with sufficient VRAM.
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
@@ -109,13 +112,12 @@ def denoise(
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step,
|
||||
step=step_index + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
step += 1
|
||||
|
||||
return img
|
||||
|
||||
Reference in New Issue
Block a user