mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add support for FLUX controlnet weight, begin_step_percent and end_step_percent.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
@@ -55,6 +56,9 @@ class ControlNetExtension:
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Map pixel values from [0, 1] to [-1, 1].
|
||||
controlnet_cond = controlnet_cond * 2 - 1
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
@@ -65,6 +69,8 @@ class ControlNetExtension:
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
@@ -72,8 +78,12 @@ class ControlNetExtension:
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> list[torch.Tensor]:
|
||||
# TODO(ryand): Handle weight, begin_step_percent, end_step_percent.
|
||||
) -> list[torch.Tensor] | None:
|
||||
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
|
||||
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
|
||||
if timestep_index < first_step or timestep_index > last_step:
|
||||
return
|
||||
weight = self._weight
|
||||
|
||||
controlnet_block_res_samples = self._model(
|
||||
img=img,
|
||||
@@ -85,4 +95,9 @@ class ControlNetExtension:
|
||||
y=y,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
# Apply weight to the residuals.
|
||||
for block_res_sample in controlnet_block_res_samples:
|
||||
block_res_sample *= weight
|
||||
|
||||
return controlnet_block_res_samples
|
||||
|
||||
@@ -43,10 +43,12 @@ def denoise(
|
||||
|
||||
# Run ControlNet models.
|
||||
# controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model.
|
||||
controlnet_block_residuals: list[list[torch.Tensor]] = []
|
||||
controlnet_block_residuals: list[list[torch.Tensor] | None] = []
|
||||
for controlnet_extension in controlnet_extensions or []:
|
||||
controlnet_block_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step - 1,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
@@ -65,7 +67,7 @@ def denoise(
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
block_controlnet_hidden_states=controlnet_block_residuals,
|
||||
controlnet_block_residuals=controlnet_block_residuals,
|
||||
)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
|
||||
@@ -88,7 +88,7 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
block_controlnet_hidden_states: list[Tensor] | None = None,
|
||||
controlnet_block_residuals: list[list[Tensor] | None] | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -109,9 +109,11 @@ class Flux(nn.Module):
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
# Apply ControlNet residual.
|
||||
if block_controlnet_hidden_states is not None:
|
||||
img = img + block_controlnet_hidden_states[block_index % len(block_controlnet_hidden_states)]
|
||||
# Apply ControlNet residuals.
|
||||
if controlnet_block_residuals is not None:
|
||||
for single_controlnet_block_residuals in controlnet_block_residuals:
|
||||
if single_controlnet_block_residuals:
|
||||
img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
|
||||
Reference in New Issue
Block a user