mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
(SD) Fix schedulers and multi-controlnet. (#2006)
* (SD) Fixes schedulers if recieving noise preds as numpy arrays * Fix schedulers and stencil name * Multicontrolnet fixes
This commit is contained in:
@@ -254,8 +254,8 @@ class SharkifyStableDiffusionModel:
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
"stencil_adapter",
|
||||
"stencil_adapter_512",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
@@ -268,11 +268,19 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
# TODO: Fix this
|
||||
# if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
# model_config = model_config + get_path_stem(self.use_stencil)
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
if "stencil_adapter" in model:
|
||||
stencil_names = []
|
||||
for i, stencil in enumerate(self.stencils):
|
||||
if stencil is not None:
|
||||
cnet_config = model_config + stencil.split("_")[-1]
|
||||
stencil_names.append(
|
||||
get_extended_name(sub_model + cnet_config)
|
||||
)
|
||||
|
||||
model_name[model] = stencil_names
|
||||
else:
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
@@ -679,6 +687,8 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_control_net(self, stencil_id, use_large=False):
|
||||
stencil_id = get_stencil_model_id(stencil_id)
|
||||
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
|
||||
print(f"Importing ControlNet adapter from {stencil_id}")
|
||||
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
|
||||
@@ -687,7 +697,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.cnet.in_channels
|
||||
self.in_channels = self.cnet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
@@ -751,7 +761,23 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
inputs = tuple(self.inputs["stencil_adapter"])
|
||||
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
|
||||
ext_model_name = self.model_name[model_name]
|
||||
if isinstance(ext_model_name, list):
|
||||
for i in ext_model_name:
|
||||
if stencil_id.split("_")[-1] in i:
|
||||
desired_name = i
|
||||
print(f"Multi-CN: compiling model {i}")
|
||||
else:
|
||||
continue
|
||||
if desired_name:
|
||||
ext_model_name = desired_name
|
||||
else:
|
||||
raise Exception(
|
||||
f"Could not find extended configuration for {stencil_id}"
|
||||
)
|
||||
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
@@ -761,19 +787,13 @@ class SharkifyStableDiffusionModel:
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
*inputs[3:],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
|
||||
input_mask = [True, True, True, True] + ([True] * 13)
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name[model_name],
|
||||
extended_model_name=ext_model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -1315,16 +1335,16 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def controlnet(self, stencil_id, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
self.inputs["stencil_adapter"] = self.get_input_info_for(
|
||||
base_models["stencil_adapter"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
|
||||
stencil_id, use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
check_compilation(compiled_stencil_adapter, "Stencil")
|
||||
if self.return_mlir:
|
||||
return controlnet_mlir
|
||||
return compiled_stencil_adaptor
|
||||
return compiled_stencil_adapter
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
@@ -391,7 +391,6 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
control_mode,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
|
||||
@@ -825,7 +825,7 @@ class StableDiffusionPipeline:
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings.numpy()
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -207,16 +207,21 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
noise = randn_tensor(
|
||||
noise_pred.shape,
|
||||
dtype=noise_pred.dtype,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
noise_pred = (
|
||||
torch.from_numpy(noise_pred)
|
||||
if isinstance(noise_pred, np.ndarray)
|
||||
else noise_pred
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
|
||||
if gamma > 0:
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
|
||||
@@ -276,7 +276,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adaptor": {
|
||||
"stencil_adapter": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
|
||||
@@ -79,10 +79,12 @@ def controlnet_hint_shaping(
|
||||
)
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
|
||||
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
|
||||
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
return controlnet_hint_shaping(
|
||||
Image.fromarray(controlnet_hint.detach().numpy()),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
elif isinstance(controlnet_hint, np.ndarray):
|
||||
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
|
||||
@@ -109,29 +111,36 @@ def controlnet_hint_shaping(
|
||||
) # b h w c -> b c h w
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
|
||||
+ f"({height}, {width}, {channels}), "
|
||||
+ f"(1, {height}, {width}, {channels}) or "
|
||||
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
)
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
controlnet_hint = np.array(controlnet_hint) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, num_images_per_prompt
|
||||
Image.fromarray(controlnet_hint),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = np.array(controlnet_hint).astype(
|
||||
np.float16
|
||||
) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
|
||||
)
|
||||
(hint_w, hint_h) = controlnet_hint.size
|
||||
left = int((hint_w - width) / 2)
|
||||
right = left + height
|
||||
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
|
||||
controlnet_hint = controlnet_hint.resize((width, height))
|
||||
return controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -141,16 +150,22 @@ def controlnet_hint_conversion(
|
||||
controlnet_hint = None
|
||||
match use_stencil:
|
||||
case "canny":
|
||||
print("Detecting edge with canny")
|
||||
print(
|
||||
"Converting controlnet hint to edge detection mask with canny preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_canny(image)
|
||||
case "openpose":
|
||||
print("Detecting human pose")
|
||||
print(
|
||||
"Detecting human pose in controlnet hint with openpose preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_openpose(image)
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
print("Using your scribble as a controlnet hint.")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print("Working with ZoeDepth")
|
||||
print(
|
||||
"Converting controlnet hint to a depth mapping with ZoeDepth."
|
||||
)
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -6,6 +6,12 @@ import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorData,
|
||||
EditorValue,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -189,7 +195,7 @@ def img2img_inf(
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
else "stabilityai/stable-diffusion-1-5-base"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
@@ -403,6 +409,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
result = openpose(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
print(result)
|
||||
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
|
||||
return (
|
||||
Image.fromarray(result[0]),
|
||||
@@ -415,42 +422,53 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result[0]),
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case "scribble":
|
||||
result = input_image["composite"].convert("L")
|
||||
return (result, stencils, images)
|
||||
return (
|
||||
input_image["composite"],
|
||||
stencils,
|
||||
images,
|
||||
)
|
||||
case _:
|
||||
return (None, stencils, images)
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = (
|
||||
data = Image.fromarray(
|
||||
np.zeros(
|
||||
shape=(height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
+ 255
|
||||
)
|
||||
return data
|
||||
img_dict = {
|
||||
"background": data,
|
||||
"layers": [data],
|
||||
"composite": None,
|
||||
}
|
||||
return EditorValue(img_dict)
|
||||
|
||||
def update_cn_input(model, width, height):
|
||||
if model == "scribble":
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
value=create_canvas(width, height),
|
||||
crop_size=(width, height),
|
||||
brush=Brush(
|
||||
colors=["#000000"], color_mode="fixed"
|
||||
),
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True),
|
||||
gr.Slider(visible=True),
|
||||
@@ -469,6 +487,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=False),
|
||||
gr.Slider(visible=False),
|
||||
@@ -505,6 +524,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
cnet_1_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
@@ -515,6 +535,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
cnet_1_output = gr.Image(
|
||||
visible=True, show_label=False
|
||||
)
|
||||
|
||||
cnet_1_model.input(
|
||||
update_cn_input,
|
||||
[cnet_1_model, canvas_width, canvas_height],
|
||||
|
||||
Reference in New Issue
Block a user