diff --git a/shark/examples/shark_inference/stable_diffusion/main.py b/shark/examples/shark_inference/stable_diffusion/main.py index a0f5fc58..3b7c3220 100644 --- a/shark/examples/shark_inference/stable_diffusion/main.py +++ b/shark/examples/shark_inference/stable_diffusion/main.py @@ -21,6 +21,9 @@ UNET_FP16 = "unet_fp16" UNET_FP32 = "unet_fp32" IREE_EXTRA_ARGS = [] +TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn" +UNET_FP16_TUNED = "unet_fp16_tuned" + # Helper function to profile the vulkan device. def start_profiling(file_path="foo.rdc", profiling_mode="queue"): if args.vulkan_debug_utils and "vulkan" in args.device: @@ -42,24 +45,41 @@ def get_models(): global IREE_EXTRA_ARGS if args.precision == "fp16": IREE_EXTRA_ARGS += [ - "--iree-flow-enable-conv-nchw-to-nhwc-transform", "--iree-flow-enable-padding-linalg-ops", "--iree-flow-linalg-ops-padding-size=32", - "--iree-spirv-unify-aliased-resources=false", ] + if args.use_tuned: + unet_gcloud_bucket = TUNED_GCLOUD_BUCKET + vae_gcloud_bucket = GCLOUD_BUCKET + unet_args = IREE_EXTRA_ARGS + vae_args = IREE_EXTRA_ARGS + [ + "--iree-flow-enable-conv-nchw-to-nhwc-transform" + ] + unet_name = UNET_FP16_TUNED + vae_name = VAE_FP16 + else: + unet_gcloud_bucket = GCLOUD_BUCKET + vae_gcloud_bucket = GCLOUD_BUCKET + IREE_EXTRA_ARGS += [ + "--iree-flow-enable-conv-nchw-to-nhwc-transform" + ] + unet_args = IREE_EXTRA_ARGS + vae_args = IREE_EXTRA_ARGS + unet_name = UNET_FP16 + vae_name = VAE_FP16 if args.import_mlir == True: return get_vae16(model_name=VAE_FP16), get_unet16_wrapped( model_name=UNET_FP16 ) else: return get_shark_model( - GCLOUD_BUCKET, - VAE_FP16, - IREE_EXTRA_ARGS, + vae_gcloud_bucket, + vae_name, + vae_args, ), get_shark_model( - GCLOUD_BUCKET, - UNET_FP16, - IREE_EXTRA_ARGS, + unet_gcloud_bucket, + unet_name, + unet_args, ) elif args.precision == "fp32": diff --git a/shark/examples/shark_inference/stable_diffusion/stable_args.py b/shark/examples/shark_inference/stable_diffusion/stable_args.py index 7bdd8834..c4b80c36 100644 --- a/shark/examples/shark_inference/stable_diffusion/stable_args.py +++ b/shark/examples/shark_inference/stable_diffusion/stable_args.py @@ -78,4 +78,11 @@ p.add_argument( help="Profiles vulkan device and collects the .rdc info", ) +p.add_argument( + "--use_tuned", + default=True, + action=argparse.BooleanOptionalAction, + help="Download and use the tuned version of the model if available", +) + args = p.parse_args()