Create a dedicated FLUX ControlNet invocation.

This commit is contained in:
Ryan Dick
2024-10-08 21:24:55 +00:00
parent 0dd9f1f772
commit dea6cbd599
4 changed files with 104 additions and 13 deletions

View File

@@ -0,0 +1,95 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
class FluxControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
controlnet_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("flux_controlnet_output")
class FluxControlNetOutput(BaseInvocationOutput):
"""FLUX ControlNet info"""
controlnet: FluxControlNetField = OutputField(description=FieldDescriptions.control)
@invocation(
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""
image: ImageField = InputField(description="The control image")
controlnet_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
return FluxControlNetOutput(
controlnet=FluxControlNetField(
image=self.image,
controlnet_model=self.controlnet_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
),
)

View File

@@ -6,7 +6,6 @@ import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -17,6 +16,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -92,7 +92,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
controlnet: ControlField | list[ControlField] | None = InputField(
controlnet: FluxControlNetField | list[FluxControlNetField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
@@ -322,10 +322,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device: torch.device,
) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[ControlField]
controlnets: list[FluxControlNetField]
if self.controlnet is None:
controlnets = []
elif isinstance(self.controlnet, ControlField):
elif isinstance(self.controlnet, FluxControlNetField):
controlnets = [self.controlnet]
elif isinstance(self.controlnet, list):
controlnets = self.controlnet
@@ -339,7 +339,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
xlabs_controlnet_extensions: list[XLabsControlNetExtension] = []
instantx_controlnet_extensions: list[InstantXControlNetExtension] = []
for controlnet in controlnets:
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
model = exit_stack.enter_context(context.models.load(controlnet.controlnet_model))
image = context.images.get_pil(controlnet.image.image_name)
if isinstance(model, XLabsControlNetFlux):
@@ -351,7 +351,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
@@ -377,7 +376,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latent_width=latent_width,
dtype=dtype,
device=device,
control_mode=controlnet.control_mode,
resize_mode=controlnet.resize_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,

View File

@@ -5,7 +5,7 @@ from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
)
@@ -51,7 +51,6 @@ class InstantXControlNetExtension(BaseControlNetExtension):
latent_width: int,
dtype: torch.dtype,
device: torch.device,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
@@ -67,7 +66,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
height=image_height,
device=device,
dtype=dtype,
control_mode=control_mode,
control_mode="balanced",
resize_mode=resize_mode,
)

View File

@@ -4,7 +4,7 @@ import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
@@ -39,7 +39,6 @@ class XLabsControlNetExtension(BaseControlNetExtension):
latent_width: int,
dtype: torch.dtype,
device: torch.device,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
@@ -55,7 +54,7 @@ class XLabsControlNetExtension(BaseControlNetExtension):
height=image_height,
device=device,
dtype=dtype,
control_mode=control_mode,
control_mode="balanced",
resize_mode=resize_mode,
)