[WEB] Add schedulers in the web UI (#594)

1. Add schedulers option in web UI.
2. Remove random seed checkbox as the same functionality can be achieved
   by passing -1(or any negative number) to the seed.

Signed-Off-by: Gaurav Shukla

Signed-off-by: Gaurav Shukla
This commit is contained in:
Gaurav Shukla
2022-12-09 03:23:20 +05:30
committed by GitHub
parent 0225292a44
commit b62ee3fcb9
3 changed files with 41 additions and 44 deletions

View File

@@ -91,28 +91,33 @@ with gr.Blocks(css=demo_css) as shark_web:
elem_id="prompt_examples",
)
with gr.Row():
with gr.Group():
steps = gr.Slider(
1, 100, value=50, step=1, label="Steps"
)
guidance = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="Guidance Scale",
)
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
guidance = gr.Slider(
0,
50,
value=7.5,
step=0.1,
label="Guidance Scale",
)
with gr.Row():
scheduler_key = gr.Dropdown(
label="Scheduler",
value="DPMSolverMultistep",
choices=[
"DDIM",
"PNDM",
"LMSDiscrete",
"DPMSolverMultistep",
],
)
with gr.Group():
random_seed = gr.Button("Randomize Seed").style(
full_width=True
)
uint32_info = np.iinfo(np.uint32)
rand_seed = randint(uint32_info.min, uint32_info.max)
random_val = randint(uint32_info.min, uint32_info.max)
seed = gr.Number(
value=rand_seed, precision=0, show_label=False
)
generate_seed = gr.Checkbox(
value=False, label="use random seed"
value=random_val, precision=0, show_label=False
)
u32_min = gr.Number(
value=uint32_info.min, visible=False
@@ -145,7 +150,7 @@ with gr.Blocks(css=demo_css) as shark_web:
steps,
guidance,
seed,
generate_seed,
scheduler_key,
],
outputs=[generated_img, std_output],
)
@@ -156,7 +161,7 @@ with gr.Blocks(css=demo_css) as shark_web:
steps,
guidance,
seed,
generate_seed,
scheduler_key,
],
outputs=[generated_img, std_output],
)

View File

@@ -11,23 +11,17 @@ from models.stable_diffusion.stable_args import args
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
schedulers["LMSDiscrete"] = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
schedulers["DDIM"] = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="scheduler",
)
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
"CompVis/stable-diffusion-v1-4",
@@ -50,6 +44,3 @@ cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
cache_obj["unet"],
cache_obj["clip"],
) = (get_vae(args), get_unet(args), get_clip(args))
# cache scheduler
cache_obj["scheduler"] = schedulers[args.scheduler]

View File

@@ -1,18 +1,19 @@
import torch
from PIL import Image
from tqdm.auto import tqdm
from models.stable_diffusion.cache_objects import cache_obj
from models.stable_diffusion.cache_objects import cache_obj, schedulers
from models.stable_diffusion.stable_args import args
from random import randint
import numpy as np
import time
def set_ui_params(prompt, steps, guidance, seed):
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
args.prompt = [prompt]
args.steps = steps
args.guidance = guidance
args.seed = seed
args.scheduler = scheduler_key
def stable_diff_inf(
@@ -20,29 +21,29 @@ def stable_diff_inf(
steps: int,
guidance: float,
seed: int,
generate_seed: bool,
scheduler_key: str,
):
# Handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if generate_seed or seed < uint32_min or seed >= uint32_max:
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
set_ui_params(prompt, steps, guidance, seed)
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
dtype = torch.float32 if args.precision == "fp32" else torch.half
generator = torch.manual_seed(
args.seed
) # Seed generator to create the inital latent noise
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
# Initialize vae and unet models.
vae, unet, clip, tokenizer, scheduler = (
vae, unet, clip, tokenizer = (
cache_obj["vae"],
cache_obj["unet"],
cache_obj["clip"],
cache_obj["tokenizer"],
cache_obj["scheduler"],
)
scheduler = schedulers[args.scheduler]
start = time.time()
text_input = tokenizer(
@@ -124,7 +125,7 @@ def stable_diff_inf(
total_time = time.time() - start
text_output = f"prompt={args.prompt}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, seed={args.seed}, scheduler={args.scheduler}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}"
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
print(f"\nAverage step time: {avg_ms}ms/it")
text_output += "\nTotal image generation time: {0:.2f}sec".format(