Add cfg_scale_start_step and cfg_scale_end_step to FLUX Denoise node.

This commit is contained in:
Ryan Dick
2024-10-21 14:52:02 +00:00
committed by psychedelicious
parent 20362448b9
commit d20b894a61
3 changed files with 133 additions and 4 deletions

View File

@@ -87,8 +87,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
# TODO(ryand): Add cfg_scale range validation.
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale_start_step: int = InputField(
default=0,
title="CFG Scale Start Step",
description="Index of the first step to apply cfg_scale. Negative indices count backwards from the "
+ "the last step (e.g. a value of -1 refers to the final step).",
)
cfg_scale_end_step: int = InputField(
default=-1,
title="CFG Scale End Step",
description="Index of the last step to apply cfg_scale. Negative indices count backwards from the "
+ "last step (e.g. a value of -1 refers to the final step).",
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
@@ -235,6 +246,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
cfg_scale = self.prep_cfg_scale(
cfg_scale=self.cfg_scale,
timesteps=timesteps,
cfg_scale_start_step=self.cfg_scale_start_step,
cfg_scale_end_step=self.cfg_scale_end_step,
)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
@@ -296,7 +314,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
cfg_scale=self.cfg_scale,
cfg_scale=cfg_scale,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)
@@ -304,6 +322,55 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width)
return x
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
) -> list[float]:
"""Prepare the cfg_scale schedule.
- Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step.
- If cfg_scale is a list, then it is assumed to be a schedule and is returned as-is.
- If cfg_scale is a scalar, then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step.
"""
# num_steps is the number of denoising steps, which is one less than the number of timesteps.
num_steps = len(timesteps) - 1
# Normalize cfg_scale to a list if it is a scalar.
cfg_scale_list: list[float]
if isinstance(cfg_scale, float):
cfg_scale_list = [cfg_scale] * num_steps
elif isinstance(cfg_scale, list):
cfg_scale_list = cfg_scale
else:
raise ValueError(f"Unsupported cfg_scale type: {type(cfg_scale)}")
assert len(cfg_scale_list) == num_steps
# Handle negative indices for cfg_scale_start_step and cfg_scale_end_step.
start_step_index = cfg_scale_start_step
if start_step_index < 0:
start_step_index = num_steps + start_step_index
end_step_index = cfg_scale_end_step
if end_step_index < 0:
end_step_index = num_steps + end_step_index
# Validate the start and end step indices.
if not (0 <= start_step_index < num_steps):
raise ValueError(f"Invalid cfg_scale_start_step. Out of range: {cfg_scale_start_step}.")
if not (0 <= end_step_index < num_steps):
raise ValueError(f"Invalid cfg_scale_end_step. Out of range: {cfg_scale_end_step}.")
if start_step_index > end_step_index:
raise ValueError(
f"cfg_scale_start_step ({cfg_scale_start_step}) must be before cfg_scale_end_step "
+ f"({cfg_scale_end_step})."
)
# Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those
# steps.
clipped_cfg_scale = [1.0] * num_steps
clipped_cfg_scale[start_step_index : end_step_index + 1] = cfg_scale_list[start_step_index : end_step_index + 1]
return clipped_cfg_scale
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.

View File

@@ -29,7 +29,7 @@ def denoise(
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
cfg_scale: float | list[float],
cfg_scale: list[float],
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
):
@@ -84,7 +84,7 @@ def denoise(
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
)
step_cfg_scale = cfg_scale[step_index] if isinstance(cfg_scale, list) else cfg_scale
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
if not math.isclose(step_cfg_scale, 1.0):