Minor fix.

This commit is contained in:
Prashant Kumar
2022-10-20 22:56:03 +05:30
parent 38ae6b5af4
commit 5fe22a7980

View File

@@ -23,7 +23,7 @@ UNET_FP32 = "unet_fp32"
def get_models():
if args.precision == "fp16":
if args.import_mlir == True:
return get_unet16_wrapped(), get_vae16()
return get_vae16(), get_unet16_wrapped()
else:
return get_shark_model(GCLOUD_BUCKET, VAE_FP16), get_shark_model(
GCLOUD_BUCKET, UNET_FP16