mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[WEB][SD] Make unet tuned model default for rdna3 devices (#642)
This commit is contained in:
@@ -9,6 +9,7 @@ from diffusers import (
|
||||
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
|
||||
from models.stable_diffusion.utils import set_iree_runtime_flags
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
|
||||
|
||||
model_config = {
|
||||
@@ -39,6 +40,9 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# set use_tuned
|
||||
if "rdna3" not in get_vulkan_triple_flag():
|
||||
args.use_tuned = False
|
||||
|
||||
# set iree-runtime flags
|
||||
set_iree_runtime_flags(args)
|
||||
|
||||
@@ -18,7 +18,10 @@ def get_unet(args):
|
||||
if args.precision == "fp16":
|
||||
if args.use_tuned:
|
||||
bucket = "gs://shark_tank/vivian"
|
||||
model_name = "unet_1dec_fp16_tuned"
|
||||
if args.version == "v1.4":
|
||||
model_name = "unet_1dec_fp16_tuned"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16_tuned"
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
|
||||
@@ -105,7 +105,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user