Add multicontrolnet (#1958)

This commit is contained in:
gpetters94
2023-12-01 14:51:20 -05:00
committed by GitHub
parent 552b2c3ee3
commit ff15fd74f6
11 changed files with 392 additions and 144 deletions

View File

@@ -86,7 +86,7 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
@@ -144,7 +144,7 @@ class SharkifyStableDiffusionModel:
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
@@ -195,8 +195,9 @@ class SharkifyStableDiffusionModel:
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
# TODO: Fix this
# if "stencil_adaptor" == model and self.use_stencil is not None:
# model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
@@ -394,6 +395,22 @@ class SharkifyStableDiffusionModel:
scale12,
scale13,
):
# TODO: Average pooling
db_res_samples = [
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
@@ -488,11 +505,11 @@ class SharkifyStableDiffusionModel:
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
@@ -507,6 +524,19 @@ class SharkifyStableDiffusionModel:
timestep,
text_embedding,
stencil_image_input,
acc1,
acc2,
acc3,
acc4,
acc5,
acc6,
acc7,
acc8,
acc9,
acc10,
acc11,
acc12,
acc13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
@@ -528,6 +558,20 @@ class SharkifyStableDiffusionModel:
)
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
) + (
acc1 + down_block_res_samples[0],
acc2 + down_block_res_samples[1],
acc3 + down_block_res_samples[2],
acc4 + down_block_res_samples[3],
acc5 + down_block_res_samples[4],
acc6 + down_block_res_samples[5],
acc7 + down_block_res_samples[6],
acc8 + down_block_res_samples[7],
acc9 + down_block_res_samples[8],
acc10 + down_block_res_samples[9],
acc11 + down_block_res_samples[10],
acc12 + down_block_res_samples[11],
acc13 + mid_block_res_sample,
)
scnet = StencilControlNetModel(
@@ -543,7 +587,7 @@ class SharkifyStableDiffusionModel:
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
*inputs[3:],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
@@ -552,7 +596,7 @@ class SharkifyStableDiffusionModel:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
input_mask = [True, True, True, True] + ([True] * 13)
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
@@ -837,7 +881,10 @@ class SharkifyStableDiffusionModel:
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
stencil_count = 0
for stencil in self.stencils:
stencil_count += 1
model = "stencil_unet" if stencil_count > 0 else "unet"
compiled_unet = None
unet_inputs = base_models[model]
@@ -906,13 +953,13 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
stencil_id, use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -158,7 +158,6 @@ class Image2ImagePipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.

View File

@@ -55,28 +55,47 @@ class StencilPipeline(StableDiffusionPipeline):
import_mlir: bool,
use_lora: str,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names
def load_controlnet(self):
if self.controlnet is not None:
def load_controlnet(self, index, model_name):
if model_name is None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
if (
self.controlnet[index] is not None
and self.controlnet_id[index] is not None
and self.controlnet_id[index] == model_name
):
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
self.controlnet_id[index] = model_name
self.controlnet[index] = self.sd_model.controlnet(model_name)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def unload_controlnet(self, index):
del self.controlnet[index]
self.controlnet_id[index] = None
self.controlnet[index] = None
def load_controlnet_512(self, index, model_name):
if (
self.controlnet_512[index] is not None
and self.controlnet_512_id[index] == model_name
):
return
self.controlnet_512_id[index] = model_name
self.controlnet_512[index] = self.sd_model.controlnet(
model_name, use_large=True
)
def unload_controlnet_512(self, index):
del self.controlnet_512[index]
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
self,
@@ -111,7 +130,7 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
stencil_hints=[None],
controlnet_conditioning_scale: float = 1.0,
control_mode="Balanced", # Prompt, Balanced, or Controlnet
mask=None,
@@ -125,10 +144,15 @@ class StencilPipeline(StableDiffusionPipeline):
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, name in enumerate(self.controlnet_names):
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -151,28 +175,72 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
# Multicontrolnet
width = latent_model_input_1.shape[2]
height = latent_model_input_1.shape[3]
dtype = latent_model_input_1.dtype
control_acc = (
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
+ [
torch.zeros(
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
)
]
* 4
)
for i, controlnet_hint in enumerate(stencil_hints):
if controlnet_hint is None:
continue
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
else:
control = self.controlnet_512[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
control_acc = control[13:]
control = control[:13]
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
@@ -289,8 +357,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
for i in range(len(self.controlnet_names)):
self.unload_controlnet(i)
self.unload_controlnet_512(i)
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -316,15 +385,30 @@ class StencilPipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
stencil_images,
resample_type,
control_mode,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
for i, stencil in enumerate(stencils):
image = stencil_images[i]
stencil_hints.append(
controlnet_hint_conversion(
image,
stencil,
height,
width,
dtype,
num_images_per_prompt=1,
)
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -372,8 +456,8 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
control_mode=control_mode,
stencil_hints=stencil_hints,
)
# Img latents -> PIL images

View File

@@ -338,7 +338,8 @@ class StableDiffusionPipeline:
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
@@ -371,7 +372,7 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
stencils=stencils,
use_lora=use_lora,
use_quantize=use_quantize,
)
@@ -386,6 +387,10 @@ class StableDiffusionPipeline:
ondemand,
)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################

View File

@@ -208,6 +208,58 @@
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
},
"acc1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"acc5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"acc8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"stencil_unet": {

View File

@@ -31,6 +31,10 @@ from apps.stable_diffusion.src.utils import (
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -60,7 +64,6 @@ def img2img_inf(
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
@@ -69,6 +72,8 @@ def img2img_inf(
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -90,11 +95,17 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return None, "A stencil must have an Image input"
if images[i] is not None:
images[i] = images[i].convert("RGB")
if image_dict is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
# if use_stencil == "scribble":
# image = image_dict["mask"].convert("RGB")
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
image = image_dict["image"].convert("RGB")
@@ -124,12 +135,14 @@ def img2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
stencil_count = 0
for stencil in stencils:
if stencil is not None:
stencil_count += 1
if stencil_count > 0:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
# image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete "
@@ -151,7 +164,7 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
stencils=stencils,
ondemand=ondemand,
)
if (
@@ -178,7 +191,7 @@ def img2img_inf(
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
if stencil_count > 0:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
@@ -195,7 +208,7 @@ def img2img_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
@@ -252,7 +265,8 @@ def img2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
stencils,
images,
resample_type=resample_type,
control_mode=control_mode,
)
@@ -274,12 +288,17 @@ def img2img_inf(
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
), stencils, images
return generated_imgs, text_output, ""
return generated_imgs, text_output, "", stencils, images
with gr.Blocks(title="Image-to-Image") as img2img_web:
# Stencils
# TODO: Add more stencils here
STENCIL_COUNT = 2
stencils = gr.State([None] * STENCIL_COUNT)
images = gr.State([None] * STENCIL_COUNT)
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -350,70 +369,105 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
height=300,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Accordion(label="Multistencil Options", open=False):
choices = ["None", "canny", "openpose", "scribble"]
def cnet_preview(
checked, model, input_image, index, stencils, images
):
if not checked:
stencils[index] = None
images[index] = None
return (None, stencils, images)
images[index] = input_image
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(np.array(input_image), 100, 200)
return (
[Image.fromarray(result), result],
stencils,
images,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(np.array(input_image))
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
return (
[Image.fromarray(result[0]), result],
stencils,
images,
)
case _:
return (None, stencils, images)
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
cnet_1 = gr.Checkbox(show_label=False)
cnet_1_model = gr.Dropdown(
label="Controlnet 1",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
choices=choices,
)
cnet_1_image = gr.Image(
source="upload",
tool=None,
type="pil",
)
cnet_1_output = gr.Gallery(
show_label=False,
object_fit="scale-down",
rows=1,
columns=1,
)
cnet_1.change(
fn=(
lambda a, b, c, s, i: cnet_preview(
a, b, c, 0, s, i
)
),
inputs=[
cnet_1,
cnet_1_model,
cnet_1_image,
stencils,
images,
],
outputs=[cnet_1_output, stencils, images],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
cnet_2 = gr.Checkbox(show_label=False)
cnet_2_model = gr.Dropdown(
label="Controlnet 2",
value="None",
choices=choices,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
cnet_2_image = gr.Image(
source="upload",
tool=None,
type="pil",
)
cnet_2_output = gr.Gallery(
show_label=False,
object_fit="scale-down",
rows=1,
columns=1,
)
cnet_2.change(
fn=(
lambda a, b, c, s, i: cnet_preview(
a, b, c, 1, s, i
)
),
inputs=[
cnet_2,
cnet_2_model,
cnet_2_image,
stencils,
images,
],
outputs=[cnet_2_output, stencils, images],
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
@@ -624,7 +678,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
@@ -633,8 +686,16 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
repeatable_seeds,
resample_type,
control_mode,
stencils,
images,
],
outputs=[
img2img_gallery,
std_output,
img2img_status,
stencils,
images,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
)

View File

@@ -122,7 +122,7 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (

View File

@@ -121,7 +121,7 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (

View File

@@ -126,7 +126,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -226,7 +226,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil="None",
stencils=[],
ondemand=ondemand,
)
@@ -280,7 +280,7 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
stencils=[],
resample_type=resample_type,
)
total_time = time.time() - start_time

View File

@@ -120,7 +120,7 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (

View File

@@ -30,7 +30,7 @@ class Config:
width: int
device: str
use_lora: str
use_stencil: str
stencils: list[str]
ondemand: str # should this be expecting a bool instead?