[WEB][SD] Make unet tuned model default for rdna3 devices (#642)

This commit is contained in:
Gaurav Shukla
2022-12-16 01:32:03 +05:30
committed by GitHub
parent 2928179331
commit e7e763551a
3 changed files with 9 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ from diffusers import (
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
from models.stable_diffusion.utils import set_iree_runtime_flags
from models.stable_diffusion.stable_args import args
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
model_config = {
@@ -39,6 +40,9 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
subfolder="scheduler",
)
# set use_tuned
if "rdna3" not in get_vulkan_triple_flag():
args.use_tuned = False
# set iree-runtime flags
set_iree_runtime_flags(args)

View File

@@ -18,7 +18,10 @@ def get_unet(args):
if args.precision == "fp16":
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
model_name = "unet_1dec_fp16_tuned"
if args.version == "v1.4":
model_name = "unet_1dec_fp16_tuned"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16_tuned"
return get_shark_model(args, bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"

View File

@@ -105,7 +105,7 @@ p.add_argument(
p.add_argument(
"--use_tuned",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)