mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Update Stable diffusion script to enable use of tuned models (#443)
This commit is contained in:
@@ -21,6 +21,9 @@ UNET_FP16 = "unet_fp16"
|
||||
UNET_FP32 = "unet_fp32"
|
||||
IREE_EXTRA_ARGS = []
|
||||
|
||||
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
|
||||
UNET_FP16_TUNED = "unet_fp16_tuned"
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
@@ -42,24 +45,41 @@ def get_models():
|
||||
global IREE_EXTRA_ARGS
|
||||
if args.precision == "fp16":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-spirv-unify-aliased-resources=false",
|
||||
]
|
||||
if args.use_tuned:
|
||||
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
|
||||
vae_gcloud_bucket = GCLOUD_BUCKET
|
||||
unet_args = IREE_EXTRA_ARGS
|
||||
vae_args = IREE_EXTRA_ARGS + [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
|
||||
]
|
||||
unet_name = UNET_FP16_TUNED
|
||||
vae_name = VAE_FP16
|
||||
else:
|
||||
unet_gcloud_bucket = GCLOUD_BUCKET
|
||||
vae_gcloud_bucket = GCLOUD_BUCKET
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
|
||||
]
|
||||
unet_args = IREE_EXTRA_ARGS
|
||||
vae_args = IREE_EXTRA_ARGS
|
||||
unet_name = UNET_FP16
|
||||
vae_name = VAE_FP16
|
||||
if args.import_mlir == True:
|
||||
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
||||
model_name=UNET_FP16
|
||||
)
|
||||
else:
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
VAE_FP16,
|
||||
IREE_EXTRA_ARGS,
|
||||
vae_gcloud_bucket,
|
||||
vae_name,
|
||||
vae_args,
|
||||
), get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
UNET_FP16,
|
||||
IREE_EXTRA_ARGS,
|
||||
unet_gcloud_bucket,
|
||||
unet_name,
|
||||
unet_args,
|
||||
)
|
||||
|
||||
elif args.precision == "fp32":
|
||||
|
||||
@@ -78,4 +78,11 @@ p.add_argument(
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user