mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Create a dedicated FLUX ControlNet invocation.
This commit is contained in:
95
invokeai/app/invocations/flux_controlnet.py
Normal file
95
invokeai/app/invocations/flux_controlnet.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
|
||||
|
||||
class FluxControlNetField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
controlnet_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("flux_controlnet_output")
|
||||
class FluxControlNetOutput(BaseInvocationOutput):
|
||||
"""FLUX ControlNet info"""
|
||||
|
||||
controlnet: FluxControlNetField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_controlnet",
|
||||
title="FLUX ControlNet",
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
"""Collect FLUX ControlNet info to pass to other nodes."""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
controlnet_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float | list[float] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
|
||||
return FluxControlNetOutput(
|
||||
controlnet=FluxControlNetField(
|
||||
image=self.image,
|
||||
controlnet_model=self.controlnet_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
@@ -6,7 +6,6 @@ import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -17,6 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -92,7 +92,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
controlnet: ControlField | list[ControlField] | None = InputField(
|
||||
controlnet: FluxControlNetField | list[FluxControlNetField] | None = InputField(
|
||||
default=None, input=Input.Connection, description="ControlNet models."
|
||||
)
|
||||
controlnet_vae: VAEField | None = InputField(
|
||||
@@ -322,10 +322,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
device: torch.device,
|
||||
) -> tuple[list[XLabsControlNetExtension], list[InstantXControlNetExtension]]:
|
||||
# Normalize the controlnet input to list[ControlField].
|
||||
controlnets: list[ControlField]
|
||||
controlnets: list[FluxControlNetField]
|
||||
if self.controlnet is None:
|
||||
controlnets = []
|
||||
elif isinstance(self.controlnet, ControlField):
|
||||
elif isinstance(self.controlnet, FluxControlNetField):
|
||||
controlnets = [self.controlnet]
|
||||
elif isinstance(self.controlnet, list):
|
||||
controlnets = self.controlnet
|
||||
@@ -339,7 +339,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
xlabs_controlnet_extensions: list[XLabsControlNetExtension] = []
|
||||
instantx_controlnet_extensions: list[InstantXControlNetExtension] = []
|
||||
for controlnet in controlnets:
|
||||
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
|
||||
model = exit_stack.enter_context(context.models.load(controlnet.controlnet_model))
|
||||
image = context.images.get_pil(controlnet.image.image_name)
|
||||
|
||||
if isinstance(model, XLabsControlNetFlux):
|
||||
@@ -351,7 +351,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
control_mode=controlnet.control_mode,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
@@ -377,7 +376,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
control_mode=controlnet.control_mode,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
|
||||
@@ -5,7 +5,7 @@ from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
|
||||
InstantXControlNetFlux,
|
||||
)
|
||||
@@ -51,7 +51,6 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
control_mode: CONTROLNET_MODE_VALUES,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
@@ -67,7 +66,7 @@ class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode=control_mode,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux_output import XLabsControlNetFluxOutput
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
@@ -39,7 +39,6 @@ class XLabsControlNetExtension(BaseControlNetExtension):
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
control_mode: CONTROLNET_MODE_VALUES,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
@@ -55,7 +54,7 @@ class XLabsControlNetExtension(BaseControlNetExtension):
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode=control_mode,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user