Pass the flags to vae.

This commit is contained in:
Prashant Kumar
2022-10-23 11:19:24 -07:00
parent 2741b8be53
commit a48eaaed20

View File

@@ -48,7 +48,15 @@ def get_models():
if args.import_mlir == True:
return get_vae32(), get_unet32_wrapped()
else:
return get_shark_model(GCLOUD_BUCKET, VAE_FP32), get_shark_model(
return get_shark_model(
GCLOUD_BUCKET,
VAE_FP32,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
), get_shark_model(
GCLOUD_BUCKET,
UNET_FP32,
[