Added minimal viable node support for ControlNet in TextToImageInvocation.

This commit is contained in:
user1
2023-04-25 13:30:00 -07:00
parent 6bd74de8f1
commit cf6941f665

View File

@@ -4,7 +4,9 @@ from functools import partial
from typing import Literal, Optional, Union
import numpy as np
from diffusers import ControlNetModel
from torch import Tensor
import torch
from pydantic import BaseModel, Field
@@ -53,6 +55,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# control_strength: Optional[float] = Field(default=1.0, ge=0, le=1, description="The strength of the controlnet")
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
@@ -70,20 +75,36 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get(
self.control_image.image_type, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate(
txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict(
exclude={"prompt"}
exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object