Add support for FLUX controlnet weight, begin_step_percent and end_step_percent.

This commit is contained in:
Ryan Dick
2024-10-03 19:02:04 +00:00
parent c81bb761ed
commit 36515e1e2a
3 changed files with 27 additions and 8 deletions

View File

@@ -1,3 +1,4 @@
import math
from typing import List, Union
import torch
@@ -55,6 +56,9 @@ class ControlNetExtension:
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return cls(
model=model,
controlnet_cond=controlnet_cond,
@@ -65,6 +69,8 @@ class ControlNetExtension:
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
@@ -72,8 +78,12 @@ class ControlNetExtension:
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> list[torch.Tensor]:
# TODO(ryand): Handle weight, begin_step_percent, end_step_percent.
) -> list[torch.Tensor] | None:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return
weight = self._weight
controlnet_block_res_samples = self._model(
img=img,
@@ -85,4 +95,9 @@ class ControlNetExtension:
y=y,
guidance=guidance,
)
# Apply weight to the residuals.
for block_res_sample in controlnet_block_res_samples:
block_res_sample *= weight
return controlnet_block_res_samples

View File

@@ -43,10 +43,12 @@ def denoise(
# Run ControlNet models.
# controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model.
controlnet_block_residuals: list[list[torch.Tensor]] = []
controlnet_block_residuals: list[list[torch.Tensor] | None] = []
for controlnet_extension in controlnet_extensions or []:
controlnet_block_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
@@ -65,7 +67,7 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=controlnet_block_residuals,
controlnet_block_residuals=controlnet_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -88,7 +88,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
block_controlnet_hidden_states: list[Tensor] | None = None,
controlnet_block_residuals: list[list[Tensor] | None] | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -109,9 +109,11 @@ class Flux(nn.Module):
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# Apply ControlNet residual.
if block_controlnet_hidden_states is not None:
img = img + block_controlnet_hidden_states[block_index % len(block_controlnet_hidden_states)]
# Apply ControlNet residuals.
if controlnet_block_residuals is not None:
for single_controlnet_block_residuals in controlnet_block_residuals:
if single_controlnet_block_residuals:
img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)]
img = torch.cat((txt, img), 1)
for block in self.single_blocks: