diff --git a/invokeai/app/invocations/flux_controlnet.py b/invokeai/app/invocations/flux_controlnet.py index 2daded9287..41f4a6cc8c 100644 --- a/invokeai/app/invocations/flux_controlnet.py +++ b/invokeai/app/invocations/flux_controlnet.py @@ -25,6 +25,10 @@ class FluxControlNetField(BaseModel): default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" ) resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") + instantx_control_mode: int = Field( + default=0, + description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.", + ) @field_validator("control_weight") @classmethod @@ -70,6 +74,10 @@ class FluxControlNetInvocation(BaseInvocation): default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" ) resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") + instantx_control_mode: int = InputField( + default=0, + description="The control mode for InstantX ControlNet union models. Ignored for other ControlNet models.", + ) @field_validator("control_weight") @classmethod @@ -91,5 +99,6 @@ class FluxControlNetInvocation(BaseInvocation): begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, resize_mode=self.resize_mode, + instantx_control_mode=self.instantx_control_mode, ), ) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index b369a7ae92..c6bd1b969d 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -358,9 +358,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): ) ) elif isinstance(model, InstantXControlNetFlux): - # TODO(ryand): Pass in the correct control mode. - control_mode = torch.tensor(0, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long) + instantx_control_mode = instantx_control_mode.reshape([-1, 1]) if self.controlnet_vae is None: raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.") @@ -370,7 +369,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): InstantXControlNetExtension.from_controlnet_image( model=model, controlnet_image=image, - instantx_control_mode=control_mode, + instantx_control_mode=instantx_control_mode, vae_info=vae_info, latent_height=latent_height, latent_width=latent_width, diff --git a/invokeai/backend/flux/extensions/instantx_controlnet_extension.py b/invokeai/backend/flux/extensions/instantx_controlnet_extension.py index 1c8650be53..67100f6153 100644 --- a/invokeai/backend/flux/extensions/instantx_controlnet_extension.py +++ b/invokeai/backend/flux/extensions/instantx_controlnet_extension.py @@ -77,7 +77,7 @@ class InstantXControlNetExtension(BaseControlNetExtension): return cls( model=model, controlnet_cond=controlnet_cond, - instantx_control_mode=instantx_control_mode, + instantx_control_mode=instantx_control_mode if model.is_union else None, weight=weight, begin_step_percent=begin_step_percent, end_step_percent=end_step_percent,