[WEB] Update models to 8dec and also default values (#620)

1. Update the models to 8 dec.
2. precision is default to `fp16` in CLI.
3. version is default to `v2.1base` in CLI as well as web.
4. The default scheduler is set to `EulerDiscrete` now.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
Gaurav Shukla
2022-12-14 02:38:33 +05:30
committed by GitHub
parent 08e373aef4
commit d913453e57
5 changed files with 22 additions and 11 deletions

View File

@@ -23,7 +23,7 @@ p.add_argument(
p.add_argument(
"--version",
type=str,
default="v1.4",
default="v2.1base",
help="Specify version of stable diffusion model",
)
@@ -48,7 +48,7 @@ p.add_argument(
)
p.add_argument(
"--precision", type=str, default="fp32", help="precision to run the model."
"--precision", type=str, default="fp16", help="precision to run the model."
)
p.add_argument(

View File

@@ -101,13 +101,13 @@ with gr.Blocks(css=demo_css) as shark_web:
)
version = gr.Radio(
label="Version",
value="v1.4",
value="v2.1base",
choices=["v1.4", "v2.1base"],
)
with gr.Row():
scheduler_key = gr.Dropdown(
label="Scheduler",
value="DPMSolverMultistep",
value="EulerDiscrete",
choices=[
"DDIM",
"PNDM",
@@ -174,9 +174,9 @@ with gr.Blocks(css=demo_css) as shark_web:
outputs=[generated_img, std_output],
)
shark_web.queue()
shark_web.launch(
share=False,
server_name="0.0.0.0",
server_port=8080,
enable_queue=True,
)

View File

@@ -5,6 +5,7 @@ import torch
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
}
@@ -19,6 +20,16 @@ model_input = {
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2.1base": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),

View File

@@ -22,7 +22,7 @@ def get_unet(args):
return get_shark_model(args, bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp16"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
iree_flags += [
@@ -56,7 +56,7 @@ def get_vae(args):
)
if args.precision == "fp16":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp16"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
iree_flags += [
@@ -119,7 +119,7 @@ def get_clip(args):
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_1dec_fp32"
model_name = "clip_8dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_8dec_fp32"
iree_flags += [

View File

@@ -25,7 +25,7 @@ p.add_argument(
p.add_argument(
"--version",
type=str,
default="v1.4",
default="v2.1base",
help="Specify version of stable diffusion model",
)
@@ -60,8 +60,8 @@ p.add_argument(
p.add_argument(
"--scheduler",
type=str,
default="DPMSolverMultistep",
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep]",
default="EulerDiscrete",
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep, EulerDiscrete]",
)
p.add_argument(