Simplify ui further, add CLI option to load a default config

This commit is contained in:
Ean Garvey
2024-06-06 13:21:25 -05:00
parent 67b438eb9f
commit 5b3b262359
4 changed files with 295 additions and 238 deletions

View File

@@ -23,7 +23,6 @@ p = argparse.ArgumentParser(
##############################################################################
# Stable Diffusion Params
##############################################################################
p.add_argument(
"-a",
"--app",
@@ -595,9 +594,10 @@ p.add_argument(
# Web UI flags
##############################################################################
p.add_argument(
"--default_config",
"--defaults",
default="sdxl-turbo.json",
type=str,
help="Path to the default API request .json file. Works for CLI and webui."
)
p.add_argument(

View File

@@ -14,6 +14,7 @@ from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_checkpoints,
get_configs_path,
get_configs,
write_default_sd_configs,
)
from apps.shark_studio.api.sd import (
@@ -148,7 +149,14 @@ def pull_sd_configs(
def load_sd_cfg(sd_json: dict, load_sd_config: str):
new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config)))
if os.path.exists(load_sd_config):
config = load_sd_config
elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)):
config = os.path.join(get_configs_path(), load_sd_config)
else:
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
config = sd_json
new_sd_config = none_to_str_none(json.loads(view_json_file(config)))
if sd_json:
for key in new_sd_config:
sd_json[key] = new_sd_config[key]
@@ -241,17 +249,17 @@ def base_model_changed(base_model_id):
) + get_checkpoints(model_type="checkpoints")
if "turbo" in base_model_id:
new_steps = gr.Dropdown(
value=cmd_opts.steps,
value=2,
choices=[1, 2],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=False,
allow_custom_value=True,
)
if "stable-diffusion-xl-base-1.0" in base_model_id:
new_steps = gr.Dropdown(
value=40,
choices=[20, 25, 30, 35, 40, 45, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=False,
allow_custom_value=True,
)
elif ".py" in base_model_id:
new_steps = gr.Dropdown(
@@ -262,7 +270,7 @@ def base_model_changed(base_model_id):
)
else:
new_steps = gr.Dropdown(
value=cmd_opts.steps,
value=20,
choices=[10, 20, 30, 40, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
@@ -276,70 +284,197 @@ def base_model_changed(base_model_id):
new_steps,
]
init_config = global_obj.get_init_config()
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Column(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=2, min_width=600):
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Device Settings", open=False
):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=global_obj.get_device_list()[0],
choices=global_obj.get_device_list(),
allow_custom_value=False,
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="\U00002795\U0000FE0F Prompt",
value=init_config["prompt"][0],
lines=4,
elem_id="prompt_box",
show_copy_button=True,
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
label="Low VRAM",
interactive=True,
visible=False,
negative_prompt = gr.Textbox(
label="\U00002796\U0000FE0F Negative Prompt",
value=init_config["negative_prompt"][0],
lines=4,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Accordion(
label="\U0001F4D0\U0000FE0F Advanced Settings", open=True
):
with gr.Accordion(
label="Device Settings", open=False
):
device = gr.Dropdown(
elem_id="device",
label="Device",
value=init_config["device"] if init_config["device"] else "rocm",
choices=global_obj.get_device_list(),
allow_custom_value=True,
)
target_triple = gr.Textbox(
elem_id="target_triple",
label="Architecture",
value="",
value=init_config["target_triple"],
)
precision = gr.Radio(
label="Precision",
value=cmd_opts.precision,
choices=[
"fp16",
"fp32",
],
with gr.Row():
ondemand = gr.Checkbox(
value=init_config["ondemand"],
label="Low VRAM",
interactive=True,
visible=False,
)
precision = gr.Radio(
label="Precision",
value=init_config["precision"],
choices=[
"fp16",
"fp32",
],
visible=False,
)
with gr.Row():
height = gr.Slider(
512,
1024,
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
interactive=False, # DEMO
visible=False, # DEMO
)
width = gr.Slider(
512,
1024,
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
interactive=False, # DEMO
visible=False, # DEMO
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing",
open=False,
visible=False,
):
strength = gr.Slider(
0,
1,
value=init_config["strength"],
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=init_config["resample_type"],
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
with gr.Row():
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value=init_config["base_model_id"],
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=init_config["seed"],
label="\U0001F331\U0000FE0F Seed",
info="An integer, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value=init_config["scheduler"],
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Dropdown(
value=20,
choices=[10, 15, 20],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
guidance_scale = gr.Slider(
0,
5, #DEMO
value=4,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=init_config["batch_count"],
step=1,
label="Batch Count",
interactive=True,
visible=True,
)
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
base_model_id = gr.Dropdown(
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/sdxl-turbo",
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row():
height = gr.Slider(
512,
1024,
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
interactive=False, # DEMO
visible=False, # DEMO
)
width = gr.Slider(
512,
1024,
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
interactive=False, # DEMO
visible=False, # DEMO
)
batch_size = gr.Slider(
1,
4,
value=init_config["batch_size"],
step=1,
label="Batch Size",
interactive=False, # DEMO
visible=True,
)
compiled_pipeline = gr.Checkbox(
value=init_config["compiled_pipeline"],
label="Faster txt2img (SDXL only)",
visible=False, # DEMO
)
with gr.Row(elem_classes=["fill"], visible=False):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
write_default_sd_configs(get_configs_path())
default_config_file = global_obj.get_init_config()
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
with gr.Row():
with gr.Row():
load_sd_config = gr.Dropdown(
label="Load Config",
value=cmd_opts.defaults,
choices=get_configs(),
allow_custom_value=True,
)
with gr.Row():
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
components=sd_json,
)
# with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights",
open=False,
@@ -350,7 +485,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
label="Checkpoint Weights",
info="Select or enter HF model ID",
elem_id="custom_model",
value="None",
value=init_config["custom_weights"],
allow_custom_value=True,
choices=["None"]
+ get_checkpoints(os.path.basename(str(base_model_id))),
@@ -363,11 +498,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
label=f"VAE Model",
info=sd_vae_info,
elem_id="custom_model",
value=(
os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None"
),
value=init_config["custom_vae"],
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
@@ -380,7 +511,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
label=f"Standalone LoRA Weights",
info=sd_lora_info,
elem_id="lora_weights",
value=None,
value=init_config["embeddings"][0] if (len(init_config["embeddings"].keys()) > 1) else "None",
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,
@@ -405,68 +536,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
outputs=[embeddings_config],
show_progress=False,
)
with gr.Accordion(
label="\U0001F9EA\U0000FE0F Input Image Processing",
open=False,
visible=False,
):
strength = gr.Slider(
0,
1,
value=cmd_opts.strength,
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=cmd_opts.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="\U00002795\U0000FE0F Prompt",
value=cmd_opts.prompt[0],
lines=2,
elem_id="prompt_box",
show_copy_button=True,
)
negative_prompt = gr.Textbox(
label="\U00002796\U0000FE0F Negative Prompt",
value=cmd_opts.negative_prompt[0],
lines=2,
elem_id="negative_prompt_box",
show_copy_button=True,
)
with gr.Row(equal_height=True):
seed = gr.Textbox(
value=cmd_opts.seed,
label="\U0001F331\U0000FE0F Seed",
info="An integer, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value="EulerAncestralDiscrete",
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Dropdown(
value=cmd_opts.steps,
choices=[1, 2],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
guidance_scale = gr.Slider(
0,
5, #DEMO
value=cmd_opts.guidance_scale,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
)
with gr.Accordion(
label="Controlnet Options",
open=False,
@@ -628,30 +697,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
object_fit="fit",
preview=True,
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
value=cmd_opts.batch_count,
step=1,
label="Batch Count",
interactive=True,
visible=True,
)
batch_size = gr.Slider(
1,
4,
value=cmd_opts.batch_size,
step=1,
label="Batch Size",
interactive=True,
visible=False, # DEMO
)
compiled_pipeline = gr.Checkbox(
True,
label="Faster txt2img (SDXL only)",
visible=False, # DEMO
)
with gr.Row():
stable_diffusion = gr.Button("Start")
unload = gr.Button("Unload Models")
@@ -661,90 +706,43 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
show_progress=False,
)
stop_batch = gr.Button("Stop", visible=False)
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
write_default_sd_configs(get_configs_path())
default_config_file = os.path.join(
get_configs_path(),
"sdxl-turbo.json",
)
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
label="Load Config",
file_count="single",
root_dir=(
cmd_opts.configs_path
if cmd_opts.configs_path
else get_configs_path()
),
height=200,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
value="Save Config", size="sm"
)
clear_sd_config = gr.ClearButton(
value="Clear Config",
size="sm",
components=sd_json,
)
# with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
# with gr.Tab(label="Config", id=102) as sd_tab_config:
# with gr.Group():#elem_classes=["sd-right-panel"]):
# with gr.Row(elem_classes=["fill"], visible=False):
# Path(get_configs_path()).mkdir(
# parents=True, exist_ok=True
# )
# write_default_sd_configs(get_configs_path())
# default_config_file = global_obj.get_init_config()
# sd_json = gr.JSON(
# elem_classes=["fill"],
# value=view_json_file(default_config_file),
# )
# with gr.Row():
# with gr.Row():
# load_sd_config = gr.Dropdown(
# label="Load Config",
# value=cmd_opts.defaults,
# choices=get_configs(),
# allow_custom_value=True,
# )
# with gr.Row():
# save_sd_config = gr.Button(
# value="Save Config", size="sm"
# )
# clear_sd_config = gr.ClearButton(
# value="Clear Config",
# size="sm",
# components=sd_json,
# )
# # with gr.Row():
# sd_config_name = gr.Textbox(
# value="Config Name",
# info="Name of the file this config will be saved to.",
# interactive=True,
# show_label=False,
# )
with gr.Tab(label="Log", id=103, visible=False) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
@@ -765,7 +763,41 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
inputs=[base_model_id],
outputs=[custom_weights, steps],
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
outputs=[
prompt,
negative_prompt,
sd_init_image,
height,
width,
steps,
strength,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
base_model_id,
custom_weights,
custom_vae,
precision,
device,
target_triple,
ondemand,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
sd_json,
],
)
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
pull_kwargs = dict(
fn=pull_sd_configs,
inputs=[

View File

@@ -100,6 +100,15 @@ def get_checkpoints(model_type="checkpoints"):
ckpt_files.extend(files)
return sorted(ckpt_files, key=str.casefold)
def get_configs():
return sorted(
[
os.path.basename(x)
for x in glob.glob(os.path.join(get_configs_path(), "*.json"))
],
key=str.casefold,
)
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)

View File

@@ -1,12 +1,18 @@
import gc
from ...api.utils import get_available_devices
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import os
from apps.shark_studio.web.utils.file_utils import get_configs_path
"""
The global objects include SD pipeline and config.
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
Also we could avoid memory leak when switching models by clearing the cache.
"""
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
content = fopen.read()
return content
def _init():
global _sd_obj
@@ -89,6 +95,16 @@ def get_device_list():
global _devices
return _devices
def get_init_config():
global _init_config
if os.path.exists(cmd_opts.defaults):
_init_config = cmd_opts.defaults
elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)):
_init_config = os.path.join(get_configs_path(), cmd_opts.defaults)
else:
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
_init_config = os.path.join(get_configs_path(), "sdxl-turbo.json")
return _init_config
def get_sd_status():
global _sd_obj