Fix tuned model selection for non-vulkan devices (#792)

This commit is contained in:
Quinn Dawkins
2023-01-10 19:04:21 -05:00
committed by GitHub
parent e4efdb5cbb
commit 9570045cc3

View File

@@ -33,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
]
except KeyError:
raise Exception(
f"{bucket}/{model_key} is not present in the models database"
f"{bucket_key}/{model_key} is not present in the models database"
)
if (
@@ -62,7 +62,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
def get_unet():
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
if "vulkan" not in args.device and is_tuned:
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{args.variant}/{is_tuned}/{args.device}"
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
else: