mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add multicontrolnet (#1958)
This commit is contained in:
@@ -86,7 +86,7 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb: bool = True,
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
use_stencil: str = None,
|
||||
stencils: list[str] = [],
|
||||
use_lora: str = "",
|
||||
use_quantize: str = None,
|
||||
return_mlir: bool = False,
|
||||
@@ -144,7 +144,7 @@ class SharkifyStableDiffusionModel:
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.is_inpaint = is_inpaint
|
||||
self.is_upscaler = is_upscaler
|
||||
self.use_stencil = get_stencil_model_id(use_stencil)
|
||||
self.stencils = [get_stencil_model_id(x) for x in stencils]
|
||||
if use_lora != "":
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.use_lora = use_lora
|
||||
@@ -195,8 +195,9 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
model_config = model_config + get_path_stem(self.use_stencil)
|
||||
# TODO: Fix this
|
||||
# if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
# model_config = model_config + get_path_stem(self.use_stencil)
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
return model_name
|
||||
@@ -394,6 +395,22 @@ class SharkifyStableDiffusionModel:
|
||||
scale12,
|
||||
scale13,
|
||||
):
|
||||
# TODO: Average pooling
|
||||
db_res_samples = [
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
]
|
||||
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
db_res_samples = tuple(
|
||||
[
|
||||
@@ -488,11 +505,11 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self, use_large=False):
|
||||
def get_control_net(self, stencil_id, use_large=False):
|
||||
stencil_id = get_stencil_model_id(stencil_id)
|
||||
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
):
|
||||
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.cnet = ControlNetModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -507,6 +524,19 @@ class SharkifyStableDiffusionModel:
|
||||
timestep,
|
||||
text_embedding,
|
||||
stencil_image_input,
|
||||
acc1,
|
||||
acc2,
|
||||
acc3,
|
||||
acc4,
|
||||
acc5,
|
||||
acc6,
|
||||
acc7,
|
||||
acc8,
|
||||
acc9,
|
||||
acc10,
|
||||
acc11,
|
||||
acc12,
|
||||
acc13,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
# TODO: guidance NOT NEEDED change in `get_input_info` later
|
||||
@@ -528,6 +558,20 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return tuple(
|
||||
list(down_block_res_samples) + [mid_block_res_sample]
|
||||
) + (
|
||||
acc1 + down_block_res_samples[0],
|
||||
acc2 + down_block_res_samples[1],
|
||||
acc3 + down_block_res_samples[2],
|
||||
acc4 + down_block_res_samples[3],
|
||||
acc5 + down_block_res_samples[4],
|
||||
acc6 + down_block_res_samples[5],
|
||||
acc7 + down_block_res_samples[6],
|
||||
acc8 + down_block_res_samples[7],
|
||||
acc9 + down_block_res_samples[8],
|
||||
acc10 + down_block_res_samples[9],
|
||||
acc11 + down_block_res_samples[10],
|
||||
acc12 + down_block_res_samples[11],
|
||||
acc13 + mid_block_res_sample,
|
||||
)
|
||||
|
||||
scnet = StencilControlNetModel(
|
||||
@@ -543,7 +587,7 @@ class SharkifyStableDiffusionModel:
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
*inputs[3:],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
@@ -552,7 +596,7 @@ class SharkifyStableDiffusionModel:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
input_mask = [True, True, True, True] + ([True] * 13)
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
@@ -837,7 +881,10 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
stencil_count = 0
|
||||
for stencil in self.stencils:
|
||||
stencil_count += 1
|
||||
model = "stencil_unet" if stencil_count > 0 else "unet"
|
||||
compiled_unet = None
|
||||
unet_inputs = base_models[model]
|
||||
|
||||
@@ -906,13 +953,13 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self, use_large=False):
|
||||
def controlnet(self, stencil_id, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
stencil_id, use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
|
||||
@@ -158,7 +158,6 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
resample_type,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
|
||||
@@ -55,28 +55,47 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
self.controlnet_512_id = [str] * len(controlnet_names)
|
||||
self.controlnet_names = controlnet_names
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
def load_controlnet(self, index, model_name):
|
||||
if model_name is None:
|
||||
return
|
||||
self.controlnet = self.sd_model.controlnet()
|
||||
|
||||
def unload_controlnet(self):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
if (
|
||||
self.controlnet[index] is not None
|
||||
and self.controlnet_id[index] is not None
|
||||
and self.controlnet_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
self.controlnet_id[index] = model_name
|
||||
self.controlnet[index] = self.sd_model.controlnet(model_name)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
def unload_controlnet(self, index):
|
||||
del self.controlnet[index]
|
||||
self.controlnet_id[index] = None
|
||||
self.controlnet[index] = None
|
||||
|
||||
def load_controlnet_512(self, index, model_name):
|
||||
if (
|
||||
self.controlnet_512[index] is not None
|
||||
and self.controlnet_512_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512_id[index] = model_name
|
||||
self.controlnet_512[index] = self.sd_model.controlnet(
|
||||
model_name, use_large=True
|
||||
)
|
||||
|
||||
def unload_controlnet_512(self, index):
|
||||
del self.controlnet_512[index]
|
||||
self.controlnet_512_id[index] = None
|
||||
self.controlnet_512[index] = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -111,7 +130,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
controlnet_hint=None,
|
||||
stencil_hints=[None],
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
control_mode="Balanced", # Prompt, Balanced, or Controlnet
|
||||
mask=None,
|
||||
@@ -125,10 +144,15 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -151,28 +175,72 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
# Multicontrolnet
|
||||
width = latent_model_input_1.shape[2]
|
||||
height = latent_model_input_1.shape[3]
|
||||
dtype = latent_model_input_1.dtype
|
||||
control_acc = (
|
||||
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 4
|
||||
)
|
||||
for i, controlnet_hint in enumerate(stencil_hints):
|
||||
if controlnet_hint is None:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
control_acc = control[13:]
|
||||
control = control[:13]
|
||||
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
@@ -289,8 +357,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
for i in range(len(self.controlnet_names)):
|
||||
self.unload_controlnet(i)
|
||||
self.unload_controlnet_512(i)
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -316,15 +385,30 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
stencils,
|
||||
stencil_images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
controlnet_hint = controlnet_hint_conversion(
|
||||
image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
)
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
for i, stencil in enumerate(stencils):
|
||||
image = stencil_images[i]
|
||||
stencil_hints.append(
|
||||
controlnet_hint_conversion(
|
||||
image,
|
||||
stencil,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
)
|
||||
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
@@ -372,8 +456,8 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
controlnet_hint=controlnet_hint,
|
||||
control_mode=control_mode,
|
||||
stencil_hints=stencil_hints,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
|
||||
@@ -338,7 +338,8 @@ class StableDiffusionPipeline:
|
||||
ondemand: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
debug: bool = False,
|
||||
use_stencil: str = None,
|
||||
stencils: list[str] = [],
|
||||
# stencil_images: list[Image] = []
|
||||
use_lora: str = "",
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
@@ -371,7 +372,7 @@ class StableDiffusionPipeline:
|
||||
debug=debug,
|
||||
is_inpaint=is_inpaint,
|
||||
is_upscaler=is_upscaler,
|
||||
use_stencil=use_stencil,
|
||||
stencils=stencils,
|
||||
use_lora=use_lora,
|
||||
use_quantize=use_quantize,
|
||||
)
|
||||
@@ -386,6 +387,10 @@ class StableDiffusionPipeline:
|
||||
ondemand,
|
||||
)
|
||||
|
||||
if cls.__name__ == "StencilPipeline":
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
|
||||
)
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
|
||||
# #####################################################
|
||||
|
||||
@@ -208,6 +208,58 @@
|
||||
"controlnet_hint": {
|
||||
"shape": [1, 3, "8*height", "8*width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stencil_unet": {
|
||||
|
||||
@@ -31,6 +31,10 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
|
||||
@@ -60,7 +64,6 @@ def img2img_inf(
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
use_stencil: str,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
@@ -69,6 +72,8 @@ def img2img_inf(
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -90,11 +95,17 @@ def img2img_inf(
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
return None, "A stencil must have an Image input"
|
||||
if images[i] is not None:
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if image_dict is None:
|
||||
return None, "An Initial Image is required"
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
elif isinstance(image_dict, PIL.Image.Image):
|
||||
# if use_stencil == "scribble":
|
||||
# image = image_dict["mask"].convert("RGB")
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
@@ -124,12 +135,14 @@ def img2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
use_stencil = None if use_stencil == "None" else use_stencil
|
||||
args.use_stencil = use_stencil
|
||||
if use_stencil is not None:
|
||||
stencil_count = 0
|
||||
for stencil in stencils:
|
||||
if stencil is not None:
|
||||
stencil_count += 1
|
||||
if stencil_count > 0:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, width, height = resize_stencil(image)
|
||||
# image, width, height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete "
|
||||
@@ -151,7 +164,7 @@ def img2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
stencils=stencils,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -178,7 +191,7 @@ def img2img_inf(
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
if use_stencil is not None:
|
||||
if stencil_count > 0:
|
||||
args.use_tuned = False
|
||||
global_obj.set_sd_obj(
|
||||
StencilPipeline.from_pretrained(
|
||||
@@ -195,7 +208,7 @@ def img2img_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
stencils=stencils,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
@@ -252,7 +265,8 @@ def img2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
stencils,
|
||||
images,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
@@ -274,12 +288,17 @@ def img2img_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
), stencils, images
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
return generated_imgs, text_output, "", stencils, images
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
# Stencils
|
||||
# TODO: Add more stencils here
|
||||
STENCIL_COUNT = 2
|
||||
stencils = gr.State([None] * STENCIL_COUNT)
|
||||
images = gr.State([None] * STENCIL_COUNT)
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -350,70 +369,105 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
height=300,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
with gr.Accordion(label="Multistencil Options", open=False):
|
||||
choices = ["None", "canny", "openpose", "scribble"]
|
||||
|
||||
def cnet_preview(
|
||||
checked, model, input_image, index, stencils, images
|
||||
):
|
||||
if not checked:
|
||||
stencils[index] = None
|
||||
images[index] = None
|
||||
return (None, stencils, images)
|
||||
images[index] = input_image
|
||||
stencils[index] = model
|
||||
match model:
|
||||
case "canny":
|
||||
canny = CannyDetector()
|
||||
result = canny(np.array(input_image), 100, 200)
|
||||
return (
|
||||
[Image.fromarray(result), result],
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case "openpose":
|
||||
openpose = OpenposeDetector()
|
||||
result = openpose(np.array(input_image))
|
||||
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
|
||||
return (
|
||||
[Image.fromarray(result[0]), result],
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case _:
|
||||
return (None, stencils, images)
|
||||
|
||||
with gr.Row():
|
||||
use_stencil = gr.Dropdown(
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
cnet_1 = gr.Checkbox(show_label=False)
|
||||
cnet_1_model = gr.Dropdown(
|
||||
label="Controlnet 1",
|
||||
value="None",
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
choices=choices,
|
||||
)
|
||||
cnet_1_image = gr.Image(
|
||||
source="upload",
|
||||
tool=None,
|
||||
type="pil",
|
||||
)
|
||||
cnet_1_output = gr.Gallery(
|
||||
show_label=False,
|
||||
object_fit="scale-down",
|
||||
rows=1,
|
||||
columns=1,
|
||||
)
|
||||
cnet_1.change(
|
||||
fn=(
|
||||
lambda a, b, c, s, i: cnet_preview(
|
||||
a, b, c, 0, s, i
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_1,
|
||||
cnet_1_model,
|
||||
cnet_1_image,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[cnet_1_output, stencils, images],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
if choice == "scribble":
|
||||
return (
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Button.update(visible=False),
|
||||
)
|
||||
|
||||
def create_canvas(w, h):
|
||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
cnet_2 = gr.Checkbox(show_label=False)
|
||||
cnet_2_model = gr.Dropdown(
|
||||
label="Controlnet 2",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
cnet_2_image = gr.Image(
|
||||
source="upload",
|
||||
tool=None,
|
||||
type="pil",
|
||||
)
|
||||
cnet_2_output = gr.Gallery(
|
||||
show_label=False,
|
||||
object_fit="scale-down",
|
||||
rows=1,
|
||||
columns=1,
|
||||
)
|
||||
cnet_2.change(
|
||||
fn=(
|
||||
lambda a, b, c, s, i: cnet_preview(
|
||||
a, b, c, 1, s, i
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_2,
|
||||
cnet_2_model,
|
||||
cnet_2_image,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[cnet_2_output, stencils, images],
|
||||
)
|
||||
create_button = gr.Button(
|
||||
label="Start",
|
||||
value="Open drawing canvas!",
|
||||
visible=False,
|
||||
)
|
||||
create_button.click(
|
||||
fn=create_canvas,
|
||||
inputs=[canvas_width, canvas_height],
|
||||
outputs=[img2img_init_image],
|
||||
)
|
||||
use_stencil.change(
|
||||
fn=show_canvas,
|
||||
inputs=use_stencil,
|
||||
outputs=[canvas_width, canvas_height, create_button],
|
||||
)
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
@@ -624,7 +678,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
use_stencil,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
@@ -633,8 +686,16 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
std_output,
|
||||
img2img_status,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def inpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -121,7 +121,7 @@ def outpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -126,7 +126,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -226,7 +226,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil="None",
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
|
||||
@@ -280,7 +280,7 @@ def txt2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil="None",
|
||||
stencils=[],
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
@@ -120,7 +120,7 @@ def upscaler_inf(
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -30,7 +30,7 @@ class Config:
|
||||
width: int
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
stencils: list[str]
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user