Add support for cfg_scale list on FLUX Denoise node.

This commit is contained in:
Ryan Dick
2024-10-18 20:14:47 +00:00
committed by psychedelicious
parent da171114ea
commit 5df10cc494
2 changed files with 24 additions and 23 deletions

View File

@@ -85,9 +85,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
negative_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add support for cfg_scale to be a list of floats: one for each step.
# TODO(ryand): Add cfg_scale range validation.
cfg_scale: float = InputField(default=3.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(

View File

@@ -1,3 +1,4 @@
import math
from typing import Callable
import torch
@@ -28,7 +29,7 @@ def denoise(
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
cfg_scale: float,
cfg_scale: float | list[float],
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
):
@@ -43,10 +44,9 @@ def denoise(
latents=img,
),
)
step = 1
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
@@ -54,7 +54,7 @@ def denoise(
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
timestep_index=step_index,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
@@ -84,21 +84,24 @@ def denoise(
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
)
# TODO(ryand): Add option to apply controlnet to negative conditioning as well.
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
# systems with sufficient VRAM.
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
)
pred = neg_pred + cfg_scale * (pred - neg_pred)
step_cfg_scale = cfg_scale[step_index] if isinstance(cfg_scale, list) else cfg_scale
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
if not math.isclose(step_cfg_scale, 1.0):
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on
# systems with sufficient VRAM.
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
)
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
@@ -109,13 +112,12 @@ def denoise(
step_callback(
PipelineIntermediateState(
step=step,
step=step_index + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
latents=preview_img,
),
)
step += 1
return img