Update Stable diffusion script to enable use of tuned models (#443)

This commit is contained in:
Quinn Dawkins
2022-10-29 01:42:49 -04:00
committed by GitHub
parent 7f37599a60
commit 239c19eb12
2 changed files with 35 additions and 8 deletions

View File

@@ -21,6 +21,9 @@ UNET_FP16 = "unet_fp16"
UNET_FP32 = "unet_fp32" UNET_FP32 = "unet_fp32"
IREE_EXTRA_ARGS = [] IREE_EXTRA_ARGS = []
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
UNET_FP16_TUNED = "unet_fp16_tuned"
# Helper function to profile the vulkan device. # Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"): def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device: if args.vulkan_debug_utils and "vulkan" in args.device:
@@ -42,24 +45,41 @@ def get_models():
global IREE_EXTRA_ARGS global IREE_EXTRA_ARGS
if args.precision == "fp16": if args.precision == "fp16":
IREE_EXTRA_ARGS += [ IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops", "--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32", "--iree-flow-linalg-ops-padding-size=32",
"--iree-spirv-unify-aliased-resources=false",
] ]
if args.use_tuned:
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
vae_gcloud_bucket = GCLOUD_BUCKET
unet_args = IREE_EXTRA_ARGS
vae_args = IREE_EXTRA_ARGS + [
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
]
unet_name = UNET_FP16_TUNED
vae_name = VAE_FP16
else:
unet_gcloud_bucket = GCLOUD_BUCKET
vae_gcloud_bucket = GCLOUD_BUCKET
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
]
unet_args = IREE_EXTRA_ARGS
vae_args = IREE_EXTRA_ARGS
unet_name = UNET_FP16
vae_name = VAE_FP16
if args.import_mlir == True: if args.import_mlir == True:
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped( return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
model_name=UNET_FP16 model_name=UNET_FP16
) )
else: else:
return get_shark_model( return get_shark_model(
GCLOUD_BUCKET, vae_gcloud_bucket,
VAE_FP16, vae_name,
IREE_EXTRA_ARGS, vae_args,
), get_shark_model( ), get_shark_model(
GCLOUD_BUCKET, unet_gcloud_bucket,
UNET_FP16, unet_name,
IREE_EXTRA_ARGS, unet_args,
) )
elif args.precision == "fp32": elif args.precision == "fp32":

View File

@@ -78,4 +78,11 @@ p.add_argument(
help="Profiles vulkan device and collects the .rdc info", help="Profiles vulkan device and collects the .rdc info",
) )
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
args = p.parse_args() args = p.parse_args()