From d20b894a619d1f9196b7d5c8fac0aa4e97bc2c2f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 21 Oct 2024 14:52:02 +0000 Subject: [PATCH] Add cfg_scale_start_step and cfg_scale_end_step to FLUX Denoise node. --- invokeai/app/invocations/flux_denoise.py | 71 +++++++++++++++++++++- invokeai/backend/flux/denoise.py | 4 +- tests/app/invocations/test_flux_denoise.py | 62 +++++++++++++++++++ 3 files changed, 133 insertions(+), 4 deletions(-) create mode 100644 tests/app/invocations/test_flux_denoise.py diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 5c1cb70822..925787b422 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -87,8 +87,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.", input=Input.Connection, ) - # TODO(ryand): Add cfg_scale range validation. cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") + cfg_scale_start_step: int = InputField( + default=0, + title="CFG Scale Start Step", + description="Index of the first step to apply cfg_scale. Negative indices count backwards from the " + + "the last step (e.g. a value of -1 refers to the final step).", + ) + cfg_scale_end_step: int = InputField( + default=-1, + title="CFG Scale End Step", + description="Index of the last step to apply cfg_scale. Negative indices count backwards from the " + + "last step (e.g. a value of -1 refers to the final step).", + ) width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") num_steps: int = InputField( @@ -235,6 +246,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): noise=noise, ) + cfg_scale = self.prep_cfg_scale( + cfg_scale=self.cfg_scale, + timesteps=timesteps, + cfg_scale_start_step=self.cfg_scale_start_step, + cfg_scale_end_step=self.cfg_scale_end_step, + ) + with ExitStack() as exit_stack: # Prepare ControlNet extensions. # Note: We do this before loading the transformer model to minimize peak memory (see implementation). @@ -296,7 +314,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, - cfg_scale=self.cfg_scale, + cfg_scale=cfg_scale, inpaint_extension=inpaint_extension, controlnet_extensions=controlnet_extensions, ) @@ -304,6 +322,55 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): x = unpack(x.float(), self.height, self.width) return x + @classmethod + def prep_cfg_scale( + cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int + ) -> list[float]: + """Prepare the cfg_scale schedule. + + - Clips the cfg_scale schedule based on cfg_scale_start_step and cfg_scale_end_step. + - If cfg_scale is a list, then it is assumed to be a schedule and is returned as-is. + - If cfg_scale is a scalar, then a linear schedule is created from cfg_scale_start_step to cfg_scale_end_step. + """ + # num_steps is the number of denoising steps, which is one less than the number of timesteps. + num_steps = len(timesteps) - 1 + + # Normalize cfg_scale to a list if it is a scalar. + cfg_scale_list: list[float] + if isinstance(cfg_scale, float): + cfg_scale_list = [cfg_scale] * num_steps + elif isinstance(cfg_scale, list): + cfg_scale_list = cfg_scale + else: + raise ValueError(f"Unsupported cfg_scale type: {type(cfg_scale)}") + assert len(cfg_scale_list) == num_steps + + # Handle negative indices for cfg_scale_start_step and cfg_scale_end_step. + start_step_index = cfg_scale_start_step + if start_step_index < 0: + start_step_index = num_steps + start_step_index + end_step_index = cfg_scale_end_step + if end_step_index < 0: + end_step_index = num_steps + end_step_index + + # Validate the start and end step indices. + if not (0 <= start_step_index < num_steps): + raise ValueError(f"Invalid cfg_scale_start_step. Out of range: {cfg_scale_start_step}.") + if not (0 <= end_step_index < num_steps): + raise ValueError(f"Invalid cfg_scale_end_step. Out of range: {cfg_scale_end_step}.") + if start_step_index > end_step_index: + raise ValueError( + f"cfg_scale_start_step ({cfg_scale_start_step}) must be before cfg_scale_end_step " + + f"({cfg_scale_end_step})." + ) + + # Set values outside the start and end step indices to 1.0. This is equivalent to disabling cfg_scale for those + # steps. + clipped_cfg_scale = [1.0] * num_steps + clipped_cfg_scale[start_step_index : end_step_index + 1] = cfg_scale_list[start_step_index : end_step_index + 1] + + return clipped_cfg_scale + def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: """Prepare the inpaint mask. diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 92811f76f6..7ce375f4a2 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -29,7 +29,7 @@ def denoise( timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], guidance: float, - cfg_scale: float | list[float], + cfg_scale: list[float], inpaint_extension: InpaintExtension | None, controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension], ): @@ -84,7 +84,7 @@ def denoise( controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, ) - step_cfg_scale = cfg_scale[step_index] if isinstance(cfg_scale, list) else cfg_scale + step_cfg_scale = cfg_scale[step_index] # If step_cfg_scale, is 1.0, then we don't need to run the negative prediction. if not math.isclose(step_cfg_scale, 1.0): diff --git a/tests/app/invocations/test_flux_denoise.py b/tests/app/invocations/test_flux_denoise.py new file mode 100644 index 0000000000..412ef7a490 --- /dev/null +++ b/tests/app/invocations/test_flux_denoise.py @@ -0,0 +1,62 @@ +import pytest + +from invokeai.app.invocations.flux_denoise import FluxDenoiseInvocation + +TIMESTEPS = [1.0, 0.75, 0.5, 0.25, 0.0] + + +@pytest.mark.parametrize( + ["cfg_scale", "timesteps", "cfg_scale_start_step", "cfg_scale_end_step", "expected"], + [ + # Test scalar cfg_scale. + (2.0, TIMESTEPS, 0, -1, [2.0, 2.0, 2.0, 2.0]), + # Test list cfg_scale. + ([1.0, 2.0, 3.0, 4.0], TIMESTEPS, 0, -1, [1.0, 2.0, 3.0, 4.0]), + # Test positive cfg_scale_start_step. + (2.0, TIMESTEPS, 1, -1, [1.0, 2.0, 2.0, 2.0]), + # Test positive cfg_scale_end_step. + (2.0, TIMESTEPS, 0, 2, [2.0, 2.0, 2.0, 1.0]), + # Test negative cfg_scale_start_step. + (2.0, TIMESTEPS, -3, -1, [1.0, 2.0, 2.0, 2.0]), + # Test negative cfg_scale_end_step. + (2.0, TIMESTEPS, 0, -2, [2.0, 2.0, 2.0, 1.0]), + # Test single step application. + (2.0, TIMESTEPS, 2, 2, [1.0, 1.0, 2.0, 1.0]), + ], +) +def test_prep_cfg_scale( + cfg_scale: float | list[float], + timesteps: list[float], + cfg_scale_start_step: int, + cfg_scale_end_step: int, + expected: list[float], +): + result = FluxDenoiseInvocation.prep_cfg_scale(cfg_scale, timesteps, cfg_scale_start_step, cfg_scale_end_step) + assert result == expected + + +def test_prep_cfg_scale_invalid_type(): + with pytest.raises(ValueError, match="Unsupported cfg_scale type"): + FluxDenoiseInvocation.prep_cfg_scale("invalid", [1.0, 0.5], 0, -1) # type: ignore + + +@pytest.mark.parametrize("cfg_scale_start_step", [4, -5]) +def test_prep_cfg_scale_invalid_start_step(cfg_scale_start_step: int): + with pytest.raises(ValueError, match="Invalid cfg_scale_start_step"): + FluxDenoiseInvocation.prep_cfg_scale(2.0, TIMESTEPS, cfg_scale_start_step, -1) + + +@pytest.mark.parametrize("cfg_scale_end_step", [4, -5]) +def test_prep_cfg_scale_invalid_end_step(cfg_scale_end_step: int): + with pytest.raises(ValueError, match="Invalid cfg_scale_end_step"): + FluxDenoiseInvocation.prep_cfg_scale(2.0, TIMESTEPS, 0, cfg_scale_end_step) + + +def test_prep_cfg_scale_start_after_end(): + with pytest.raises(ValueError, match="cfg_scale_start_step .* must be before cfg_scale_end_step"): + FluxDenoiseInvocation.prep_cfg_scale(2.0, TIMESTEPS, 3, 2) + + +def test_prep_cfg_scale_list_length_mismatch(): + with pytest.raises(AssertionError): + FluxDenoiseInvocation.prep_cfg_scale([1.0, 2.0, 3.0], TIMESTEPS, 0, -1)