mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 21:14:58 -05:00
Add cfg_scale_start_step and cfg_scale_end_step to FLUX Denoise node.
This commit is contained in:
committed by
psychedelicious
parent
20362448b9
commit
d20b894a61
@@ -87,8 +87,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
# TODO(ryand): Add cfg_scale range validation.
|
||||
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
cfg_scale_start_step: int = InputField(
|
||||
default=0,
|
||||
title="CFG Scale Start Step",
|
||||
description="Index of the first step to apply cfg_scale. Negative indices count backwards from the "
|
||||
+ "the last step (e.g. a value of -1 refers to the final step).",
|
||||
)
|
||||
cfg_scale_end_step: int = InputField(
|
||||
default=-1,
|
||||
title="CFG Scale End Step",
|
||||
description="Index of the last step to apply cfg_scale. Negative indices count backwards from the "
|
||||
+ "last step (e.g. a value of -1 refers to the final step).",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
@@ -235,6 +246,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
cfg_scale = self.prep_cfg_scale(
|
||||
cfg_scale=self.cfg_scale,
|
||||
timesteps=timesteps,
|
||||
cfg_scale_start_step=self.cfg_scale_start_step,
|
||||
cfg_scale_end_step=self.cfg_scale_end_step,
|
||||
)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Prepare ControlNet extensions.
|
||||
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
||||
@@ -296,7 +314,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
cfg_scale=self.cfg_scale,
|
||||
cfg_scale=cfg_scale,
|
||||
inpaint_extension=inpaint_extension,
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
)
|
||||
@@ -304,6 +322,55 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
) -> list[float]:
|
||||
"""Prepare the cfg_scale schedule.
|
||||
|
||||
- Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step.
|
||||
- If cfg_scale is a list, then it is assumed to be a schedule and is returned as-is.
|
||||
- If cfg_scale is a scalar, then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step.
|
||||
"""
|
||||
# num_steps is the number of denoising steps, which is one less than the number of timesteps.
|
||||
num_steps = len(timesteps) - 1
|
||||
|
||||
# Normalize cfg_scale to a list if it is a scalar.
|
||||
cfg_scale_list: list[float]
|
||||
if isinstance(cfg_scale, float):
|
||||
cfg_scale_list = [cfg_scale] * num_steps
|
||||
elif isinstance(cfg_scale, list):
|
||||
cfg_scale_list = cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Unsupported cfg_scale type: {type(cfg_scale)}")
|
||||
assert len(cfg_scale_list) == num_steps
|
||||
|
||||
# Handle negative indices for cfg_scale_start_step and cfg_scale_end_step.
|
||||
start_step_index = cfg_scale_start_step
|
||||
if start_step_index < 0:
|
||||
start_step_index = num_steps + start_step_index
|
||||
end_step_index = cfg_scale_end_step
|
||||
if end_step_index < 0:
|
||||
end_step_index = num_steps + end_step_index
|
||||
|
||||
# Validate the start and end step indices.
|
||||
if not (0 <= start_step_index < num_steps):
|
||||
raise ValueError(f"Invalid cfg_scale_start_step. Out of range: {cfg_scale_start_step}.")
|
||||
if not (0 <= end_step_index < num_steps):
|
||||
raise ValueError(f"Invalid cfg_scale_end_step. Out of range: {cfg_scale_end_step}.")
|
||||
if start_step_index > end_step_index:
|
||||
raise ValueError(
|
||||
f"cfg_scale_start_step ({cfg_scale_start_step}) must be before cfg_scale_end_step "
|
||||
+ f"({cfg_scale_end_step})."
|
||||
)
|
||||
|
||||
# Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those
|
||||
# steps.
|
||||
clipped_cfg_scale = [1.0] * num_steps
|
||||
clipped_cfg_scale[start_step_index : end_step_index + 1] = cfg_scale_list[start_step_index : end_step_index + 1]
|
||||
|
||||
return clipped_cfg_scale
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def denoise(
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
cfg_scale: float | list[float],
|
||||
cfg_scale: list[float],
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
):
|
||||
@@ -84,7 +84,7 @@ def denoise(
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index] if isinstance(cfg_scale, list) else cfg_scale
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
|
||||
Reference in New Issue
Block a user