mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Updates to sd api and UI for sd3, sdxl QOL
This commit is contained in:
@@ -32,7 +32,6 @@ from apps.shark_studio.modules.img_processing import (
|
||||
from subprocess import check_output
|
||||
EMPTY_SD_MAP = {
|
||||
"clip": None,
|
||||
"scheduler": None,
|
||||
"unet": None,
|
||||
"vae_decode": None,
|
||||
}
|
||||
@@ -41,14 +40,12 @@ EMPTY_SDXL_MAP = {
|
||||
"prompt_encoder": None,
|
||||
"unet": None,
|
||||
"vae_decode": None,
|
||||
"scheduler": None,
|
||||
}
|
||||
|
||||
EMPTY_SD3_MAP = {
|
||||
"clip": None,
|
||||
"mmdit": None,
|
||||
"vae": None,
|
||||
"scheduler": None,
|
||||
}
|
||||
|
||||
EMPTY_FLAGS = {
|
||||
@@ -90,10 +87,10 @@ class StableDiffusion:
|
||||
height: int,
|
||||
width: int,
|
||||
batch_size: int,
|
||||
steps: int,
|
||||
scheduler: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
steps: int = 50,
|
||||
scheduler_id: str = None,
|
||||
clip_device: str = None,
|
||||
vae_device: str = None,
|
||||
target_triple: str = None,
|
||||
@@ -103,6 +100,7 @@ class StableDiffusion:
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
vae_precision: str = "fp16",
|
||||
cpu_scheduling: bool = False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
progress(0, desc="Initializing pipeline...")
|
||||
@@ -119,19 +117,15 @@ class StableDiffusion:
|
||||
devices = {
|
||||
"clip": clip_device,
|
||||
"mmdit": backend,
|
||||
"unet": backend,
|
||||
"vae": vae_device,
|
||||
}
|
||||
targets = {
|
||||
"clip": clip_target,
|
||||
"mmdit": target,
|
||||
"unet": target,
|
||||
"vae": vae_target,
|
||||
}
|
||||
pipe_device_id = backend
|
||||
target_triple = target
|
||||
for key in devices:
|
||||
if devices[key] != backend:
|
||||
pipe_device_id = "hybrid"
|
||||
target_triple = "_".join([clip_target, target, vae_target])
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
@@ -171,27 +165,11 @@ class StableDiffusion:
|
||||
targets = target
|
||||
max_length = 64
|
||||
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
str(max_length),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
target_triple,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras) + "lora")
|
||||
if is_controlled:
|
||||
pipe_id_list.append("controlled")
|
||||
if custom_vae:
|
||||
pipe_id_list.append(custom_vae)
|
||||
self.pipe_id = "_".join(pipe_id_list)
|
||||
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
|
||||
self.weights_path = Path(
|
||||
os.path.join(
|
||||
get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)
|
||||
)
|
||||
)
|
||||
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), "vmfbs"))
|
||||
self.weights_path = Path(os.path.join(get_checkpoints_path(), "weights"))
|
||||
if not os.path.exists(self.pipeline_dir):
|
||||
os.mkdir(self.pipeline_dir)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
|
||||
@@ -211,7 +189,7 @@ class StableDiffusion:
|
||||
progress(0.5, desc="Initializing pipeline...")
|
||||
self.sd_pipe = self.turbine_pipe(
|
||||
hf_model_name=base_model_id,
|
||||
scheduler_id=scheduler,
|
||||
scheduler_id=scheduler_id,
|
||||
height=height,
|
||||
width=width,
|
||||
precision=precision,
|
||||
@@ -227,6 +205,7 @@ class StableDiffusion:
|
||||
external_weights_dir=self.weights_path,
|
||||
external_weights=external_weights,
|
||||
vae_precision=vae_precision,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
progress(1, desc="Pipeline initialized!...")
|
||||
gc.collect()
|
||||
@@ -331,16 +310,23 @@ class StableDiffusion:
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
progress=gr.Progress(),
|
||||
steps=None,
|
||||
cpu_scheduling=False,
|
||||
scheduler_id=None,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
1,
|
||||
guidance_scale,
|
||||
seed,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
batch_count=1,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
return_imgs=True,
|
||||
steps=steps,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
scheduler_id=scheduler_id,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
)
|
||||
return img
|
||||
|
||||
@@ -362,11 +348,8 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
if not sd_kwargs["device"]:
|
||||
gr.Warning("No device specified. Please specify a device.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] not in [512, 1024]:
|
||||
gr.Warning("Height must be 512 or 1024. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["height"] != sd_kwargs["width"]:
|
||||
gr.Warning("Height and width must be the same. This is a temporary limitation.")
|
||||
if sd_kwargs["height"] != 512 and sd_kwargs["width"] != 512 and sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
gr.Warning("SDXL turbo output size must be 512x512. This is a temporary limitation.")
|
||||
return None, ""
|
||||
if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo":
|
||||
if sd_kwargs["steps"] > 10:
|
||||
@@ -416,6 +399,7 @@ def shark_sd_fn(
|
||||
clip_device: str = None,
|
||||
vae_device: str = None,
|
||||
vae_precision: str = None,
|
||||
cpu_scheduling: bool = False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
@@ -471,8 +455,6 @@ def shark_sd_fn(
|
||||
"num_loras": num_loras,
|
||||
"import_ir": import_ir,
|
||||
"is_controlled": is_controlled,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
"vae_precision": vae_precision,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
@@ -486,6 +468,7 @@ def shark_sd_fn(
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": sd_init_image,
|
||||
"steps": steps,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
@@ -493,10 +476,12 @@ def shark_sd_fn(
|
||||
"resample_type": resample_type,
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
"cpu_scheduling": cpu_scheduling,
|
||||
"scheduler_id": scheduler,
|
||||
}
|
||||
if global_obj.get_sd_obj() and global_obj.get_sd_obj().dynamic_steps:
|
||||
submit_run_kwargs["steps"] = submit_pipe_kwargs["steps"]
|
||||
submit_pipe_kwargs.pop("steps")
|
||||
if compiled_pipeline:
|
||||
submit_pipe_kwargs["steps"] = submit_run_kwargs["steps"]
|
||||
submit_pipe_kwargs["scheduler_id"] = submit_run_kwargs["scheduler_id"]
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
@@ -545,14 +530,12 @@ def shark_sd_fn(
|
||||
seed,
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
breakpoint()
|
||||
generated_imgs.extend(out_imgs[batch])
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
if batch_count > 1:
|
||||
submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment)
|
||||
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
@@ -575,7 +558,11 @@ def unload_sd():
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
print("Inject call to cancel longer API calls.")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
print("Cancelling...")
|
||||
global_obj.get_sd_obj()._interrupt = True
|
||||
while global_obj.get_sd_obj()._interrupt:
|
||||
time.sleep(0.1)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -101,18 +101,18 @@ def export_scheduler_model(model):
|
||||
|
||||
|
||||
scheduler_model_map = {
|
||||
# "PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
"DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
# "LCM": export_scheduler_model("LCMScheduler"),
|
||||
# "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
# "DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
# "DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
# "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
# "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
# "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
# "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
# "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
# "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
"LCM": export_scheduler_model("LCMScheduler"),
|
||||
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
|
||||
"DDPM": export_scheduler_model("DDPMScheduler"),
|
||||
"DDIM": export_scheduler_model("DDIMScheduler"),
|
||||
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
|
||||
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
|
||||
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
|
||||
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
|
||||
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
|
||||
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
|
||||
}
|
||||
|
||||
@@ -121,6 +121,7 @@ def pull_sd_configs(
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
vae_precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
@@ -128,6 +129,7 @@ def pull_sd_configs(
|
||||
resample_type,
|
||||
controlnets,
|
||||
embeddings,
|
||||
cpu_scheduling,
|
||||
):
|
||||
sd_args = str_none_to_none(locals())
|
||||
sd_cfg = {}
|
||||
@@ -189,6 +191,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["custom_weights"],
|
||||
sd_json["custom_vae"],
|
||||
sd_json["precision"],
|
||||
sd_json["vae_precision"],
|
||||
sd_json["device"],
|
||||
sd_json["target_triple"],
|
||||
sd_json["ondemand"],
|
||||
@@ -196,6 +199,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_json["resample_type"],
|
||||
sd_json["controlnets"],
|
||||
sd_json["embeddings"],
|
||||
sd_json["cpu_scheduling"],
|
||||
sd_json,
|
||||
]
|
||||
|
||||
@@ -248,44 +252,17 @@ def base_model_changed(base_model_id):
|
||||
new_choices = get_checkpoints(
|
||||
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
|
||||
) + get_checkpoints(model_type="checkpoints")
|
||||
if "turbo" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=2,
|
||||
choices=[1, 2],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
if "stable-diffusion-xl-base-1.0" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=40,
|
||||
choices=[20, 25, 30, 35, 40, 45, 50],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
elif ".py" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 15, 20],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
else:
|
||||
new_steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 20, 30, 40, 50],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
choices=["None"] + new_choices,
|
||||
),
|
||||
new_steps,
|
||||
gr.update(),
|
||||
]
|
||||
|
||||
init_config = global_obj.get_init_config()
|
||||
if not os.path.exists(init_config):
|
||||
write_default_sd_configs(get_configs_path())
|
||||
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
@@ -339,23 +316,43 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
visible=True,
|
||||
)
|
||||
vae_precision = gr.Radio(
|
||||
label="VAE Precision",
|
||||
value=init_config["precision"],
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
cpu_scheduling = gr.Checkbox(
|
||||
value=init_config["ondemand"],
|
||||
label="CPU scheduling",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
value=False,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
visible=False, # DEMO
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
512,
|
||||
256,
|
||||
1024,
|
||||
value=512,
|
||||
step=512,
|
||||
step=256,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
interactive=True, # DEMO
|
||||
visible=True, # DEMO
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
256,
|
||||
1024,
|
||||
value=512,
|
||||
step=512,
|
||||
step=256,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
interactive=True, # DEMO
|
||||
visible=True, # DEMO
|
||||
@@ -405,23 +402,24 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 15, 20],
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
50,
|
||||
value=init_config["steps"],
|
||||
step=1,
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
5, #DEMO
|
||||
value=4,
|
||||
20, #DEMO
|
||||
value=7,
|
||||
step=0.1,
|
||||
label="\U0001F5C3\U0000FE0F CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
10,
|
||||
value=init_config["batch_count"],
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
@@ -437,11 +435,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
interactive=False, # DEMO
|
||||
visible=True,
|
||||
)
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
value=init_config["compiled_pipeline"],
|
||||
label="Faster txt2img (SDXL only)",
|
||||
visible=False, # DEMO
|
||||
)
|
||||
with gr.Row(elem_classes=["fill"], visible=False):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
@@ -479,7 +472,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Accordion(
|
||||
label="\U00002696\U0000FE0F Model Weights",
|
||||
open=False,
|
||||
visible=False, # DEMO
|
||||
visible=True, # DEMO
|
||||
):
|
||||
with gr.Column():
|
||||
custom_weights = gr.Dropdown(
|
||||
@@ -503,6 +496,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
choices=["None"] + get_checkpoints("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
visible=False,
|
||||
)
|
||||
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
|
||||
"\\", "\n\\"
|
||||
@@ -516,13 +510,16 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
multiselect=True,
|
||||
choices=[] + get_checkpoints("lora"),
|
||||
scale=2,
|
||||
visible=False,
|
||||
)
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
visible=False,
|
||||
)
|
||||
embeddings_config = gr.JSON(
|
||||
label="Embeddings Options", min_width=50, scale=1
|
||||
label="Embeddings Options", min_width=50, scale=1,
|
||||
visible=False,
|
||||
)
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
@@ -700,13 +697,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Start")
|
||||
stop_batch = gr.Button("Stop", visible=True)
|
||||
unload = gr.Button("Unload Models")
|
||||
unload.click(
|
||||
fn=unload_sd,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop", visible=False)
|
||||
# with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
# with gr.Group():#elem_classes=["sd-right-panel"]):
|
||||
# with gr.Row(elem_classes=["fill"], visible=False):
|
||||
@@ -784,6 +781,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
vae_precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
@@ -791,6 +789,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
cpu_scheduling,
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
@@ -818,6 +817,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
vae_precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
@@ -825,6 +825,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
cpu_scheduling,
|
||||
],
|
||||
outputs=[
|
||||
sd_json,
|
||||
|
||||
Reference in New Issue
Block a user