mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
[WEB] Update models to 8dec and also default values (#620)
1. Update the models to 8 dec. 2. precision is default to `fp16` in CLI. 3. version is default to `v2.1base` in CLI as well as web. 4. The default scheduler is set to `EulerDiscrete` now. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com> Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -23,7 +23,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v1.4",
|
||||
default="v2.1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp32", help="precision to run the model."
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
|
||||
@@ -101,13 +101,13 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
)
|
||||
version = gr.Radio(
|
||||
label="Version",
|
||||
value="v1.4",
|
||||
value="v2.1base",
|
||||
choices=["v1.4", "v2.1base"],
|
||||
)
|
||||
with gr.Row():
|
||||
scheduler_key = gr.Dropdown(
|
||||
label="Scheduler",
|
||||
value="DPMSolverMultistep",
|
||||
value="EulerDiscrete",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
@@ -174,9 +174,9 @@ with gr.Blocks(css=demo_css) as shark_web:
|
||||
outputs=[generated_img, std_output],
|
||||
)
|
||||
|
||||
shark_web.queue()
|
||||
shark_web.launch(
|
||||
share=False,
|
||||
server_name="0.0.0.0",
|
||||
server_port=8080,
|
||||
enable_queue=True,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
@@ -19,6 +20,16 @@ model_input = {
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v2.1base": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v1.4": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_unet(args):
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_1dec_fp16"
|
||||
model_name = "unet_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
@@ -56,7 +56,7 @@ def get_vae(args):
|
||||
)
|
||||
if args.precision == "fp16":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_1dec_fp16"
|
||||
model_name = "vae_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
@@ -119,7 +119,7 @@ def get_clip(args):
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "clip_1dec_fp32"
|
||||
model_name = "clip_8dec_fp32"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "clip2base_8dec_fp32"
|
||||
iree_flags += [
|
||||
|
||||
@@ -25,7 +25,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v1.4",
|
||||
default="v2.1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
@@ -60,8 +60,8 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="DPMSolverMultistep",
|
||||
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep]",
|
||||
default="EulerDiscrete",
|
||||
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep, EulerDiscrete]",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user