Work on integrating InstantX into denoise process.

This commit is contained in:
Ryan Dick
2024-10-07 22:17:06 +00:00
committed by Kent Keirsey
parent 44c588d778
commit c8d1d14662
7 changed files with 270 additions and 50 deletions

View File

@@ -17,13 +17,15 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.controlnet_extension import ControlNetExtension
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
@@ -93,6 +95,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet: ControlField | list[ControlField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -238,7 +244,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(f"Unsupported model format: {config.format}")
# Prepare ControlNet extensions.
controlnet_extensions = self._prep_controlnet_extensions(
(xlabs_controlnet_extensions, instantx_controlnet_extensions) = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
@@ -313,11 +319,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[ControlNetExtension] | None:
) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[ControlField]
if self.controlnet is None:
return None
controlnets = []
elif isinstance(self.controlnet, ControlField):
controlnets = [self.controlnet]
elif isinstance(self.controlnet, list):
@@ -325,29 +331,62 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else:
raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}")
controlnet_extensions: list[ControlNetExtension] = []
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
xlabs_controlnet_extensions: list[XLabsControlNetExtension] = []
instantx_controlnet_extensions: list[InstantXControlNetExtension] = []
for controlnet in controlnets:
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
assert isinstance(model, XLabsControlNetFlux)
image = context.images.get_pil(controlnet.image.image_name)
controlnet_extensions.append(
ControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
if isinstance(model, XLabsControlNetFlux):
xlabs_controlnet_extensions.append(
XLabsControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
)
elif isinstance(model, InstantXControlNetFlux):
# control_mode = torch.tensor(0, dtype=torch.long)
# control_mode = control_mode.reshape([-1, 1])
return controlnet_extensions
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
instantx_controlnet_extensions.append(
InstantXControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
# TODO(ryand): Pass in the correct control mode.
instantx_control_mode=None,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
return (xlabs_controlnet_extensions, instantx_controlnet_extensions)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras: