Pass the flags to vae. (#422)

This commit is contained in:
Prashant Kumar
2022-10-23 23:53:13 +05:30
committed by GitHub
parent 4f906a265c
commit 2741b8be53

View File

@@ -26,8 +26,22 @@ def get_models():
if args.import_mlir == True:
return get_vae16(), get_unet16_wrapped()
else:
return get_shark_model(GCLOUD_BUCKET, VAE_FP16), get_shark_model(
GCLOUD_BUCKET, UNET_FP16
return get_shark_model(
GCLOUD_BUCKET,
VAE_FP16,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
), get_shark_model(
GCLOUD_BUCKET,
UNET_FP16,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
)
elif args.precision == "fp32":