Add instantx controlnet logic to FLUX model forward().

This commit is contained in:
Ryan Dick
2024-10-08 15:50:42 +00:00
committed by Kent Keirsey
parent c8d1d14662
commit 4289b5e6c3
2 changed files with 33 additions and 9 deletions

View File

@@ -264,7 +264,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
xlabs_controlnet_extensions=xlabs_controlnet_extensions,
instantx_controlnet_extensions=instantx_controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)

View File

@@ -1,10 +1,13 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
@@ -87,8 +90,9 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
controlnet_block_residuals: list[list[Tensor] | None] | None = None,
guidance: Tensor | None,
xlabs_controlnet_residuals: list[XLabsControlNetFluxOutput],
instantx_controlnet_residuals: list[InstantXControlNetFluxOutput],
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -109,15 +113,34 @@ class Flux(nn.Module):
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# Apply ControlNet residuals.
if controlnet_block_residuals is not None:
for single_controlnet_block_residuals in controlnet_block_residuals:
if single_controlnet_block_residuals:
img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)]
# Apply XLabs ControlNet residuals.
for single_xlabs_cn_res in xlabs_controlnet_residuals:
double_block_res = single_xlabs_cn_res.controlnet_double_block_residuals
if double_block_res:
img += double_block_res[block_index % len(double_block_res)]
# Apply InstantX ControlNet residuals.
for single_instantx_cn_res in instantx_controlnet_residuals:
double_block_res = single_instantx_cn_res.controlnet_block_samples
if double_block_res:
interval_control = len(self.double_blocks) / len(double_block_res)
interval_control = int(math.ceil(interval_control))
img += double_block_res[block_index // interval_control]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
for block_index, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
# Apply InstantX ControlNet residuals.
for single_instantx_cn_res in instantx_controlnet_residuals:
single_block_res = single_instantx_cn_res.controlnet_single_block_samples
if single_block_res:
interval_control = len(self.single_blocks) / len(single_block_res)
interval_control = int(math.ceil(interval_control))
img[:, txt.shape[1] :, ...] = (
img[:, txt.shape[1] :, ...] + single_block_res[block_index // interval_control]
)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)