From 3953e60a4fd5d4699e15a5e81418f150ccf91bb2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 9 Oct 2024 22:00:54 +0000 Subject: [PATCH] Remove instantx_control_mode from FLUX ControlNet node. --- invokeai/app/invocations/flux_controlnet.py | 9 --------- invokeai/app/invocations/flux_denoise.py | 6 ++++-- .../backend/flux/controlnet/instantx_controlnet_flux.py | 9 +++++++-- .../flux/extensions/instantx_controlnet_extension.py | 6 +++--- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/invokeai/app/invocations/flux_controlnet.py b/invokeai/app/invocations/flux_controlnet.py index 41f4a6cc8c..2daded9287 100644 --- a/invokeai/app/invocations/flux_controlnet.py +++ b/invokeai/app/invocations/flux_controlnet.py @@ -25,10 +25,6 @@ class FluxControlNetField(BaseModel): default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" ) resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") - instantx_control_mode: int = Field( - default=0, - description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.", - ) @field_validator("control_weight") @classmethod @@ -74,10 +70,6 @@ class FluxControlNetInvocation(BaseInvocation): default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" ) resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") - instantx_control_mode: int = InputField( - default=0, - description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.", - ) @field_validator("control_weight") @classmethod @@ -99,6 +91,5 @@ class FluxControlNetInvocation(BaseInvocation): begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, resize_mode=self.resize_mode, - instantx_control_mode=self.instantx_control_mode, ), ) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 0fb65fbf1b..fc3d3cb8fd 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -356,8 +356,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): ) ) elif isinstance(model, InstantXControlNetFlux): - instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long) - instantx_control_mode = instantx_control_mode.reshape([-1, 1]) + instantx_control_mode: torch.Tensor | None = None + # if controlnet.instantx_control_mode is not None: + # instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long) + # instantx_control_mode = instantx_control_mode.reshape([-1, 1]) if self.controlnet_vae is None: raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.") diff --git a/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py b/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py index ebac20e822..1af5fbdfc0 100644 --- a/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py +++ b/invokeai/backend/flux/controlnet/instantx_controlnet_flux.py @@ -134,8 +134,13 @@ class InstantXControlNetFlux(torch.nn.Module): # If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding. if self.is_union: - assert controlnet_mode is not None - controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) + if controlnet_mode is None: + # We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input. + # We've chosen to use a zero-embedding in this case. + zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device) + controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index)) + else: + controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) txt = torch.cat([controlnet_mode_emb, txt], dim=1) txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1) else: diff --git a/invokeai/backend/flux/extensions/instantx_controlnet_extension.py b/invokeai/backend/flux/extensions/instantx_controlnet_extension.py index 2819a2d7e8..a57a89dbeb 100644 --- a/invokeai/backend/flux/extensions/instantx_controlnet_extension.py +++ b/invokeai/backend/flux/extensions/instantx_controlnet_extension.py @@ -38,8 +38,8 @@ class InstantXControlNetExtension(BaseControlNetExtension): # TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models? # The control mode for InstantX ControlNet union models. # See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode - # Expected shape: (batch_size, 1) - # Expected dtype: torch.long + # Expected shape: (batch_size, 1), Expected dtype: torch.long + # If None, a zero-embedding will be used. self._instantx_control_mode = instantx_control_mode # TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released. @@ -86,7 +86,7 @@ class InstantXControlNetExtension(BaseControlNetExtension): return cls( model=model, controlnet_cond=controlnet_cond, - instantx_control_mode=instantx_control_mode if model.is_union else None, + instantx_control_mode=instantx_control_mode, weight=weight, begin_step_percent=begin_step_percent, end_step_percent=end_step_percent,