Shift the controlnet-type-specific logic into the specific ControlNet extensions and make the FLUX model controlnet-type-agnostic.

This commit is contained in:
Ryan Dick
2024-10-09 16:12:09 +00:00
committed by Kent Keirsey
parent d99e7dd4e4
commit 0559480dd6
11 changed files with 176 additions and 97 deletions

View File

@@ -244,7 +244,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare ControlNet extensions.
(xlabs_controlnet_extensions, instantx_controlnet_extensions) = self._prep_controlnet_extensions(
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
@@ -264,8 +264,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
xlabs_controlnet_extensions=xlabs_controlnet_extensions,
instantx_controlnet_extensions=instantx_controlnet_extensions,
controlnet_extensions=controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -320,7 +319,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]:
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[FluxControlNetField]
if self.controlnet is None:
@@ -336,14 +335,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
xlabs_controlnet_extensions: list[XLabsControlNetExtension] = []
instantx_controlnet_extensions: list[InstantXControlNetExtension] = []
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet in controlnets:
model = exit_stack.enter_context(context.models.load(controlnet.controlnet_model))
image = context.images.get_pil(controlnet.image.image_name)
if isinstance(model, XLabsControlNetFlux):
xlabs_controlnet_extensions.append(
controlnet_extensions.append(
XLabsControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
@@ -365,7 +363,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
instantx_controlnet_extensions.append(
controlnet_extensions.append(
InstantXControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
@@ -384,7 +382,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
return (xlabs_controlnet_extensions, instantx_controlnet_extensions)
return controlnet_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:

View File

@@ -0,0 +1,58 @@
from dataclasses import dataclass
import torch
@dataclass
class ControlNetFluxOutput:
single_block_residuals: list[torch.Tensor] | None
double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.single_block_residuals is not None:
for i in range(len(self.single_block_residuals)):
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
if self.double_block_residuals is not None:
for i in range(len(self.double_block_residuals)):
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
def add_tensor_lists_elementwise(
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
) -> list[torch.Tensor] | None:
"""Add two tensor lists elementwise that could be None."""
if list1 is None and list2 is None:
return None
if list1 is None:
return list2
if list2 is None:
return list1
new_list: list[torch.Tensor] = []
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
new_list.append(list1_tensor + list2_tensor)
return new_list
def add_controlnet_flux_outputs(
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
) -> ControlNetFluxOutput:
return ControlNetFluxOutput(
single_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
),
double_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
),
)
def sum_controlnet_flux_outputs(
controlnet_outputs: list[ControlNetFluxOutput],
) -> ControlNetFluxOutput:
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
for controlnet_output in controlnet_outputs:
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
return controlnet_output_sum

View File

@@ -2,10 +2,11 @@
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
from dataclasses import dataclass
import torch
import torch.nn as nn
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import (
@@ -16,6 +17,13 @@ from invokeai.backend.flux.modules.layers import (
timestep_embedding,
)
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL
# - in_channels: in_channels

View File

@@ -1,17 +0,0 @@
from dataclasses import dataclass
import torch
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_block_samples is not None:
for i in range(len(self.controlnet_block_samples)):
self.controlnet_block_samples[i] = self.controlnet_block_samples[i] * weight
if self.controlnet_single_block_samples is not None:
for i in range(len(self.controlnet_single_block_samples)):
self.controlnet_single_block_samples[i] = self.controlnet_single_block_samples[i] * weight

View File

@@ -2,15 +2,21 @@
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
from dataclasses import dataclass
import torch
from einops import rearrange
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
import torch
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_double_block_residuals is not None:
for i in range(len(self.controlnet_double_block_residuals)):
self.controlnet_double_block_residuals[i] = self.controlnet_double_block_residuals[i] * weight

View File

@@ -1,11 +1,9 @@
import itertools
from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
@@ -26,8 +24,7 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
xlabs_controlnet_extensions: list[XLabsControlNetExtension],
instantx_controlnet_extensions: list[InstantXControlNetExtension],
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -47,8 +44,8 @@ def denoise(
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
controlnet_residuals: list[XLabsControlNetFluxOutput | InstantXControlNetFluxOutput | None] = []
for controlnet_extension in itertools.chain(xlabs_controlnet_extensions, instantx_controlnet_extensions):
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
@@ -62,10 +59,9 @@ def denoise(
guidance=guidance_vec,
)
)
xlabs_controlnet_residuals = [res for res in controlnet_residuals if isinstance(res, XLabsControlNetFluxOutput)]
instantx_controlnet_residuals = [
res for res in controlnet_residuals if isinstance(res, InstantXControlNetFluxOutput)
]
# Merge the ControlNet residuals from multiple ControlNets.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred = model(
img=img,
@@ -75,8 +71,8 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
xlabs_controlnet_residuals=xlabs_controlnet_residuals,
instantx_controlnet_residuals=instantx_controlnet_residuals,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -4,8 +4,7 @@ from typing import List, Union
import torch
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
class BaseControlNetExtension(ABC):
@@ -43,4 +42,4 @@ class BaseControlNetExtension(ABC):
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> InstantXControlNetFluxOutput | XLabsControlNetFluxOutput | None: ...
) -> ControlNetFluxOutput: ...

View File

@@ -1,3 +1,4 @@
import math
from typing import List, Union
import torch
@@ -6,10 +7,11 @@ from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
InstantXControlNetFluxOutput,
)
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.model_manager.load.load_base import LoadedModel
@@ -40,6 +42,10 @@ class InstantXControlNetExtension(BaseControlNetExtension):
# Expected dtype: torch.long
self._instantx_control_mode = instantx_control_mode
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def from_controlnet_image(
cls,
@@ -83,6 +89,35 @@ class InstantXControlNetExtension(BaseControlNetExtension):
end_step_percent=end_step_percent,
)
def _instantx_output_to_controlnet_output(
self, instantx_output: InstantXControlNetFluxOutput
) -> ControlNetFluxOutput:
# The `interval_control` logic here is based on
# https://github.com/huggingface/diffusers/blob/31058cdaef63ca660a1a045281d156239fba8192/src/diffusers/models/transformers/transformer_flux.py#L507-L511
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
double_block_samples = instantx_output.controlnet_block_samples
if double_block_samples:
interval_control = self._flux_transformer_num_double_blocks / len(double_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(double_block_samples[i // interval_control])
# Handle single block residuals.
single_block_residuals: list[torch.Tensor] = []
single_block_samples = instantx_output.controlnet_single_block_samples
if single_block_samples:
interval_control = self._flux_transformer_num_single_blocks / len(single_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_single_blocks):
single_block_residuals.append(single_block_samples[i // interval_control])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals,
single_block_residuals=single_block_residuals,
)
def run_controlnet(
self,
timestep_index: int,
@@ -94,10 +129,10 @@ class InstantXControlNetExtension(BaseControlNetExtension):
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> InstantXControlNetFluxOutput | None:
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return None
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
# Make sure inputs have correct device and dtype.
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
@@ -105,7 +140,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
)
output: InstantXControlNetFluxOutput = self._model(
instantx_output: InstantXControlNetFluxOutput = self._model(
controlnet_cond=self._controlnet_cond,
controlnet_mode=self._instantx_control_mode,
img=img,
@@ -117,5 +152,6 @@ class InstantXControlNetExtension(BaseControlNetExtension):
guidance=guidance,
)
output.apply_weight(weight)
return output
controlnet_output = self._instantx_output_to_controlnet_output(instantx_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -5,8 +5,8 @@ from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
@@ -30,6 +30,10 @@ class XLabsControlNetExtension(BaseControlNetExtension):
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
# TODO(ryand): Pass in these params if a new base transformer / XLabs ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def from_controlnet_image(
cls,
@@ -69,6 +73,22 @@ class XLabsControlNetExtension(BaseControlNetExtension):
end_step_percent=end_step_percent,
)
def _xlabs_output_to_controlnet_output(self, xlabs_output: XLabsControlNetFluxOutput) -> ControlNetFluxOutput:
# The modulo index logic used here is based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/model.py#L198-L200
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
xlabs_double_block_residuals = xlabs_output.controlnet_double_block_residuals
if xlabs_double_block_residuals is not None:
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(xlabs_double_block_residuals[i % len(xlabs_double_block_residuals)])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals,
single_block_residuals=None,
)
def run_controlnet(
self,
timestep_index: int,
@@ -80,12 +100,12 @@ class XLabsControlNetExtension(BaseControlNetExtension):
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> XLabsControlNetFluxOutput | None:
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return None
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
output: XLabsControlNetFluxOutput = self._model(
xlabs_output: XLabsControlNetFluxOutput = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
@@ -96,5 +116,6 @@ class XLabsControlNetExtension(BaseControlNetExtension):
guidance=guidance,
)
output.apply_weight(weight)
return output
controlnet_output = self._xlabs_output_to_controlnet_output(xlabs_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -1,13 +1,10 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
@@ -91,8 +88,8 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None,
xlabs_controlnet_residuals: list[XLabsControlNetFluxOutput],
instantx_controlnet_residuals: list[InstantXControlNetFluxOutput],
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -110,36 +107,26 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
# Validate double_block_residuals shape.
if controlnet_double_block_residuals is not None:
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# Apply XLabs ControlNet residuals.
for single_xlabs_cn_res in xlabs_controlnet_residuals:
double_block_res = single_xlabs_cn_res.controlnet_double_block_residuals
if double_block_res:
img += double_block_res[block_index % len(double_block_res)]
# Apply InstantX ControlNet residuals.
for single_instantx_cn_res in instantx_controlnet_residuals:
double_block_res = single_instantx_cn_res.controlnet_block_samples
if double_block_res:
interval_control = len(self.double_blocks) / len(double_block_res)
interval_control = int(math.ceil(interval_control))
img += double_block_res[block_index // interval_control]
if controlnet_double_block_residuals is not None:
img += controlnet_double_block_residuals[block_index]
img = torch.cat((txt, img), 1)
# Validate single_block_residuals shape.
if controlnet_single_block_residuals is not None:
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
# Apply InstantX ControlNet residuals.
for single_instantx_cn_res in instantx_controlnet_residuals:
single_block_res = single_instantx_cn_res.controlnet_single_block_samples
if single_block_res:
interval_control = len(self.single_blocks) / len(single_block_res)
interval_control = int(math.ceil(interval_control))
img[:, txt.shape[1] :, ...] = (
img[:, txt.shape[1] :, ...] + single_block_res[block_index // interval_control]
)
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
img = img[:, txt.shape[1] :, ...]