From 5df10cc494096e609d99101fd04853d4ba04afcb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 18 Oct 2024 20:14:47 +0000 Subject: [PATCH] Add support for cfg_scale list on FLUX Denoise node. --- invokeai/app/invocations/flux_denoise.py | 3 +- invokeai/backend/flux/denoise.py | 44 +++++++++++++----------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index fc6e153ac7..b90b7d8519 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -85,9 +85,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): negative_text_conditioning: FluxConditioningField = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection ) - # TODO(ryand): Add support for cfg_scale to be a list of floats: one for each step. # TODO(ryand): Add cfg_scale range validation. - cfg_scale: float = InputField(default=3.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") + cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") num_steps: int = InputField( diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index b524d67e7c..bcdb15a18f 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -1,3 +1,4 @@ +import math from typing import Callable import torch @@ -28,7 +29,7 @@ def denoise( timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], guidance: float, - cfg_scale: float, + cfg_scale: float | list[float], inpaint_extension: InpaintExtension | None, controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension], ): @@ -43,10 +44,9 @@ def denoise( latents=img, ), ) - step = 1 # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): + for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # Run ControlNet models. @@ -54,7 +54,7 @@ def denoise( for controlnet_extension in controlnet_extensions: controlnet_residuals.append( controlnet_extension.run_controlnet( - timestep_index=step - 1, + timestep_index=step_index, total_num_timesteps=total_steps, img=img, img_ids=img_ids, @@ -84,21 +84,24 @@ def denoise( controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, ) - # TODO(ryand): Add option to apply controlnet to negative conditioning as well. - # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on - # systems with sufficient VRAM. - neg_pred = model( - img=img, - img_ids=img_ids, - txt=neg_txt, - txt_ids=neg_txt_ids, - y=neg_vec, - timesteps=t_vec, - guidance=guidance_vec, - controlnet_double_block_residuals=None, - controlnet_single_block_residuals=None, - ) - pred = neg_pred + cfg_scale * (pred - neg_pred) + step_cfg_scale = cfg_scale[step_index] if isinstance(cfg_scale, list) else cfg_scale + + # If step_cfg_scale, is 1.0, then we don't need to run the negative prediction. + if not math.isclose(step_cfg_scale, 1.0): + # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance on + # systems with sufficient VRAM. + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + controlnet_double_block_residuals=None, + controlnet_single_block_residuals=None, + ) + pred = neg_pred + step_cfg_scale * (pred - neg_pred) preview_img = img - t_curr * pred img = img + (t_prev - t_curr) * pred @@ -109,13 +112,12 @@ def denoise( step_callback( PipelineIntermediateState( - step=step, + step=step_index + 1, order=1, total_steps=total_steps, timestep=int(t_curr), latents=preview_img, ), ) - step += 1 return img