mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 23:58:00 -05:00
137 lines
5.3 KiB
Python
137 lines
5.3 KiB
Python
# Invocations for ControlNet image preprocessors
|
|
# initial implementation by Gregg Helt, 2023
|
|
from typing import List, Union
|
|
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
|
from invokeai.app.invocations.baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
Classification,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
from invokeai.app.invocations.fields import (
|
|
FieldDescriptions,
|
|
ImageField,
|
|
InputField,
|
|
OutputField,
|
|
)
|
|
from invokeai.app.invocations.model import ModelIdentifierField
|
|
from invokeai.app.invocations.primitives import ImageOutput
|
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.app.util.controlnet_utils import (
|
|
CONTROLNET_MODE_VALUES,
|
|
CONTROLNET_RESIZE_VALUES,
|
|
heuristic_resize_fast,
|
|
)
|
|
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
|
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
|
|
|
|
|
class ControlField(BaseModel):
|
|
image: ImageField = Field(description="The control image")
|
|
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
|
begin_step_percent: float = Field(
|
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
|
)
|
|
end_step_percent: float = Field(
|
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
)
|
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
|
|
|
@field_validator("control_weight")
|
|
@classmethod
|
|
def validate_control_weight(cls, v):
|
|
validate_weights(v)
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_begin_end_step_percent(self):
|
|
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
|
return self
|
|
|
|
|
|
@invocation_output("control_output")
|
|
class ControlOutput(BaseInvocationOutput):
|
|
"""node output for ControlNet info"""
|
|
|
|
# Outputs
|
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
|
|
|
|
|
@invocation(
|
|
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
|
|
)
|
|
class ControlNetInvocation(BaseInvocation):
|
|
"""Collects ControlNet info to pass to other nodes"""
|
|
|
|
image: ImageField = InputField(description="The control image")
|
|
control_model: ModelIdentifierField = InputField(
|
|
description=FieldDescriptions.controlnet_model,
|
|
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
|
|
ui_model_type=ModelType.ControlNet,
|
|
)
|
|
control_weight: Union[float, List[float]] = InputField(
|
|
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
|
)
|
|
begin_step_percent: float = InputField(
|
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
|
)
|
|
end_step_percent: float = InputField(
|
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
)
|
|
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
|
|
|
@field_validator("control_weight")
|
|
@classmethod
|
|
def validate_control_weight(cls, v):
|
|
validate_weights(v)
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
|
|
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
|
return self
|
|
|
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
|
return ControlOutput(
|
|
control=ControlField(
|
|
image=self.image,
|
|
control_model=self.control_model,
|
|
control_weight=self.control_weight,
|
|
begin_step_percent=self.begin_step_percent,
|
|
end_step_percent=self.end_step_percent,
|
|
control_mode=self.control_mode,
|
|
resize_mode=self.resize_mode,
|
|
),
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"heuristic_resize",
|
|
title="Heuristic Resize",
|
|
tags=["image, controlnet"],
|
|
category="image",
|
|
version="1.1.1",
|
|
classification=Classification.Prototype,
|
|
)
|
|
class HeuristicResizeInvocation(BaseInvocation):
|
|
"""Resize an image using a heuristic method. Preserves edge maps."""
|
|
|
|
image: ImageField = InputField(description="The image to resize")
|
|
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
|
|
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.images.get_pil(self.image.image_name, "RGB")
|
|
np_img = pil_to_np(image)
|
|
np_resized = heuristic_resize_fast(np_img, (self.width, self.height))
|
|
resized = np_to_pil(np_resized)
|
|
image_dto = context.images.save(image=resized)
|
|
return ImageOutput.build(image_dto)
|