mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add instantx controlnet logic to FLUX model forward().
This commit is contained in:
@@ -264,7 +264,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
xlabs_controlnet_extensions=xlabs_controlnet_extensions,
|
||||
instantx_controlnet_extensions=instantx_controlnet_extensions,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
@@ -87,8 +90,9 @@ class Flux(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
controlnet_block_residuals: list[list[Tensor] | None] | None = None,
|
||||
guidance: Tensor | None,
|
||||
xlabs_controlnet_residuals: list[XLabsControlNetFluxOutput],
|
||||
instantx_controlnet_residuals: list[InstantXControlNetFluxOutput],
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -109,15 +113,34 @@ class Flux(nn.Module):
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
# Apply ControlNet residuals.
|
||||
if controlnet_block_residuals is not None:
|
||||
for single_controlnet_block_residuals in controlnet_block_residuals:
|
||||
if single_controlnet_block_residuals:
|
||||
img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)]
|
||||
# Apply XLabs ControlNet residuals.
|
||||
for single_xlabs_cn_res in xlabs_controlnet_residuals:
|
||||
double_block_res = single_xlabs_cn_res.controlnet_double_block_residuals
|
||||
if double_block_res:
|
||||
img += double_block_res[block_index % len(double_block_res)]
|
||||
|
||||
# Apply InstantX ControlNet residuals.
|
||||
for single_instantx_cn_res in instantx_controlnet_residuals:
|
||||
double_block_res = single_instantx_cn_res.controlnet_block_samples
|
||||
if double_block_res:
|
||||
interval_control = len(self.double_blocks) / len(double_block_res)
|
||||
interval_control = int(math.ceil(interval_control))
|
||||
img += double_block_res[block_index // interval_control]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
# Apply InstantX ControlNet residuals.
|
||||
for single_instantx_cn_res in instantx_controlnet_residuals:
|
||||
single_block_res = single_instantx_cn_res.controlnet_single_block_samples
|
||||
if single_block_res:
|
||||
interval_control = len(self.single_blocks) / len(single_block_res)
|
||||
interval_control = int(math.ceil(interval_control))
|
||||
img[:, txt.shape[1] :, ...] = (
|
||||
img[:, txt.shape[1] :, ...] + single_block_res[block_index // interval_control]
|
||||
)
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
Reference in New Issue
Block a user