From c8d1d146629e52d04d2469786bc2d5d17efeac76 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 7 Oct 2024 22:17:06 +0000 Subject: [PATCH] Work on integrating InstantX into denoise process. --- invokeai/app/invocations/flux_denoise.py | 83 +++++++++---- .../controlnet/instantx_controlnet_flux.py | 8 ++ .../flux/controlnet/xlabs_controlnet_flux.py | 16 ++- invokeai/backend/flux/denoise.py | 23 ++-- .../extensions/base_controlnet_extension.py | 46 ++++++++ .../instantx_controlnet_extension.py | 109 ++++++++++++++++++ ...nsion.py => xlabs_controlnet_extension.py} | 35 +++--- 7 files changed, 270 insertions(+), 50 deletions(-) create mode 100644 invokeai/backend/flux/extensions/base_controlnet_extension.py create mode 100644 invokeai/backend/flux/extensions/instantx_controlnet_extension.py rename invokeai/backend/flux/extensions/{controlnet_extension.py => xlabs_controlnet_extension.py} (77%) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 1851e76d0e..634cc3db6f 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -17,13 +17,15 @@ from invokeai.app.invocations.fields import ( WithBoard, WithMetadata, ) -from invokeai.app.invocations.model import TransformerField +from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux from invokeai.backend.flux.denoise import denoise -from invokeai.backend.flux.extensions.controlnet_extension import ControlNetExtension from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension +from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.model import Flux from invokeai.backend.flux.sampling_utils import ( clip_timestep_schedule_fractional, @@ -93,6 +95,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): controlnet: ControlField | list[ControlField] | None = InputField( default=None, input=Input.Connection, description="ControlNet models." ) + controlnet_vae: VAEField | None = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -238,7 +244,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): raise ValueError(f"Unsupported model format: {config.format}") # Prepare ControlNet extensions. - controlnet_extensions = self._prep_controlnet_extensions( + (xlabs_controlnet_extensions, instantx_controlnet_extensions) = self._prep_controlnet_extensions( context=context, exit_stack=exit_stack, latent_height=latent_h, @@ -313,11 +319,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): latent_width: int, dtype: torch.dtype, device: torch.device, - ) -> list[ControlNetExtension] | None: + ) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]: # Normalize the controlnet input to list[ControlField]. controlnets: list[ControlField] if self.controlnet is None: - return None + controlnets = [] elif isinstance(self.controlnet, ControlField): controlnets = [self.controlnet] elif isinstance(self.controlnet, list): @@ -325,29 +331,62 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): else: raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}") - controlnet_extensions: list[ControlNetExtension] = [] + # TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets + # before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to + # minimize peak memory. + + xlabs_controlnet_extensions: list[XLabsControlNetExtension] = [] + instantx_controlnet_extensions: list[InstantXControlNetExtension] = [] for controlnet in controlnets: model = exit_stack.enter_context(context.models.load(controlnet.control_model)) - assert isinstance(model, XLabsControlNetFlux) image = context.images.get_pil(controlnet.image.image_name) - controlnet_extensions.append( - ControlNetExtension.from_controlnet_image( - model=model, - controlnet_image=image, - latent_height=latent_height, - latent_width=latent_width, - dtype=dtype, - device=device, - control_mode=controlnet.control_mode, - resize_mode=controlnet.resize_mode, - weight=controlnet.control_weight, - begin_step_percent=controlnet.begin_step_percent, - end_step_percent=controlnet.end_step_percent, + if isinstance(model, XLabsControlNetFlux): + xlabs_controlnet_extensions.append( + XLabsControlNetExtension.from_controlnet_image( + model=model, + controlnet_image=image, + latent_height=latent_height, + latent_width=latent_width, + dtype=dtype, + device=device, + control_mode=controlnet.control_mode, + resize_mode=controlnet.resize_mode, + weight=controlnet.control_weight, + begin_step_percent=controlnet.begin_step_percent, + end_step_percent=controlnet.end_step_percent, + ) ) - ) + elif isinstance(model, InstantXControlNetFlux): + # control_mode = torch.tensor(0, dtype=torch.long) + # control_mode = control_mode.reshape([-1, 1]) - return controlnet_extensions + if self.controlnet_vae is None: + raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.") + vae_info = context.models.load(self.controlnet_vae.vae) + + instantx_controlnet_extensions.append( + InstantXControlNetExtension.from_controlnet_image( + model=model, + controlnet_image=image, + # TODO(ryand): Pass in the correct control mode. + instantx_control_mode=None, + vae_info=vae_info, + latent_height=latent_height, + latent_width=latent_width, + dtype=dtype, + device=device, + control_mode=controlnet.control_mode, + resize_mode=controlnet.resize_mode, + weight=controlnet.control_weight, + begin_step_percent=controlnet.begin_step_percent, + end_step_percent=controlnet.end_step_percent, + ) + ) + else: + raise ValueError(f"Unsupported ControlNet model type: {type(model)}") + + return (xlabs_controlnet_extensions, instantx_controlnet_extensions) def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.transformer.loras: diff --git a/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py b/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py index d5230ac6bb..653d37b206 100644 --- a/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py +++ b/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py @@ -22,6 +22,14 @@ class InstantXControlNetFluxOutput: controlnet_block_samples: list[torch.Tensor] | None controlnet_single_block_samples: list[torch.Tensor] | None + def apply_weight(self, weight: float): + if self.controlnet_block_samples is not None: + for i in range(len(self.controlnet_block_samples)): + self.controlnet_block_samples[i] = self.controlnet_block_samples[i] * weight + if self.controlnet_single_block_samples is not None: + for i in range(len(self.controlnet_single_block_samples)): + self.controlnet_single_block_samples[i] = self.controlnet_single_block_samples[i] * weight + # NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params: # - Diffusers: BFL diff --git a/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py b/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py index b5f24bc554..9c5353f611 100644 --- a/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py +++ b/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py @@ -2,6 +2,8 @@ # https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py +from dataclasses import dataclass + import torch from einops import rearrange @@ -10,6 +12,16 @@ from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding +@dataclass +class XLabsControlNetFluxOutput: + controlnet_double_block_residuals: list[torch.Tensor] | None + + def apply_weight(self, weight: float): + if self.controlnet_double_block_residuals is not None: + for i in range(len(self.controlnet_double_block_residuals)): + self.controlnet_double_block_residuals[i] = self.controlnet_double_block_residuals[i] * weight + + class XLabsControlNetFlux(torch.nn.Module): """A ControlNet model for FLUX. @@ -88,7 +100,7 @@ class XLabsControlNetFlux(torch.nn.Module): timesteps: torch.Tensor, y: torch.Tensor, guidance: torch.Tensor | None = None, - ) -> list[torch.Tensor]: + ) -> XLabsControlNetFluxOutput: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -120,4 +132,4 @@ class XLabsControlNetFlux(torch.nn.Module): block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples.append(block_res_sample) - return controlnet_block_res_samples + return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index a679aa1950..a3c58a707e 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -1,10 +1,14 @@ +import itertools from typing import Callable import torch from tqdm import tqdm -from invokeai.backend.flux.extensions.controlnet_extension import ControlNetExtension +from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput +from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension +from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.model import Flux from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -22,7 +26,8 @@ def denoise( step_callback: Callable[[PipelineIntermediateState], None], guidance: float, inpaint_extension: InpaintExtension | None, - controlnet_extensions: list[ControlNetExtension] | None, + xlabs_controlnet_extensions: list[XLabsControlNetExtension], + instantx_controlnet_extensions: list[InstantXControlNetExtension], ): # step 0 is the initial state total_steps = len(timesteps) - 1 @@ -42,10 +47,9 @@ def denoise( t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # Run ControlNet models. - # controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model. - controlnet_block_residuals: list[list[torch.Tensor] | None] = [] - for controlnet_extension in controlnet_extensions or []: - controlnet_block_residuals.append( + controlnet_residuals: list[XLabsControlNetFluxOutput | InstantXControlNetFluxOutput | None] = [] + for controlnet_extension in itertools.chain(xlabs_controlnet_extensions, instantx_controlnet_extensions): + controlnet_residuals.append( controlnet_extension.run_controlnet( timestep_index=step - 1, total_num_timesteps=total_steps, @@ -58,6 +62,10 @@ def denoise( guidance=guidance_vec, ) ) + xlabs_controlnet_residuals = [res for res in controlnet_residuals if isinstance(res, XLabsControlNetFluxOutput)] + instantx_controlnet_residuals = [ + res for res in controlnet_residuals if isinstance(res, InstantXControlNetFluxOutput) + ] pred = model( img=img, @@ -67,7 +75,8 @@ def denoise( y=vec, timesteps=t_vec, guidance=guidance_vec, - controlnet_block_residuals=controlnet_block_residuals, + xlabs_controlnet_residuals=xlabs_controlnet_residuals, + instantx_controlnet_residuals=instantx_controlnet_residuals, ) preview_img = img - t_curr * pred diff --git a/invokeai/backend/flux/extensions/base_controlnet_extension.py b/invokeai/backend/flux/extensions/base_controlnet_extension.py new file mode 100644 index 0000000000..5e4fde9f90 --- /dev/null +++ b/invokeai/backend/flux/extensions/base_controlnet_extension.py @@ -0,0 +1,46 @@ +import math +from abc import ABC, abstractmethod +from typing import List, Union + +import torch + +from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput +from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput + + +class BaseControlNetExtension(ABC): + def __init__( + self, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + ): + self._weight = weight + self._begin_step_percent = begin_step_percent + self._end_step_percent = end_step_percent + + def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float: + first_step = math.floor(self._begin_step_percent * total_num_timesteps) + last_step = math.ceil(self._end_step_percent * total_num_timesteps) + + if timestep_index < first_step or timestep_index > last_step: + return 0.0 + + if isinstance(self._weight, list): + return self._weight[timestep_index] + + return self._weight + + @abstractmethod + def run_controlnet( + self, + timestep_index: int, + total_num_timesteps: int, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + y: torch.Tensor, + timesteps: torch.Tensor, + guidance: torch.Tensor | None, + ) -> InstantXControlNetFluxOutput | XLabsControlNetFluxOutput | None: ... diff --git a/invokeai/backend/flux/extensions/instantx_controlnet_extension.py b/invokeai/backend/flux/extensions/instantx_controlnet_extension.py new file mode 100644 index 0000000000..087fd63481 --- /dev/null +++ b/invokeai/backend/flux/extensions/instantx_controlnet_extension.py @@ -0,0 +1,109 @@ +from typing import List, Union + +import torch +from PIL.Image import Image + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation +from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image +from invokeai.backend.flux.controlnet.instantx_controlnet_flux import ( + InstantXControlNetFlux, + InstantXControlNetFluxOutput, +) +from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension +from invokeai.backend.model_manager.load.load_base import LoadedModel + + +class InstantXControlNetExtension(BaseControlNetExtension): + def __init__( + self, + model: InstantXControlNetFlux, + controlnet_cond: torch.Tensor, + instantx_control_mode: torch.Tensor | None, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + ): + super().__init__( + weight=weight, + begin_step_percent=begin_step_percent, + end_step_percent=end_step_percent, + ) + self._model = model + self._controlnet_cond = controlnet_cond + # TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models? + self._instantx_control_mode = instantx_control_mode + + @classmethod + def from_controlnet_image( + cls, + model: InstantXControlNetFlux, + controlnet_image: Image, + instantx_control_mode: torch.Tensor | None, + vae_info: LoadedModel, + latent_height: int, + latent_width: int, + dtype: torch.dtype, + device: torch.device, + control_mode: CONTROLNET_MODE_VALUES, + resize_mode: CONTROLNET_RESIZE_VALUES, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + ): + image_height = latent_height * LATENT_SCALE_FACTOR + image_width = latent_width * LATENT_SCALE_FACTOR + + resized_controlnet_image = prepare_control_image( + image=controlnet_image, + do_classifier_free_guidance=False, + width=image_width, + height=image_height, + device=device, + dtype=dtype, + control_mode=control_mode, + resize_mode=resize_mode, + ) + + # Run VAE encoder. + controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image) + + return cls( + model=model, + controlnet_cond=controlnet_cond, + instantx_control_mode=instantx_control_mode, + weight=weight, + begin_step_percent=begin_step_percent, + end_step_percent=end_step_percent, + ) + + def run_controlnet( + self, + timestep_index: int, + total_num_timesteps: int, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + y: torch.Tensor, + timesteps: torch.Tensor, + guidance: torch.Tensor | None, + ) -> InstantXControlNetFluxOutput | None: + weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps) + if weight < 1e-6: + return None + + output: InstantXControlNetFluxOutput = self._model( + controlnet_cond=self._controlnet_cond, + controlnet_mode=self._instantx_control_mode, + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + timesteps=timesteps, + y=y, + guidance=guidance, + ) + + output.apply_weight(weight) + return output diff --git a/invokeai/backend/flux/extensions/controlnet_extension.py b/invokeai/backend/flux/extensions/xlabs_controlnet_extension.py similarity index 77% rename from invokeai/backend/flux/extensions/controlnet_extension.py rename to invokeai/backend/flux/extensions/xlabs_controlnet_extension.py index 3b8b37ca36..0986af7414 100644 --- a/invokeai/backend/flux/extensions/controlnet_extension.py +++ b/invokeai/backend/flux/extensions/xlabs_controlnet_extension.py @@ -1,4 +1,3 @@ -import math from typing import List, Union import torch @@ -6,10 +5,11 @@ from PIL.Image import Image from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image -from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux +from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput +from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension -class ControlNetExtension: +class XLabsControlNetExtension(BaseControlNetExtension): def __init__( self, model: XLabsControlNetFlux, @@ -18,15 +18,17 @@ class ControlNetExtension: begin_step_percent: float, end_step_percent: float, ): + super().__init__( + weight=weight, + begin_step_percent=begin_step_percent, + end_step_percent=end_step_percent, + ) + self._model = model # _controlnet_cond is the control image passed to the ControlNet model. # Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width). self._controlnet_cond = controlnet_cond - self._weight = weight - self._begin_step_percent = begin_step_percent - self._end_step_percent = end_step_percent - @classmethod def from_controlnet_image( cls, @@ -78,14 +80,12 @@ class ControlNetExtension: y: torch.Tensor, timesteps: torch.Tensor, guidance: torch.Tensor | None, - ) -> list[torch.Tensor] | None: - first_step = math.floor(self._begin_step_percent * total_num_timesteps) - last_step = math.ceil(self._end_step_percent * total_num_timesteps) - if timestep_index < first_step or timestep_index > last_step: - return - weight = self._weight + ) -> XLabsControlNetFluxOutput | None: + weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps) + if weight < 1e-6: + return None - controlnet_block_res_samples = self._model( + output: XLabsControlNetFluxOutput = self._model( img=img, img_ids=img_ids, controlnet_cond=self._controlnet_cond, @@ -96,8 +96,5 @@ class ControlNetExtension: guidance=guidance, ) - # Apply weight to the residuals. - for block_res_sample in controlnet_block_res_samples: - block_res_sample *= weight - - return controlnet_block_res_samples + output.apply_weight(weight) + return output