(SD) Fix schedulers and multi-controlnet. (#2006)

* (SD) Fixes schedulers if recieving noise preds as numpy arrays

* Fix schedulers and stencil name

* Multicontrolnet fixes
This commit is contained in:
Ean Garvey
2023-12-05 03:29:18 -06:00
committed by GitHub
parent d72da3801f
commit ce9ce3a7c8
7 changed files with 127 additions and 67 deletions

View File

@@ -254,8 +254,8 @@ class SharkifyStableDiffusionModel:
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
"stencil_adapter",
"stencil_adapter_512",
]
index = 0
for model in sub_model_list:
@@ -268,11 +268,19 @@ class SharkifyStableDiffusionModel:
)
if self.base_vae:
sub_model = "base_vae"
# TODO: Fix this
# if "stencil_adaptor" == model and self.use_stencil is not None:
# model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
if "stencil_adapter" in model:
stencil_names = []
for i, stencil in enumerate(self.stencils):
if stencil is not None:
cnet_config = model_config + stencil.split("_")[-1]
stencil_names.append(
get_extended_name(sub_model + cnet_config)
)
model_name[model] = stencil_names
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
def check_params(self, max_len, width, height):
@@ -679,6 +687,8 @@ class SharkifyStableDiffusionModel:
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
print(f"Importing ControlNet adapter from {stencil_id}")
class StencilControlNetModel(torch.nn.Module):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
@@ -687,7 +697,7 @@ class SharkifyStableDiffusionModel:
model_id,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.in_channels = self.cnet.in_channels
self.in_channels = self.cnet.config.in_channels
self.train(False)
def forward(
@@ -751,7 +761,23 @@ class SharkifyStableDiffusionModel:
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
ext_model_name = self.model_name[model_name]
if isinstance(ext_model_name, list):
for i in ext_model_name:
if stencil_id.split("_")[-1] in i:
desired_name = i
print(f"Multi-CN: compiling model {i}")
else:
continue
if desired_name:
ext_model_name = desired_name
else:
raise Exception(
f"Could not find extended configuration for {stencil_id}"
)
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
@@ -761,19 +787,13 @@ class SharkifyStableDiffusionModel:
torch.nn.functional.pad(inputs[2], pad),
*inputs[3:],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
input_mask = [True, True, True, True] + ([True] * 13)
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name[model_name],
extended_model_name=ext_model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
@@ -1315,16 +1335,16 @@ class SharkifyStableDiffusionModel:
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
self.inputs["stencil_adapter"] = self.get_input_info_for(
base_models["stencil_adapter"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
stencil_id, use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")
check_compilation(compiled_stencil_adapter, "Stencil")
if self.return_mlir:
return controlnet_mlir
return compiled_stencil_adaptor
return compiled_stencil_adapter
except Exception as e:
sys.exit(e)

View File

@@ -391,7 +391,6 @@ class StencilPipeline(StableDiffusionPipeline):
control_mode,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )

View File

@@ -825,7 +825,7 @@ class StableDiffusionPipeline:
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()
return text_embeddings.numpy().astype(np.float16)
from typing import List, Optional, Union

View File

@@ -207,16 +207,21 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
sigma_hat = sigma * (gamma + 1)
noise = randn_tensor(
noise_pred.shape,
dtype=noise_pred.dtype,
device="cpu",
generator=generator,
noise_pred = (
torch.from_numpy(noise_pred)
if isinstance(noise_pred, np.ndarray)
else noise_pred
)
eps = noise * s_noise
if gamma > 0:
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
eps = noise * s_noise
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.config.prediction_type == "v_prediction":

View File

@@ -276,7 +276,7 @@
}
}
},
"stencil_adaptor": {
"stencil_adapter": {
"latents": {
"shape": [
"1*batch_size",

View File

@@ -79,10 +79,12 @@ def controlnet_hint_shaping(
)
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
return controlnet_hint_shaping(
Image.fromarray(controlnet_hint.detach().numpy()),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, np.ndarray):
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
@@ -109,29 +111,36 @@ def controlnet_hint_shaping(
) # b h w c -> b c h w
return controlnet_hint
else:
raise ValueError(
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
+ f"({height}, {width}, {channels}), "
+ f"(1, {height}, {width}, {channels}) or "
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
)
elif isinstance(controlnet_hint, Image.Image):
if controlnet_hint.size == (width, height):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
controlnet_hint = np.array(controlnet_hint) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return controlnet_hint_shaping(
controlnet_hint, height, width, num_images_per_prompt
Image.fromarray(controlnet_hint),
height,
width,
dtype,
num_images_per_prompt,
)
elif isinstance(controlnet_hint, Image.Image):
controlnet_hint = controlnet_hint.convert(
"RGB"
) # make sure 3 channel RGB format
if controlnet_hint.size == (width, height):
controlnet_hint = np.array(controlnet_hint).astype(
np.float16
) # to numpy
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
return
else:
raise ValueError(
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
)
(hint_w, hint_h) = controlnet_hint.size
left = int((hint_w - width) / 2)
right = left + height
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
controlnet_hint = controlnet_hint.resize((width, height))
return controlnet_hint_shaping(
controlnet_hint, height, width, dtype, num_images_per_prompt
)
else:
raise ValueError(
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
)
@@ -141,16 +150,22 @@ def controlnet_hint_conversion(
controlnet_hint = None
match use_stencil:
case "canny":
print("Detecting edge with canny")
print(
"Converting controlnet hint to edge detection mask with canny preprocessor."
)
controlnet_hint = hint_canny(image)
case "openpose":
print("Detecting human pose")
print(
"Detecting human pose in controlnet hint with openpose preprocessor."
)
controlnet_hint = hint_openpose(image)
case "scribble":
print("Working with scribble")
print("Using your scribble as a controlnet hint.")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
print(
"Converting controlnet hint to a depth mapping with ZoeDepth."
)
controlnet_hint = hint_zoedepth(image)
case _:
return None

View File

@@ -6,6 +6,12 @@ import PIL
from math import ceil
from PIL import Image
from gradio.components.image_editor import (
Brush,
Eraser,
EditorData,
EditorValue,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -189,7 +195,7 @@ def img2img_inf(
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
else "stabilityai/stable-diffusion-1-5-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
@@ -403,6 +409,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
result = openpose(
np.array(input_image["composite"])
)
print(result)
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
return (
Image.fromarray(result[0]),
@@ -415,42 +422,53 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
np.array(input_image["composite"])
)
return (
Image.fromarray(result[0]),
Image.fromarray(result),
stencils,
images,
)
case "scribble":
result = input_image["composite"].convert("L")
return (result, stencils, images)
return (
input_image["composite"],
stencils,
images,
)
case _:
return (None, stencils, images)
def create_canvas(width, height):
data = (
data = Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
return data
img_dict = {
"background": data,
"layers": [data],
"composite": None,
}
return EditorValue(img_dict)
def update_cn_input(model, width, height):
if model == "scribble":
return [
gr.ImageEditor(
visible=True,
image_mode="RGB",
interactive=True,
show_label=False,
image_mode="RGB",
type="pil",
value=create_canvas(width, height),
crop_size=(width, height),
brush=Brush(
colors=["#000000"], color_mode="fixed"
),
),
gr.Image(
visible=True,
show_label=False,
interactive=False,
show_download_button=False,
),
gr.Slider(visible=True),
gr.Slider(visible=True),
@@ -469,6 +487,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
visible=True,
show_label=False,
interactive=True,
show_download_button=False,
),
gr.Slider(visible=False),
gr.Slider(visible=False),
@@ -505,6 +524,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="Make Canvas!",
visible=False,
)
cnet_1_image = gr.ImageEditor(
visible=False,
image_mode="RGB",
@@ -515,6 +535,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
cnet_1_output = gr.Image(
visible=True, show_label=False
)
cnet_1_model.input(
update_cn_input,
[cnet_1_model, canvas_width, canvas_height],