mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Pass the flags to vae. (#422)
This commit is contained in:
@@ -26,8 +26,22 @@ def get_models():
|
||||
if args.import_mlir == True:
|
||||
return get_vae16(), get_unet16_wrapped()
|
||||
else:
|
||||
return get_shark_model(GCLOUD_BUCKET, VAE_FP16), get_shark_model(
|
||||
GCLOUD_BUCKET, UNET_FP16
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
VAE_FP16,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
), get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
UNET_FP16,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
)
|
||||
|
||||
elif args.precision == "fp32":
|
||||
|
||||
Reference in New Issue
Block a user