Remove instantx_control_mode from FLUX ControlNet node.

This commit is contained in:
Ryan Dick
2024-10-09 22:00:54 +00:00
parent 63a2e17f6b
commit 3953e60a4f
4 changed files with 14 additions and 16 deletions

View File

@@ -25,10 +25,6 @@ class FluxControlNetField(BaseModel):
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
instantx_control_mode: int = Field(
default=0,
description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.",
)
@field_validator("control_weight")
@classmethod
@@ -74,10 +70,6 @@ class FluxControlNetInvocation(BaseInvocation):
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
instantx_control_mode: int = InputField(
default=0,
description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.",
)
@field_validator("control_weight")
@classmethod
@@ -99,6 +91,5 @@ class FluxControlNetInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
instantx_control_mode=self.instantx_control_mode,
),
)

View File

@@ -356,8 +356,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
)
)
elif isinstance(model, InstantXControlNetFlux):
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
instantx_control_mode: torch.Tensor | None = None
# if controlnet.instantx_control_mode is not None:
# instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
# instantx_control_mode = instantx_control_mode.reshape([-1, 1])
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")

View File

@@ -134,8 +134,13 @@ class InstantXControlNetFlux(torch.nn.Module):
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
if self.is_union:
assert controlnet_mode is not None
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
if controlnet_mode is None:
# We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input.
# We've chosen to use a zero-embedding in this case.
zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device)
controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index))
else:
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
else:

View File

@@ -38,8 +38,8 @@ class InstantXControlNetExtension(BaseControlNetExtension):
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
# The control mode for InstantX ControlNet union models.
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
# Expected shape: (batch_size, 1)
# Expected dtype: torch.long
# Expected shape: (batch_size, 1), Expected dtype: torch.long
# If None, a zero-embedding will be used.
self._instantx_control_mode = instantx_control_mode
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
@@ -86,7 +86,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
return cls(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode if model.is_union else None,
instantx_control_mode=instantx_control_mode,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,