Add instantx_control_mode param to FLUX ControlNet invocation.

This commit is contained in:
Ryan Dick
2024-10-08 21:52:59 +00:00
parent dea6cbd599
commit cd88723a80
3 changed files with 13 additions and 5 deletions

View File

@@ -25,6 +25,10 @@ class FluxControlNetField(BaseModel):
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
instantx_control_mode: int = Field(
default=0,
description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.",
)
@field_validator("control_weight")
@classmethod
@@ -70,6 +74,10 @@ class FluxControlNetInvocation(BaseInvocation):
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
instantx_control_mode: int = InputField(
default=0,
description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.",
)
@field_validator("control_weight")
@classmethod
@@ -91,5 +99,6 @@ class FluxControlNetInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
instantx_control_mode=self.instantx_control_mode,
),
)

View File

@@ -358,9 +358,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
)
)
elif isinstance(model, InstantXControlNetFlux):
# TODO(ryand): Pass in the correct control mode.
control_mode = torch.tensor(0, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
@@ -370,7 +369,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
InstantXControlNetExtension.from_controlnet_image(
model=model,
controlnet_image=image,
instantx_control_mode=control_mode,
instantx_control_mode=instantx_control_mode,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,

View File

@@ -77,7 +77,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
return cls(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
instantx_control_mode=instantx_control_mode if model.is_union else None,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,