mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
[WEB] Add schedulers in the web UI (#594)
1. Add schedulers option in web UI. 2. Remove random seed checkbox as the same functionality can be achieved by passing -1(or any negative number) to the seed. Signed-Off-by: Gaurav Shukla Signed-off-by: Gaurav Shukla
This commit is contained in:
41
web/index.py
41
web/index.py
@@ -91,28 +91,33 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Group():
|
||||
steps = gr.Slider(
|
||||
1, 100, value=50, step=1, label="Steps"
|
||||
)
|
||||
guidance = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=7.5,
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
)
|
||||
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
|
||||
guidance = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=7.5,
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
scheduler_key = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="DPMSolverMultistep",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
],
|
||||
)
|
||||
with gr.Group():
|
||||
random_seed = gr.Button("Randomize Seed").style(
|
||||
full_width=True
|
||||
)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
rand_seed = randint(uint32_info.min, uint32_info.max)
|
||||
random_val = randint(uint32_info.min, uint32_info.max)
|
||||
seed = gr.Number(
|
||||
value=rand_seed, precision=0, show_label=False
|
||||
)
|
||||
generate_seed = gr.Checkbox(
|
||||
value=False, label="use random seed"
|
||||
value=random_val, precision=0, show_label=False
|
||||
)
|
||||
u32_min = gr.Number(
|
||||
value=uint32_info.min, visible=False
|
||||
@@ -145,7 +150,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
steps,
|
||||
guidance,
|
||||
seed,
|
||||
generate_seed,
|
||||
scheduler_key,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
@@ -156,7 +161,7 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
steps,
|
||||
guidance,
|
||||
seed,
|
||||
generate_seed,
|
||||
scheduler_key,
|
||||
],
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
@@ -11,23 +11,17 @@ from models.stable_diffusion.stable_args import args
|
||||
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
@@ -50,6 +44,3 @@ cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
) = (get_vae(args), get_unet(args), get_clip(args))
|
||||
|
||||
# cache scheduler
|
||||
cache_obj["scheduler"] = schedulers[args.scheduler]
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from models.stable_diffusion.cache_objects import cache_obj
|
||||
from models.stable_diffusion.cache_objects import cache_obj, schedulers
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from random import randint
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def set_ui_params(prompt, steps, guidance, seed):
|
||||
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
|
||||
args.prompt = [prompt]
|
||||
args.steps = steps
|
||||
args.guidance = guidance
|
||||
args.seed = seed
|
||||
args.scheduler = scheduler_key
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
@@ -20,29 +21,29 @@ def stable_diff_inf(
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed: int,
|
||||
generate_seed: bool,
|
||||
scheduler_key: str,
|
||||
):
|
||||
|
||||
# Handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if generate_seed or seed < uint32_min or seed >= uint32_max:
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
set_ui_params(prompt, steps, guidance, seed)
|
||||
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
|
||||
# Initialize vae and unet models.
|
||||
vae, unet, clip, tokenizer, scheduler = (
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
cache_obj["tokenizer"],
|
||||
cache_obj["scheduler"],
|
||||
)
|
||||
scheduler = schedulers[args.scheduler]
|
||||
|
||||
start = time.time()
|
||||
text_input = tokenizer(
|
||||
@@ -124,7 +125,7 @@ def stable_diff_inf(
|
||||
total_time = time.time() - start
|
||||
|
||||
text_output = f"prompt={args.prompt}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, seed={args.seed}, scheduler={args.scheduler}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}"
|
||||
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
text_output += "\nTotal image generation time: {0:.2f}sec".format(
|
||||
|
||||
Reference in New Issue
Block a user