First pass at integrating FLUX ControlNets into the FLUX Denoise invocation.

This commit is contained in:
Ryan Dick
2024-10-03 16:18:57 +00:00
parent 1d4a58e52b
commit c81bb761ed
5 changed files with 179 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -19,6 +20,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
from invokeai.backend.flux.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
@@ -44,7 +47,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.0.0",
version="3.1.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -87,6 +90,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
controlnet: ControlField | list[ControlField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -167,8 +173,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
@@ -231,6 +237,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare ControlNet extensions.
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
device=x.device,
)
x = denoise(
model=transformer,
img=x,
@@ -242,6 +258,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -288,6 +305,50 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _prep_controlnet_extensions(
self,
context: InvocationContext,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[ControlNetExtension] | None:
# Normalize the controlnet input to list[ControlField].
controlnets: list[ControlField]
if self.controlnet is None:
return None
elif isinstance(self.controlnet, ControlField):
controlnets = [self.controlnet]
elif isinstance(self.controlnet, list):
controlnets = self.controlnet
else:
raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}")
controlnet_extensions: list[ControlNetExtension] = []
for controlnet in controlnets:
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
assert isinstance(model, ControlNetFlux)
image = context.images.get_pil(controlnet.image.image_name)
controlnet_extensions.append(
ControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
return controlnet_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)

View File

@@ -0,0 +1,88 @@
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
class ControlNetExtension:
def __init__(
self,
model: ControlNetFlux,
controlnet_cond: torch.Tensor,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._model = model
# _controlnet_cond is the control image passed to the ControlNet model.
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
@classmethod
def from_controlnet_image(
cls,
model: ControlNetFlux,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return cls(
model=model,
controlnet_cond=controlnet_cond,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def run_controlnet(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> list[torch.Tensor]:
# TODO(ryand): Handle weight, begin_step_percent, end_step_percent.
controlnet_block_res_samples = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
return controlnet_block_res_samples

View File

@@ -3,6 +3,7 @@ from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -21,6 +22,7 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[ControlNetExtension] | None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -38,6 +40,23 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
# controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model.
controlnet_block_residuals: list[list[torch.Tensor]] = []
for controlnet_extension in controlnet_extensions or []:
controlnet_block_residuals.append(
controlnet_extension.run_controlnet(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
)
pred = model(
img=img,
img_ids=img_ids,
@@ -46,6 +65,7 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=controlnet_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -88,6 +88,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
block_controlnet_hidden_states: list[Tensor] | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -105,9 +106,13 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# Apply ControlNet residual.
if block_controlnet_hidden_states is not None:
img = img + block_controlnet_hidden_states[block_index % len(block_controlnet_hidden_states)]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)

View File

@@ -311,7 +311,7 @@ class FluxControlnetModel(ModelLoader):
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = ControlNetFlux(params["flux_dev"])
model = ControlNetFlux(params["flux-dev"])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)