mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
UI/app structure and utility implementation.
- Initializers for webui/API launch - Schedulers file for SD scheduling utilities - Additions to API-level utilities - Added embeddings module for LoRA, Lycoris, yada yada - Added image_processing module for resamplers, resize tools, transforms, and any image annotation (PNG metadata) - shared_cmd_opts module -- sorry, this is stable_args.py. It lives on. We still want to have some global control over the app exclusively from the command-line. At least we will be free from shark_args. - Moving around some utility pieces. - Try to make api+webui concurrency possible in index.py - SD UI -- this is just img2imgUI but hopefully a little better. - UI utilities for your nod logos and your gradio temps.
This commit is contained in:
76
apps/shark_studio/api/initializers.py
Normal file
76
apps/shark_studio/api/initializers.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import re
|
||||
import warnings
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from apps.shark_studio.modules.timer import startup_timer
|
||||
|
||||
|
||||
def imports():
|
||||
import torch # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
#from apps.shark_studio.modules import shared_init
|
||||
#shared_init.initialize()
|
||||
#startup_timer.record("initialize shared")
|
||||
|
||||
from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
|
||||
startup_timer.record("other imports")
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
configure_opts_onchange()
|
||||
|
||||
#from apps.shark_studio.modules import modelloader
|
||||
#modelloader.cleanup_models()
|
||||
|
||||
#from apps.shark_studio.modules import sd_models
|
||||
#sd_models.setup_model()
|
||||
#startup_timer.record("setup SD model")
|
||||
|
||||
#initialize_rest(reload_script_modules=False)
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Called both from initialize() and when reloading the webui.
|
||||
"""
|
||||
# Keep this for adding reload options to the webUI.
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
|
||||
print("\n".join(code))
|
||||
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
20
apps/shark_studio/api/schedulers.py
Normal file
20
apps/shark_studio/api/schedulers.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#from shark_turbine.turbine_models.schedulers import export_scheduler_model
|
||||
|
||||
def export_scheduler_model(model):
|
||||
return "None", "None"
|
||||
|
||||
scheduler_model_map = {
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
"LCM": export_scheduler_model("LCMScheduler"),
|
||||
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
"DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
"DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
}
|
||||
@@ -1,12 +1,280 @@
|
||||
import os
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ["cpu-task"]
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name}://{i}"
|
||||
)
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
|
||||
def get_generated_imgs_path() -> Path:
|
||||
return Path(
|
||||
cmd_opts.output_dir
|
||||
if cmd_opts.output_dir
|
||||
else get_resource_path("..\web\generated_imgs")
|
||||
)
|
||||
|
||||
|
||||
def get_generated_imgs_todays_subdir() -> str:
|
||||
return dt.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_checkpoints_path(model = ""):
|
||||
return get_resource_path(f"..\web\models\{model}")
|
||||
|
||||
|
||||
def get_checkpoints(path):
|
||||
files = []
|
||||
for file in
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for
|
||||
the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name}://{i}"
|
||||
)
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(
|
||||
type(seed) is int for seed in seed_input
|
||||
):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(
|
||||
type(seed) is int for seed in seed_input
|
||||
):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
111
apps/shark_studio/modules/embeddings.py
Normal file
111
apps/shark_studio/modules/embeddings.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
if ".bin" in use_lora:
|
||||
main_file_name += ".bin"
|
||||
elif ".safetensors" in use_lora:
|
||||
main_file_name += ".safetensors"
|
||||
elif ".pt" in use_lora:
|
||||
main_file_name += ".pt"
|
||||
else:
|
||||
sys.exit("Only .bin and .safetensors format for LoRA is supported")
|
||||
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
except:
|
||||
return None
|
||||
|
||||
171
apps/shark_studio/modules/img_processing.py
Normal file
171
apps/shark_studio/modules/img_processing.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from
|
||||
|
||||
# save output images and the inputs corresponding to it.
|
||||
def save_output_img(output_img, img_seed, extra_info=None):
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
generated_imgs_path.mkdir(parents=True, exist_ok=True)
|
||||
csv_path = Path(generated_imgs_path, "imgs_details.csv")
|
||||
|
||||
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
|
||||
out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}"
|
||||
|
||||
img_model = args.hf_model_id
|
||||
if args.ckpt_loc:
|
||||
img_model = Path(os.path.basename(args.ckpt_loc)).stem
|
||||
|
||||
img_vae = None
|
||||
if args.custom_vae:
|
||||
img_vae = Path(os.path.basename(args.custom_vae)).stem
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
output_img.save(out_img_path, quality=95, subsampling=0)
|
||||
else:
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if args.use_hiresfix:
|
||||
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
|
||||
else:
|
||||
png_size_text = f"{args.width}x{args.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}"
|
||||
f"\nNegative prompt: {args.negative_prompts[0]}"
|
||||
f"\nSteps: {args.steps},"
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
)
|
||||
|
||||
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
|
||||
|
||||
if args.output_img_format not in ["png", "jpg"]:
|
||||
print(
|
||||
f"[ERROR] Format {args.output_img_format} is not "
|
||||
f"supported yet. Image saved as png instead."
|
||||
f"Supported formats: png / jpg"
|
||||
)
|
||||
|
||||
# To be as low-impact as possible to the existing CSV format, we append
|
||||
# "VAE" and "LORA" to the end. However, it does not fit the hierarchy of
|
||||
# importance for each data point. Something to consider.
|
||||
new_entry = {
|
||||
"VARIANT": img_model,
|
||||
"SCHEDULER": args.scheduler,
|
||||
"PROMPT": args.prompts[0],
|
||||
"NEG_PROMPT": args.negative_prompts[0],
|
||||
"SEED": img_seed,
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height
|
||||
if not args.use_hiresfix
|
||||
else args.hiresfix_height,
|
||||
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
"LORA": img_lora,
|
||||
}
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
if args.save_metadata_to_json:
|
||||
del new_entry["OUTPUT"]
|
||||
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(new_entry, f, indent=4)
|
||||
|
||||
|
||||
def get_generation_text_info(seeds, device):
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, " f"device={device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, "
|
||||
f"guidance_scale={args.guidance_scale}, "
|
||||
f"seed={seeds}"
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
if not args.use_hiresfix
|
||||
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
)
|
||||
|
||||
return text_output
|
||||
|
||||
|
||||
# For stencil, the input image can be of any size, but we need to ensure that
|
||||
# it conforms with our model constraints :-
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image, width, height):
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 128:
|
||||
n_size = 128
|
||||
if width == min_size:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
else:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
|
||||
min_size = min(width, height)
|
||||
if min_size > 768:
|
||||
n_size = 768
|
||||
if width == min_size:
|
||||
height = n_size
|
||||
width = n_size * aspect_ratio
|
||||
else:
|
||||
width = n_size
|
||||
height = n_size / aspect_ratio
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
n_width = width // 8
|
||||
n_height = height // 8
|
||||
n_width *= 8
|
||||
n_height *= 8
|
||||
new_image = image.resize((n_width, n_height))
|
||||
return new_image, n_width, n_height
|
||||
|
||||
771
apps/shark_studio/modules/shared_cmd_opts.py
Normal file
771
apps/shark_studio/modules/shared_cmd_opts.py
Normal file
@@ -0,0 +1,771 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.stable_diffusion.src.utils.resamplers import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
def is_valid_file(arg):
|
||||
if not os.path.exists(arg):
|
||||
return None
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
default="txt2img",
|
||||
help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.",
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smokes coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
nargs="+",
|
||||
default=[
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
|
||||
"blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
help="Text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=str,
|
||||
default=-1,
|
||||
help="The seed or list of seeds to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=range(1, 4),
|
||||
help="The number of inferences to be made in a single `batch_count`.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 1025, 8),
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 1025, 8),
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_level",
|
||||
type=int,
|
||||
default=20,
|
||||
help="The value to be used for noise level of upscaler.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Max length of the tokenizer output, options are 64 and 77.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The max multiple length of prompt embeddings compared to the max "
|
||||
"output length of text encoder.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="The strength of change applied on the given input image for "
|
||||
"img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_hiresfix",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use Hires Fix to do higher resolution images, while trying to "
|
||||
"avoid the issues that come with it. This is accomplished by first "
|
||||
"generating an image using txt2img, then running it through img2img.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_height",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The height of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_width",
|
||||
type=int,
|
||||
default=768,
|
||||
choices=range(128, 769, 8),
|
||||
help="The width of the Hires Fix image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hiresfix_strength",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The denoising strength to apply for the Hires Fix.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=resampler_list,
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Stable Diffusion Training Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--lora_save_dir",
|
||||
type=str,
|
||||
default="models/lora/",
|
||||
help="Directory to save the lora fine tuned model.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_images_dir",
|
||||
type=str,
|
||||
default="models/lora/training_images/",
|
||||
help="Directory containing images that are an example of the prompt.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of steps to train.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Inpainting and Outpainting Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--mask_path",
|
||||
type=str,
|
||||
help="Path to the mask image input for inpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If inpaint only masked area or whole picture.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--inpaint_full_res_padding",
|
||||
type=int,
|
||||
default=32,
|
||||
choices=range(0, 257, 4),
|
||||
help="Number of pixels for only masked padding.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--pixels",
|
||||
type=int,
|
||||
default=128,
|
||||
choices=range(8, 257, 8),
|
||||
help="Number of expended pixels for one direction for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--mask_blur",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=range(0, 65),
|
||||
help="Number of blur pixels for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--left",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend left for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--right",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend right for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--up",
|
||||
"--top",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend top for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--down",
|
||||
"--bottom",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If extend bottom for outpainting.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--noise_q",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fall-off exponent for outpainting (lower=higher detail) "
|
||||
"(min=0.0, max=4.0).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--color_variation",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Color variation for outpainting (min=0.0, max=1.0).",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="Device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="Precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Imports the model from torch module to shark_module otherwise "
|
||||
"downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Attempts to load the model from a precompiled flat-buffer "
|
||||
"and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Saves the compiled flat-buffer to the local directory.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_base_vae",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do conversion from the VAE output to pixel space on cpu.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
|
||||
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
|
||||
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
|
||||
"DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, "
|
||||
"HeunDiscrete].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_img_format",
|
||||
type=str,
|
||||
default="png",
|
||||
help="Specify the format in which output image is save. "
|
||||
"Supported options: jpg / png.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory path to save the output images and json.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of batches to be generated with random seeds in "
|
||||
"single execution.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--repeatable_seeds",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="The seed of the first batch will be used as the rng seed to "
|
||||
"generate the subsequent seeds for subsequent batches in that run.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_loc",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to SD's .ckpt file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--custom_vae",
|
||||
type=str,
|
||||
default="",
|
||||
help="HuggingFace repo-id or path to SD model's checkpoint whose VAE "
|
||||
"needs to be plugged in.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-2-1-base",
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--low_cpu_mem_usage",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the accelerate package to reduce cpu memory consumption.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--attention_slicing",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', "
|
||||
"or an integer).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--control_mode",
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
default="Balanced",
|
||||
help="How Controlnet injection should be prioritized.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
default="",
|
||||
help="Use standalone LoRA weight using a HF ID or a checkpoint "
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
default="none",
|
||||
help="Runs the quantized version of stable diffusion model. "
|
||||
"This is currently in experimental phase. "
|
||||
"Currently, only runs the stable-diffusion-2-1-base model in "
|
||||
"int8 quantization.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ondemand",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Load and unload models for low VRAM.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device_allocator_heap_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify heap key for device caching allocator."
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_metal_target_platform",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for metal.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--use_compiled_scheduler",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use the default scheduler precompiled into the model if available.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. "
|
||||
"If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. "
|
||||
"Use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help="Dispatches to return benchmark data on. "
|
||||
'Use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help="Directory where you want to store dispatch data "
|
||||
'generated with "--dispatch_benchmarks".',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for inserting debug frames between iterations "
|
||||
"for use with rgp.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hide_steps",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Flag setting warmup count for CLIP and VAE [>= 0].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--clear_all",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to clear all mlir and vmfb from common locations. "
|
||||
"Recompiling will take several minutes.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_metadata_to_json",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save a generation information "
|
||||
"json file with the image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--write_metadata_to_png",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether or not to save generation information in "
|
||||
"PNG chunk text to generated images.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If import_mlir is True, saves mlir via the debug option "
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--data_tiling",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls data tiling in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--progress_bar",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the progress bar animation during "
|
||||
"image generation.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
# TODO: replace API flag when these can be run together
|
||||
p.add_argument(
|
||||
"--ui",
|
||||
type=str,
|
||||
default="app" if os.name == "nt" else "web",
|
||||
help="One of: [api, app, web].",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for generating a public URL.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="Flag for setting server port.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling rest API.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--api_accept_origin",
|
||||
action="append",
|
||||
type=str,
|
||||
help="An origin to be accepted by the REST api for Cross Origin"
|
||||
"Resource Sharing (CORS). Use multiple times for multiple origins, "
|
||||
'or use --api_accept_origin="*" to accept all origins. If no origins '
|
||||
"are set no CORS headers will be returned by the api. Use, for "
|
||||
"instance, if you need to access the REST api from Javascript running "
|
||||
"in a web browser.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling debugging log in WebUI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for removing the output gallery tab, and avoid exposing "
|
||||
"images under --output_dir in the UI.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--output_gallery_followlinks",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for whether the output gallery tab in the UI should "
|
||||
"follow symlinks when listing subdirectories under --output_dir.",
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_model",
|
||||
type=str,
|
||||
default="unet",
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_annotation",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save annotated mlir file.",
|
||||
)
|
||||
##############################################################################
|
||||
# SD model auto-tuner flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--tuned_config_dir",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the tuned config file.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--num_iters",
|
||||
type=int,
|
||||
default=400,
|
||||
help="Number of iterations for tuning.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--search_op",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Op to be optimized, options are matmul, bmm, conv and all.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# DocuChat Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--run_docuchat_web",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
os.getcwd(), args.hf_model_id.replace("/", "_")
|
||||
)
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
|
||||
from ui.chat import chat_element
|
||||
from ui.sd import sd_element
|
||||
from ui.outputgallery import outputgallery_element
|
||||
|
||||
from modules import timer, initialize
|
||||
|
||||
@@ -185,7 +186,8 @@ def webui():
|
||||
# further down with the new id.
|
||||
with gr.TabItem(label="Stable Diffusion", id=0):
|
||||
sd_element.render()
|
||||
#with gr.TabItem(label="Output Gallery", id=1):
|
||||
with gr.TabItem(label="Output Gallery", id=1):
|
||||
outputgallery_element.render()
|
||||
with gr.TabItem(label="Chat Bot", id=2):
|
||||
chat_element.render()
|
||||
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import re
|
||||
import warnings
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
def imports():
|
||||
import torch # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
from apps.shark_studio.modules import shared_init
|
||||
shared_init.initialize()
|
||||
startup_timer.record("initialize shared")
|
||||
|
||||
from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
|
||||
startup_timer.record("other imports")
|
||||
|
||||
def initialize():
|
||||
configure_sigint_handler()
|
||||
configure_opts_onchange()
|
||||
|
||||
from apps.shark_studio.modules import modelloader
|
||||
modelloader.cleanup_models()
|
||||
|
||||
from apps.shark_studio.modules import sd_models
|
||||
sd_models.setup_model()
|
||||
startup_timer.record("setup SD model")
|
||||
|
||||
#from apps.shark_studio.modules.shared_cmd_options import cmd_opts
|
||||
|
||||
#from apps.shark_studio.modules import codeformer_model
|
||||
#warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
|
||||
#codeformer_model.setup_model(cmd_opts.codeformer_models_path)
|
||||
#startup_timer.record("setup codeformer")
|
||||
|
||||
#from apps.shark_studio.modules import gfpgan_model
|
||||
#gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
|
||||
#startup_timer.record("setup gfpgan")
|
||||
|
||||
initialize_rest(reload_script_modules=False)
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
|
||||
print("\n".join(code))
|
||||
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
if not os.environ.get("COVERAGE_RUN"):
|
||||
# Don't install the immediate-quit handler when running under coverage,
|
||||
# as then the coverage report won't be generated.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
|
||||
print("\n".join(code))
|
||||
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Called both from initialize() and when reloading the webui.
|
||||
"""
|
||||
from apps.shark_studio.modules.shared_cmd_options import cmd_opts
|
||||
|
||||
from apps.shark_studio.modules import sd_samplers
|
||||
sd_samplers.set_samplers()
|
||||
startup_timer.record("set samplers")
|
||||
|
||||
restore_config_state_file()
|
||||
startup_timer.record("restore config state file")
|
||||
|
||||
from apps.shark_studio.modules import sd_models
|
||||
sd_models.list_models()
|
||||
startup_timer.record("list SD models")
|
||||
|
||||
with startup_timer.subcategory("load scripts"):
|
||||
scripts.load_scripts()
|
||||
|
||||
if reload_script_modules:
|
||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||
importlib.reload(module)
|
||||
startup_timer.record("reload script modules")
|
||||
|
||||
from apps.shark_studio.modules import sd_vae
|
||||
sd_vae.refresh_vae_list()
|
||||
startup_timer.record("refresh VAE")
|
||||
|
||||
# from apps.shark_studio.modules import textual_inversion
|
||||
# textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
# startup_timer.record("refresh textual inversion templates")
|
||||
|
||||
from apps.shark_studio.modules import sd_unet
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Accesses shared.sd_model property to load model.
|
||||
"""
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
Thread(target=load_model).start()
|
||||
679
apps/shark_studio/web/ui/sd.py
Normal file
679
apps/shark_studio/web/ui/sd.py
Normal file
@@ -0,0 +1,679 @@
|
||||
import os
|
||||
import time
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import json
|
||||
import sys
|
||||
|
||||
from math import ceil
|
||||
from inspect import signature
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorValue,
|
||||
)
|
||||
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_available_devices,
|
||||
get_generated_imgs_path,
|
||||
get_checkpoints_path,
|
||||
get_checkpoints,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
SharkStableDiffusionPipeline,
|
||||
)
|
||||
from apps.shark_studio.api.schedulers import (
|
||||
scheduler_model_map,
|
||||
)
|
||||
from apps.shark_studio.api.controlnet import (
|
||||
resampler_list,
|
||||
preprocessor_model_map,
|
||||
control_adapter_model_map,
|
||||
PreprocessorModel,
|
||||
)
|
||||
from apps.shark_studio.modules.img_processing import (
|
||||
resampler_list,
|
||||
resize_stencil,
|
||||
)
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
get_generation_text_info,
|
||||
nodlogo_loc,
|
||||
)
|
||||
from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
|
||||
sd_pipe = None
|
||||
|
||||
|
||||
# NOTE: Each `hf_model_id` should have its own starting configuration.
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
def shark_sd_fn(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image_dict,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
base_model_id: str,
|
||||
custom_checkpoints: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
preprocessed_hints: list,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
|
||||
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
continue
|
||||
elif stencil is None and any(img is not None for img in [images[i], preprocessed_hints[i]]):
|
||||
images[i] = None
|
||||
preprocessed_hints[i] = None
|
||||
elif images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
if image:
|
||||
image, _, _, = resize_stencil(image, width, height)
|
||||
|
||||
device_id = None
|
||||
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
submit_pipe_kwargs = {
|
||||
"base_model_id": base_model_id,
|
||||
"custom_vae": custom_vae,
|
||||
"import_mlir": cmd_opts.import_mlir,
|
||||
"":
|
||||
|
||||
global sd_pipe
|
||||
global sd_pipe_kwargs
|
||||
|
||||
for key in
|
||||
|
||||
if sd_pipe is None:
|
||||
history[-1][-1] = "Getting the pipeline ready..."
|
||||
yield history, ""
|
||||
|
||||
# Initializes the pipeline and retrieves IR based on all
|
||||
# parameters that are static in the turbine output format,
|
||||
# which is currently MLIR in the torch dialect.
|
||||
|
||||
sd_pipe = SharkStableDiffusionPipeline(
|
||||
base_model_id = base_model_id,
|
||||
custom_vae = custom_vae,
|
||||
import_mlir = import_mlir,
|
||||
device = device.split("=>", 1)[1].strip(),
|
||||
precision = precision,
|
||||
max_length = 512,
|
||||
height = height,
|
||||
width = width,
|
||||
)
|
||||
|
||||
#
|
||||
for prompt, msg, exec_time in progress.tqdm(
|
||||
sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
),
|
||||
desc="Generating Image...",
|
||||
):
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
sd_fn_sig = signature(shark_sd_fn)
|
||||
max_controlnets = 5
|
||||
max_loras = 5
|
||||
|
||||
def show_loras(k):
|
||||
k = int(k)
|
||||
return [gr.Dropdown(visible=True)]*k + [gr.Dropdown(visible=False, value="None")]*(max_textboxes-k)
|
||||
|
||||
def show_controlnets(k):
|
||||
k = int(k)
|
||||
return [gr.Row(visible=True)]*k + [gr.Row(visible=False)]*(max_textboxes-k)
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = Image.fromarray(
|
||||
np.zeros(
|
||||
shape=(height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
+ 255
|
||||
)
|
||||
img_dict = {
|
||||
"background": data,
|
||||
"layers": [data],
|
||||
"composite": None,
|
||||
}
|
||||
return EditorValue(img_dict)
|
||||
|
||||
def import_original(original_img, width, height):
|
||||
resized_img, _, _ = resize_stencil(
|
||||
original_img, width, height
|
||||
)
|
||||
img_dict = {
|
||||
"background": resized_img,
|
||||
"layers": [resized_img],
|
||||
"composite": None,
|
||||
}
|
||||
return gr.ImageEditor(
|
||||
value=EditorValue(img_dict),
|
||||
crop_size=(width, height),
|
||||
)
|
||||
|
||||
def update_cn_input(
|
||||
model,
|
||||
width,
|
||||
height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
index,
|
||||
):
|
||||
if model == None:
|
||||
stencils[index] = None
|
||||
images[index] = None
|
||||
preprocessed_hints[index] = None
|
||||
return [
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
elif model == "scribble":
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
brush=Brush(
|
||||
colors=["#000000"],
|
||||
color_mode="fixed",
|
||||
default_size=5,
|
||||
),
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True, label="Canvas Width"),
|
||||
gr.Slider(visible=True, label="Canvas Height"),
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True, label="Canvas Width"),
|
||||
gr.Slider(visible=True, label="Canvas Height"),
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
# Get a list of arguments needed for the API call, then
|
||||
# initialize an empty list that will manage the corresponding
|
||||
# gradio values.
|
||||
inputs_list = gr.State(signature(shark_sd_fn))
|
||||
inputs_args = gr.State([None] * len(inputs_list))
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
save_sd_config = gr.Button(label="Save Config", scale=1)
|
||||
load_sd_config = gr.FileExplorer("Load Config", scale=1)
|
||||
clear_sd_config = gr.ClearButton("Clear Config", scale=1)
|
||||
with gr.Column(elem_if="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group()
|
||||
sd_model_info = (
|
||||
f"Checkpoint Path: {str(get_checkpoint_path())}"
|
||||
)
|
||||
sd_base = gr.Dropdown(
|
||||
label="Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2.1-base",
|
||||
choices=get_base_models(),
|
||||
) # base_model_id
|
||||
sd_checkpoint = gr.Dropdown(
|
||||
label="Checkpoints (optional)",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
choices=get_checkpoints(sd_base),
|
||||
) #
|
||||
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
sd_vae_info = f"VAE Path: {sd_vae_info}"
|
||||
sd_custom_vae = gr.Dropdown(
|
||||
label=f"Custom VAE Models",
|
||||
info=sd_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(cmd_opts.custom_vae)
|
||||
if cmd_opts.custom_vae
|
||||
else "None",
|
||||
choices=["None"] + get_checkpoints("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
with gr.Accordion(label = "Input Image", open=False):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
sd_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
height=300,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Accordion(label="Embeddings options", open=False):
|
||||
sd_lora_info = (
|
||||
str(get_checkpoints_path("loras"))
|
||||
).replace("\\", "\n\\")
|
||||
num_loras = gr.Slider(1, max_loras, value=1, step=1, label="LoRA Count")
|
||||
loras = []
|
||||
for i in range(max_loras):
|
||||
lora_opt = gr.Dropdown(
|
||||
allow_custom_value=False,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=True):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384, 768, value=cmd_opts.height, step=8, label="Height"
|
||||
)
|
||||
width = gr.Slider(
|
||||
384, 768, value=cmd_opts.width, step=8, label="Width"
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=cmd_opts.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=cmd_opts.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=cmd_opts.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=cmd_opts.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=cmd_opts.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=get_available_devices[0],
|
||||
choices=get_available_devices,
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Accordion(label="Controlnet Options", open=False):
|
||||
sd_cnet_info = (
|
||||
str(get_checkpoints_path("controlnet"))
|
||||
).replace("\\", "\n\\")
|
||||
num_cnets = gr.Slider(1, max_controlnets, value=1, step=1, label="Controlnet Count")
|
||||
cnet_rows = []
|
||||
stencils = []
|
||||
images = []
|
||||
preprocessed_hints = []
|
||||
for i in range(max_controlnets):
|
||||
with gr.Row as cnet_row:
|
||||
with gr.Column():
|
||||
cnet_gen = gr.Button(
|
||||
value="Preprocess controlnet input",
|
||||
)
|
||||
cnet_processor = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Controlnet Preprocessor",
|
||||
info=sd_cnet_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + controlnet_list + get_custom_model_files("controlnet"),
|
||||
)
|
||||
cnet_adapter = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Controlnet Adapter",
|
||||
info=sd_cnet_info,
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + controlnet_list + get_custom_model_files("controlnet"),
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
)
|
||||
cnet_input = gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
)
|
||||
cnet_output = gr.Image(
|
||||
value=None,
|
||||
visible=True,
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
show_label=True
|
||||
)
|
||||
use_input_img.click(
|
||||
import_original,
|
||||
[sd_init_image, canvas_width, canvas_height],
|
||||
[cnet_image],
|
||||
)
|
||||
|
||||
cnet_model.change(
|
||||
fn=update_cn_input,
|
||||
inputs=[
|
||||
cnet_model,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_input,
|
||||
cnet_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_image,
|
||||
],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[cnet_gen.click],
|
||||
fn=cnet_preview,
|
||||
inputs=[
|
||||
cnet_model,
|
||||
cnet_input,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_output,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
cnet_rows.append(cnet_row)
|
||||
|
||||
num_cnets.change(show_controlnets, num_cnets, cnet_rows)
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
img2img_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=2,
|
||||
object_fit="contain",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{i2i_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
|
||||
kwargs = dict(
|
||||
fn=shark_sd_fn,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
sd_base,
|
||||
sd_checkpoint,
|
||||
sd_custom_vae,
|
||||
precision,
|
||||
device,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
std_output,
|
||||
sd_status,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
show_progress="minimal" if cmd_opts.progress_bar else "none",
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=img2img_status,
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
9
apps/shark_studio/web/ui/utils.py
Normal file
9
apps/shark_studio/web/ui/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def nodlogo_loc():
|
||||
return "foo"
|
||||
|
||||
def get_checkpoints_path(model_type: str = None):
|
||||
return "foo"
|
||||
|
||||
def get_checkpoints():
|
||||
return "foo"
|
||||
|
||||
Reference in New Issue
Block a user