mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Bugfixes to get InstantX ControlNet working.
This commit is contained in:
@@ -359,8 +359,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
)
|
||||
elif isinstance(model, InstantXControlNetFlux):
|
||||
# control_mode = torch.tensor(0, dtype=torch.long)
|
||||
# control_mode = control_mode.reshape([-1, 1])
|
||||
# TODO(ryand): Pass in the correct control mode.
|
||||
control_mode = torch.tensor(0, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
@@ -370,8 +371,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
InstantXControlNetExtension.from_controlnet_image(
|
||||
model=model,
|
||||
controlnet_image=image,
|
||||
# TODO(ryand): Pass in the correct control mode.
|
||||
instantx_control_mode=None,
|
||||
instantx_control_mode=control_mode,
|
||||
vae_info=vae_info,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
|
||||
@@ -129,7 +129,7 @@ class InstantXControlNetFlux(torch.nn.Module):
|
||||
assert controlnet_mode is not None
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
|
||||
else:
|
||||
assert controlnet_mode is None
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
|
||||
)
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
|
||||
@@ -30,8 +31,13 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
self._model = model
|
||||
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
|
||||
self._controlnet_cond = controlnet_cond
|
||||
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
|
||||
# The control mode for InstantX ControlNet union models.
|
||||
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
|
||||
# Expected shape: (batch_size, 1)
|
||||
# Expected dtype: torch.long
|
||||
self._instantx_control_mode = instantx_control_mode
|
||||
|
||||
@classmethod
|
||||
@@ -67,6 +73,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
|
||||
# Run VAE encoder.
|
||||
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
|
||||
controlnet_cond = pack(controlnet_cond)
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
@@ -93,6 +100,12 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
if weight < 1e-6:
|
||||
return None
|
||||
|
||||
# Make sure inputs have correct device and dtype.
|
||||
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
|
||||
self._instantx_control_mode = (
|
||||
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
|
||||
)
|
||||
|
||||
output: InstantXControlNetFluxOutput = self._model(
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
controlnet_mode=self._instantx_control_mode,
|
||||
|
||||
Reference in New Issue
Block a user