Work on integrating InstantX into denoise process.

This commit is contained in:
Ryan Dick
2024-10-07 22:17:06 +00:00
committed by Kent Keirsey
parent 44c588d778
commit c8d1d14662
7 changed files with 270 additions and 50 deletions

View File

@@ -17,13 +17,15 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
@@ -93,6 +95,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet: ControlField | list[ControlField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -238,7 +244,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare ControlNet extensions.
controlnet_extensions = self._prep_controlnet_extensions(
(xlabs_controlnet_extensions, instantx_controlnet_extensions) = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
@@ -313,11 +319,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[ControlNetExtension] | None:
) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[ControlField]
if self.controlnet is None:
return None
controlnets = []
elif isinstance(self.controlnet, ControlField):
controlnets = [self.controlnet]
elif isinstance(self.controlnet, list):
@@ -325,29 +331,62 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}")
controlnet_extensions: list[ControlNetExtension] = []
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
xlabs_controlnet_extensions: list[XLabsControlNetExtension] = []
instantx_controlnet_extensions: list[InstantXControlNetExtension] = []
for controlnet in controlnets:
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
assert isinstance(model, XLabsControlNetFlux)
image = context.images.get_pil(controlnet.image.image_name)
controlnet_extensions.append(
ControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
if isinstance(model, XLabsControlNetFlux):
xlabs_controlnet_extensions.append(
XLabsControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
)
elif isinstance(model, InstantXControlNetFlux):
# control_mode = torch.tensor(0, dtype=torch.long)
# control_mode = control_mode.reshape([-1, 1])
return controlnet_extensions
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
instantx_controlnet_extensions.append(
InstantXControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
# TODO(ryand): Pass in the correct control mode.
instantx_control_mode=None,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
return (xlabs_controlnet_extensions, instantx_controlnet_extensions)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:

View File

@@ -22,6 +22,14 @@ class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_block_samples is not None:
for i in range(len(self.controlnet_block_samples)):
self.controlnet_block_samples[i] = self.controlnet_block_samples[i] * weight
if self.controlnet_single_block_samples is not None:
for i in range(len(self.controlnet_single_block_samples)):
self.controlnet_single_block_samples[i] = self.controlnet_single_block_samples[i] * weight
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL

View File

@@ -2,6 +2,8 @@
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
from dataclasses import dataclass
import torch
from einops import rearrange
@@ -10,6 +12,16 @@ from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.controlnet_double_block_residuals is not None:
for i in range(len(self.controlnet_double_block_residuals)):
self.controlnet_double_block_residuals[i] = self.controlnet_double_block_residuals[i] * weight
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.
@@ -88,7 +100,7 @@ class XLabsControlNetFlux(torch.nn.Module):
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> list[torch.Tensor]:
) -> XLabsControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -120,4 +132,4 @@ class XLabsControlNetFlux(torch.nn.Module):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples.append(block_res_sample)
return controlnet_block_res_samples
return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples)

View File

@@ -1,10 +1,14 @@
import itertools
from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.extensions.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -22,7 +26,8 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[ControlNetExtension] | None,
xlabs_controlnet_extensions: list[XLabsControlNetExtension],
instantx_controlnet_extensions: list[InstantXControlNetExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -42,10 +47,9 @@ def denoise(
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
# controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model.
controlnet_block_residuals: list[list[torch.Tensor] | None] = []
for controlnet_extension in controlnet_extensions or []:
controlnet_block_residuals.append(
controlnet_residuals: list[XLabsControlNetFluxOutput | InstantXControlNetFluxOutput | None] = []
for controlnet_extension in itertools.chain(xlabs_controlnet_extensions, instantx_controlnet_extensions):
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
total_num_timesteps=total_steps,
@@ -58,6 +62,10 @@ def denoise(
guidance=guidance_vec,
)
)
xlabs_controlnet_residuals = [res for res in controlnet_residuals if isinstance(res, XLabsControlNetFluxOutput)]
instantx_controlnet_residuals = [
res for res in controlnet_residuals if isinstance(res, InstantXControlNetFluxOutput)
]
pred = model(
img=img,
@@ -67,7 +75,8 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_block_residuals=controlnet_block_residuals,
xlabs_controlnet_residuals=xlabs_controlnet_residuals,
instantx_controlnet_residuals=instantx_controlnet_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -0,0 +1,46 @@
import math
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFluxOutput
class BaseControlNetExtension(ABC):
def __init__(
self,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return 0.0
if isinstance(self._weight, list):
return self._weight[timestep_index]
return self._weight
@abstractmethod
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> InstantXControlNetFluxOutput | XLabsControlNetFluxOutput | None: ...

View File

@@ -0,0 +1,109 @@
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
InstantXControlNetFluxOutput,
)
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.model_manager.load.load_base import LoadedModel
class InstantXControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: InstantXControlNetFlux,
controlnet_cond: torch.Tensor,
instantx_control_mode: torch.Tensor | None,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
self._controlnet_cond = controlnet_cond
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
self._instantx_control_mode = instantx_control_mode
@classmethod
def from_controlnet_image(
cls,
model: InstantXControlNetFlux,
controlnet_image: Image,
instantx_control_mode: torch.Tensor | None,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
return cls(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> InstantXControlNetFluxOutput | None:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return None
output: InstantXControlNetFluxOutput = self._model(
controlnet_cond=self._controlnet_cond,
controlnet_mode=self._instantx_control_mode,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
output.apply_weight(weight)
return output

View File

@@ -1,4 +1,3 @@
import math
from typing import List, Union
import torch
@@ -6,10 +5,11 @@ from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
class ControlNetExtension:
class XLabsControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: XLabsControlNetFlux,
@@ -18,15 +18,17 @@ class ControlNetExtension:
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# _controlnet_cond is the control image passed to the ControlNet model.
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
@classmethod
def from_controlnet_image(
cls,
@@ -78,14 +80,12 @@ class ControlNetExtension:
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> list[torch.Tensor] | None:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return
weight = self._weight
) -> XLabsControlNetFluxOutput | None:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return None
controlnet_block_res_samples = self._model(
output: XLabsControlNetFluxOutput = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
@@ -96,8 +96,5 @@ class ControlNetExtension:
guidance=guidance,
)
# Apply weight to the residuals.
for block_res_sample in controlnet_block_res_samples:
block_res_sample *= weight
return controlnet_block_res_samples
output.apply_weight(weight)
return output