Update Stable diffusion script to enable use of tuned models (#443)

This commit is contained in:
Quinn Dawkins
2022-10-29 01:42:49 -04:00
committed by GitHub
parent 7f37599a60
commit 239c19eb12
2 changed files with 35 additions and 8 deletions

View File

@@ -21,6 +21,9 @@ UNET_FP16 = "unet_fp16"
UNET_FP32 = "unet_fp32"
IREE_EXTRA_ARGS = []
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
UNET_FP16_TUNED = "unet_fp16_tuned"
# Helper function to profile the vulkan device.
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
if args.vulkan_debug_utils and "vulkan" in args.device:
@@ -42,24 +45,41 @@ def get_models():
global IREE_EXTRA_ARGS
if args.precision == "fp16":
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-spirv-unify-aliased-resources=false",
]
if args.use_tuned:
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
vae_gcloud_bucket = GCLOUD_BUCKET
unet_args = IREE_EXTRA_ARGS
vae_args = IREE_EXTRA_ARGS + [
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
]
unet_name = UNET_FP16_TUNED
vae_name = VAE_FP16
else:
unet_gcloud_bucket = GCLOUD_BUCKET
vae_gcloud_bucket = GCLOUD_BUCKET
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
]
unet_args = IREE_EXTRA_ARGS
vae_args = IREE_EXTRA_ARGS
unet_name = UNET_FP16
vae_name = VAE_FP16
if args.import_mlir == True:
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
model_name=UNET_FP16
)
else:
return get_shark_model(
GCLOUD_BUCKET,
VAE_FP16,
IREE_EXTRA_ARGS,
vae_gcloud_bucket,
vae_name,
vae_args,
), get_shark_model(
GCLOUD_BUCKET,
UNET_FP16,
IREE_EXTRA_ARGS,
unet_gcloud_bucket,
unet_name,
unet_args,
)
elif args.precision == "fp32":

View File

@@ -78,4 +78,11 @@ p.add_argument(
help="Profiles vulkan device and collects the .rdc info",
)
p.add_argument(
"--use_tuned",
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
args = p.parse_args()