mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Bugfixes to get InstantX ControlNet working.
This commit is contained in:
@@ -129,7 +129,7 @@ class InstantXControlNetFlux(torch.nn.Module):
|
||||
assert controlnet_mode is not None
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
|
||||
else:
|
||||
assert controlnet_mode is None
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
|
||||
)
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux_output import InstantXControlNetFluxOutput
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
|
||||
@@ -30,8 +31,13 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
self._model = model
|
||||
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
|
||||
self._controlnet_cond = controlnet_cond
|
||||
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
|
||||
# The control mode for InstantX ControlNet union models.
|
||||
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
|
||||
# Expected shape: (batch_size, 1)
|
||||
# Expected dtype: torch.long
|
||||
self._instantx_control_mode = instantx_control_mode
|
||||
|
||||
@classmethod
|
||||
@@ -67,6 +73,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
|
||||
# Run VAE encoder.
|
||||
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
|
||||
controlnet_cond = pack(controlnet_cond)
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
@@ -93,6 +100,12 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
if weight < 1e-6:
|
||||
return None
|
||||
|
||||
# Make sure inputs have correct device and dtype.
|
||||
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
|
||||
self._instantx_control_mode = (
|
||||
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
|
||||
)
|
||||
|
||||
output: InstantXControlNetFluxOutput = self._model(
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
controlnet_mode=self._instantx_control_mode,
|
||||
|
||||
Reference in New Issue
Block a user