From 36515e1e2a30f26d4513b421f35478c2405495ba Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 3 Oct 2024 19:02:04 +0000 Subject: [PATCH] Add support for FLUX controlnet weight, begin_step_percent and end_step_percent. --- invokeai/backend/flux/controlnet_extension.py | 19 +++++++++++++++++-- invokeai/backend/flux/denoise.py | 6 ++++-- invokeai/backend/flux/model.py | 10 ++++++---- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/flux/controlnet_extension.py b/invokeai/backend/flux/controlnet_extension.py index 9ed25c3897..39c71a1883 100644 --- a/invokeai/backend/flux/controlnet_extension.py +++ b/invokeai/backend/flux/controlnet_extension.py @@ -1,3 +1,4 @@ +import math from typing import List, Union import torch @@ -55,6 +56,9 @@ class ControlNetExtension: resize_mode=resize_mode, ) + # Map pixel values from [0, 1] to [-1, 1]. + controlnet_cond = controlnet_cond * 2 - 1 + return cls( model=model, controlnet_cond=controlnet_cond, @@ -65,6 +69,8 @@ class ControlNetExtension: def run_controlnet( self, + timestep_index: int, + total_num_timesteps: int, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, @@ -72,8 +78,12 @@ class ControlNetExtension: y: torch.Tensor, timesteps: torch.Tensor, guidance: torch.Tensor | None, - ) -> list[torch.Tensor]: - # TODO(ryand): Handle weight, begin_step_percent, end_step_percent. + ) -> list[torch.Tensor] | None: + first_step = math.floor(self._begin_step_percent * total_num_timesteps) + last_step = math.ceil(self._end_step_percent * total_num_timesteps) + if timestep_index < first_step or timestep_index > last_step: + return + weight = self._weight controlnet_block_res_samples = self._model( img=img, @@ -85,4 +95,9 @@ class ControlNetExtension: y=y, guidance=guidance, ) + + # Apply weight to the residuals. + for block_res_sample in controlnet_block_res_samples: + block_res_sample *= weight + return controlnet_block_res_samples diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 9da1e45a24..934da50617 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -43,10 +43,12 @@ def denoise( # Run ControlNet models. # controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model. - controlnet_block_residuals: list[list[torch.Tensor]] = [] + controlnet_block_residuals: list[list[torch.Tensor] | None] = [] for controlnet_extension in controlnet_extensions or []: controlnet_block_residuals.append( controlnet_extension.run_controlnet( + timestep_index=step - 1, + total_num_timesteps=total_steps, img=img, img_ids=img_ids, txt=txt, @@ -65,7 +67,7 @@ def denoise( y=vec, timesteps=t_vec, guidance=guidance_vec, - block_controlnet_hidden_states=controlnet_block_residuals, + controlnet_block_residuals=controlnet_block_residuals, ) preview_img = img - t_curr * pred diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index 2caf615eb3..cbb15465c8 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -88,7 +88,7 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, - block_controlnet_hidden_states: list[Tensor] | None = None, + controlnet_block_residuals: list[list[Tensor] | None] | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -109,9 +109,11 @@ class Flux(nn.Module): for block_index, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - # Apply ControlNet residual. - if block_controlnet_hidden_states is not None: - img = img + block_controlnet_hidden_states[block_index % len(block_controlnet_hidden_states)] + # Apply ControlNet residuals. + if controlnet_block_residuals is not None: + for single_controlnet_block_residuals in controlnet_block_residuals: + if single_controlnet_block_residuals: + img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)] img = torch.cat((txt, img), 1) for block in self.single_blocks: