Updates to sd api and UI for sd3, sdxl QOL

This commit is contained in:
Ean Garvey
2024-06-19 04:22:29 -05:00
parent 88db3457e2
commit f092ebddd2
3 changed files with 102 additions and 114 deletions

View File

@@ -32,7 +32,6 @@ from apps.shark_studio.modules.img_processing import (
from subprocess import check_output
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
@@ -41,14 +40,12 @@ EMPTY_SDXL_MAP = {
"prompt_encoder": None,
"unet": None,
"vae_decode": None,
"scheduler": None,
}
EMPTY_SD3_MAP = {
"clip": None,
"mmdit": None,
"vae": None,
"scheduler": None,
}
EMPTY_FLAGS = {
@@ -90,10 +87,10 @@ class StableDiffusion:
height: int,
width: int,
batch_size: int,
steps: int,
scheduler: str,
precision: str,
device: str,
steps: int = 50,
scheduler_id: str = None,
clip_device: str = None,
vae_device: str = None,
target_triple: str = None,
@@ -103,6 +100,7 @@ class StableDiffusion:
is_controlled: bool = False,
external_weights: str = "safetensors",
vae_precision: str = "fp16",
cpu_scheduling: bool = False,
progress=gr.Progress(),
):
progress(0, desc="Initializing pipeline...")
@@ -119,19 +117,15 @@ class StableDiffusion:
devices = {
"clip": clip_device,
"mmdit": backend,
"unet": backend,
"vae": vae_device,
}
targets = {
"clip": clip_target,
"mmdit": target,
"unet": target,
"vae": vae_target,
}
pipe_device_id = backend
target_triple = target
for key in devices:
if devices[key] != backend:
pipe_device_id = "hybrid"
target_triple = "_".join([clip_target, target, vae_target])
self.precision = precision
self.compiled_pipeline = False
self.base_model_id = base_model_id
@@ -171,27 +165,11 @@ class StableDiffusion:
targets = target
max_length = 64
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
str(max_length),
f"{str(height)}x{str(width)}",
precision,
target_triple,
]
if num_loras > 0:
pipe_id_list.append(str(num_loras) + "lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(
os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
)
)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), "vmfbs"))
self.weights_path = Path(os.path.join(get_checkpoints_path(), "weights"))
if not os.path.exists(self.pipeline_dir):
os.mkdir(self.pipeline_dir)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
@@ -211,7 +189,7 @@ class StableDiffusion:
progress(0.5, desc="Initializing pipeline...")
self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
scheduler_id=scheduler_id,
height=height,
width=width,
precision=precision,
@@ -227,6 +205,7 @@ class StableDiffusion:
external_weights_dir=self.weights_path,
external_weights=external_weights,
vae_precision=vae_precision,
cpu_scheduling=cpu_scheduling,
)
progress(1, desc="Pipeline initialized!...")
gc.collect()
@@ -331,16 +310,23 @@ class StableDiffusion:
resample_type,
control_mode,
hints,
progress=gr.Progress(),
steps=None,
cpu_scheduling=False,
scheduler_id=None,
progress=gr.Progress(track_tqdm=True),
):
img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
1,
guidance_scale,
seed,
prompt=prompt,
negative_prompt=negative_prompt,
batch_count=1,
guidance_scale=guidance_scale,
seed=seed,
return_imgs=True,
steps=steps,
cpu_scheduling=cpu_scheduling,
scheduler_id=scheduler_id,
progress=gr.Progress(track_tqdm=True),
)
return img
@@ -362,11 +348,8 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
if not sd_kwargs["device"]:
gr.Warning("No device specified. Please specify a device.")
return None, ""
if sd_kwargs["height"] not in [512, 1024]:
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
return None, ""
if sd_kwargs["height"] != sd_kwargs["width"]:
gr.Warning("Height and width must be the same. This is a temporary limitation.")
if sd_kwargs["height"] != 512 and sd_kwargs["width"] != 512 and sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
gr.Warning("SDXL turbo output size must be 512x512. This is a temporary limitation.")
return None, ""
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
if sd_kwargs["steps"] > 10:
@@ -416,6 +399,7 @@ def shark_sd_fn(
clip_device: str = None,
vae_device: str = None,
vae_precision: str = None,
cpu_scheduling: bool = False,
progress=gr.Progress(),
):
sd_kwargs = locals()
@@ -471,8 +455,6 @@ def shark_sd_fn(
"num_loras": num_loras,
"import_ir": import_ir,
"is_controlled": is_controlled,
"steps": steps,
"scheduler": scheduler,
"vae_precision": vae_precision,
}
submit_prep_kwargs = {
@@ -486,6 +468,7 @@ def shark_sd_fn(
"prompt": prompt,
"negative_prompt": negative_prompt,
"image": sd_init_image,
"steps": steps,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
@@ -493,10 +476,12 @@ def shark_sd_fn(
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
"cpu_scheduling": cpu_scheduling,
"scheduler_id": scheduler,
}
if global_obj.get_sd_obj() and global_obj.get_sd_obj().dynamic_steps:
submit_run_kwargs["steps"] = submit_pipe_kwargs["steps"]
submit_pipe_kwargs.pop("steps")
if compiled_pipeline:
submit_pipe_kwargs["steps"] = submit_run_kwargs["steps"]
submit_pipe_kwargs["scheduler_id"] = submit_run_kwargs["scheduler_id"]
if (
not global_obj.get_sd_obj()
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
@@ -545,14 +530,12 @@ def shark_sd_fn(
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
breakpoint()
generated_imgs.extend(out_imgs[batch])
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
if batch_count > 1:
submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment)
return (generated_imgs, "")
@@ -575,7 +558,11 @@ def unload_sd():
def cancel_sd():
print("Inject call to cancel longer API calls.")
import apps.shark_studio.web.utils.globals as global_obj
print("Cancelling...")
global_obj.get_sd_obj()._interrupt = True
while global_obj.get_sd_obj()._interrupt:
time.sleep(0.1)
return

View File

@@ -101,18 +101,18 @@ def export_scheduler_model(model):
scheduler_model_map = {
# "PNDM": export_scheduler_model("PNDMScheduler"),
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"PNDM": export_scheduler_model("PNDMScheduler"),
"DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
# "LCM": export_scheduler_model("LCMScheduler"),
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
# "DDPM": export_scheduler_model("DDPMScheduler"),
# "DDIM": export_scheduler_model("DDIMScheduler"),
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
"LCM": export_scheduler_model("LCMScheduler"),
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
"DDPM": export_scheduler_model("DDPMScheduler"),
"DDIM": export_scheduler_model("DDIMScheduler"),
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}

View File

@@ -121,6 +121,7 @@ def pull_sd_configs(
custom_weights,
custom_vae,
precision,
vae_precision,
device,
target_triple,
ondemand,
@@ -128,6 +129,7 @@ def pull_sd_configs(
resample_type,
controlnets,
embeddings,
cpu_scheduling,
):
sd_args = str_none_to_none(locals())
sd_cfg = {}
@@ -189,6 +191,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["custom_weights"],
sd_json["custom_vae"],
sd_json["precision"],
sd_json["vae_precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
@@ -196,6 +199,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
sd_json["cpu_scheduling"],
sd_json,
]
@@ -248,44 +252,17 @@ def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
if "turbo" in base_model_id:
new_steps = gr.Dropdown(
value=2,
choices=[1, 2],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
if "stable-diffusion-xl-base-1.0" in base_model_id:
new_steps = gr.Dropdown(
value=40,
choices=[20, 25, 30, 35, 40, 45, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
elif ".py" in base_model_id:
new_steps = gr.Dropdown(
value=20,
choices=[10, 15, 20],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
else:
new_steps = gr.Dropdown(
value=20,
choices=[10, 20, 30, 40, 50],
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
return [
gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
choices=["None"] + new_choices,
),
new_steps,
gr.update(),
]
init_config = global_obj.get_init_config()
if not os.path.exists(init_config):
write_default_sd_configs(get_configs_path())
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
with gr.Blocks(title="Stable Diffusion") as sd_element:
@@ -339,23 +316,43 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
"fp16",
"fp32",
],
visible=False,
visible=True,
)
vae_precision = gr.Radio(
label="VAE Precision",
value=init_config["precision"],
choices=[
"fp16",
"fp32",
],
visible=True,
)
cpu_scheduling = gr.Checkbox(
value=init_config["ondemand"],
label="CPU scheduling",
interactive=True,
visible=True,
)
compiled_pipeline = gr.Checkbox(
value=False,
label="Faster txt2img (SDXL only)",
visible=False, # DEMO
)
with gr.Row():
height = gr.Slider(
512,
256,
1024,
value=512,
step=512,
step=256,
label="\U00002195\U0000FE0F Height",
interactive=True, # DEMO
visible=True, # DEMO
)
width = gr.Slider(
512,
256,
1024,
value=512,
step=512,
step=256,
label="\U00002194\U0000FE0F Width",
interactive=True, # DEMO
visible=True, # DEMO
@@ -405,23 +402,24 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
allow_custom_value=False,
)
with gr.Row():
steps = gr.Dropdown(
value=20,
choices=[10, 15, 20],
steps = gr.Slider(
1,
50,
value=init_config["steps"],
step=1,
label="\U0001F3C3\U0000FE0F Steps",
allow_custom_value=True,
)
guidance_scale = gr.Slider(
0,
5, #DEMO
value=4,
20, #DEMO
value=7,
step=0.1,
label="\U0001F5C3\U0000FE0F CFG Scale",
)
with gr.Row():
batch_count = gr.Slider(
1,
100,
10,
value=init_config["batch_count"],
step=1,
label="Batch Count",
@@ -437,11 +435,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
interactive=False, # DEMO
visible=True,
)
compiled_pipeline = gr.Checkbox(
value=init_config["compiled_pipeline"],
label="Faster txt2img (SDXL only)",
visible=False, # DEMO
)
with gr.Row(elem_classes=["fill"], visible=False):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
@@ -479,7 +472,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Accordion(
label="\U00002696\U0000FE0F Model Weights",
open=False,
visible=False, # DEMO
visible=True, # DEMO
):
with gr.Column():
custom_weights = gr.Dropdown(
@@ -503,6 +496,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
visible=False,
)
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
"\\", "\n\\"
@@ -516,13 +510,16 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
multiselect=True,
choices=[] + get_checkpoints("lora"),
scale=2,
visible=False,
)
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
visible=False,
)
embeddings_config = gr.JSON(
label="Embeddings Options", min_width=50, scale=1
label="Embeddings Options", min_width=50, scale=1,
visible=False,
)
gr.on(
triggers=[lora_opt.change],
@@ -700,13 +697,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
)
with gr.Row():
stable_diffusion = gr.Button("Start")
stop_batch = gr.Button("Stop", visible=True)
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop", visible=False)
# with gr.Tab(label="Config", id=102) as sd_tab_config:
# with gr.Group():#elem_classes=["sd-right-panel"]):
# with gr.Row(elem_classes=["fill"], visible=False):
@@ -784,6 +781,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
custom_weights,
custom_vae,
precision,
vae_precision,
device,
target_triple,
ondemand,
@@ -791,6 +789,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
resample_type,
cnet_config,
embeddings_config,
cpu_scheduling,
sd_json,
],
)
@@ -818,6 +817,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
custom_weights,
custom_vae,
precision,
vae_precision,
device,
target_triple,
ondemand,
@@ -825,6 +825,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
resample_type,
cnet_config,
embeddings_config,
cpu_scheduling,
],
outputs=[
sd_json,