Bugfixes to get InstantX ControlNet working.

This commit is contained in:
Ryan Dick
2024-10-08 19:22:29 +00:00
committed by Kent Keirsey
parent ce4624f72b
commit de414c09fd
3 changed files with 18 additions and 5 deletions

View File

@@ -359,8 +359,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
)
)
elif isinstance(model, InstantXControlNetFlux):
# control_mode = torch.tensor(0, dtype=torch.long)
# control_mode = control_mode.reshape([-1, 1])
# TODO(ryand): Pass in the correct control mode.
control_mode = torch.tensor(0, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
@@ -370,8 +371,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
InstantXControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
# TODO(ryand): Pass in the correct control mode.
instantx_control_mode=None,
instantx_control_mode=control_mode,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,

View File

@@ -129,7 +129,7 @@ class InstantXControlNetFlux(torch.nn.Module):
assert controlnet_mode is not None
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
else:
assert controlnet_mode is None

View File

@@ -11,6 +11,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
)
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.model_manager.load.load_base import LoadedModel
@@ -30,8 +31,13 @@ class InstantXControlNetExtension(BaseControlNetExtension):
end_step_percent=end_step_percent,
)
self._model = model
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
self._controlnet_cond = controlnet_cond
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
# The control mode for InstantX ControlNet union models.
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
# Expected shape: (batch_size, 1)
# Expected dtype: torch.long
self._instantx_control_mode = instantx_control_mode
@classmethod
@@ -67,6 +73,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return cls(
model=model,
@@ -93,6 +100,12 @@ class InstantXControlNetExtension(BaseControlNetExtension):
if weight < 1e-6:
return None
# Make sure inputs have correct device and dtype.
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
self._instantx_control_mode = (
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
)
output: InstantXControlNetFluxOutput = self._model(
controlnet_cond=self._controlnet_cond,
controlnet_mode=self._instantx_control_mode,