mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Work on integrating InstantX into denoise process.
This commit is contained in:
@@ -17,13 +17,15 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.controlnet_extension import ControlNetExtension
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule_fractional,
|
||||
@@ -93,6 +95,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet: ControlField | list[ControlField] | None = InputField(
|
||||
default=None, input=Input.Connection, description="ControlNet models."
|
||||
)
|
||||
controlnet_vae: VAEField | None = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -238,7 +244,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
# Prepare ControlNet extensions.
|
||||
controlnet_extensions = self._prep_controlnet_extensions(
|
||||
(xlabs_controlnet_extensions, instantx_controlnet_extensions) = self._prep_controlnet_extensions(
|
||||
context=context,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_h,
|
||||
@@ -313,11 +319,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[ControlNetExtension] | None:
|
||||
) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]:
|
||||
# Normalize the controlnet input to list[ControlField].
|
||||
controlnets: list[ControlField]
|
||||
if self.controlnet is None:
|
||||
return None
|
||||
controlnets = []
|
||||
elif isinstance(self.controlnet, ControlField):
|
||||
controlnets = [self.controlnet]
|
||||
elif isinstance(self.controlnet, list):
|
||||
@@ -325,29 +331,62 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
else:
|
||||
raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}")
|
||||
|
||||
controlnet_extensions: list[ControlNetExtension] = []
|
||||
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
|
||||
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
|
||||
# minimize peak memory.
|
||||
|
||||
xlabs_controlnet_extensions: list[XLabsControlNetExtension] = []
|
||||
instantx_controlnet_extensions: list[InstantXControlNetExtension] = []
|
||||
for controlnet in controlnets:
|
||||
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
|
||||
assert isinstance(model, XLabsControlNetFlux)
|
||||
image = context.images.get_pil(controlnet.image.image_name)
|
||||
|
||||
controlnet_extensions.append(
|
||||
ControlNetExtension.from_controlnet_image(
|
||||
model=model,
|
||||
controlnet_image=image,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
control_mode=controlnet.control_mode,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
if isinstance(model, XLabsControlNetFlux):
|
||||
xlabs_controlnet_extensions.append(
|
||||
XLabsControlNetExtension.from_controlnet_image(
|
||||
model=model,
|
||||
controlnet_image=image,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
control_mode=controlnet.control_mode,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(model, InstantXControlNetFlux):
|
||||
# control_mode = torch.tensor(0, dtype=torch.long)
|
||||
# control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
return controlnet_extensions
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
|
||||
instantx_controlnet_extensions.append(
|
||||
InstantXControlNetExtension.from_controlnet_image(
|
||||
model=model,
|
||||
controlnet_image=image,
|
||||
# TODO(ryand): Pass in the correct control mode.
|
||||
instantx_control_mode=None,
|
||||
vae_info=vae_info,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
control_mode=controlnet.control_mode,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
|
||||
|
||||
return (xlabs_controlnet_extensions, instantx_controlnet_extensions)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
|
||||
Reference in New Issue
Block a user