mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
Update Stable diffusion script to enable use of tuned models (#443)
This commit is contained in:
@@ -21,6 +21,9 @@ UNET_FP16 = "unet_fp16"
|
|||||||
UNET_FP32 = "unet_fp32"
|
UNET_FP32 = "unet_fp32"
|
||||||
IREE_EXTRA_ARGS = []
|
IREE_EXTRA_ARGS = []
|
||||||
|
|
||||||
|
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
|
||||||
|
UNET_FP16_TUNED = "unet_fp16_tuned"
|
||||||
|
|
||||||
# Helper function to profile the vulkan device.
|
# Helper function to profile the vulkan device.
|
||||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||||
@@ -42,24 +45,41 @@ def get_models():
|
|||||||
global IREE_EXTRA_ARGS
|
global IREE_EXTRA_ARGS
|
||||||
if args.precision == "fp16":
|
if args.precision == "fp16":
|
||||||
IREE_EXTRA_ARGS += [
|
IREE_EXTRA_ARGS += [
|
||||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
|
||||||
"--iree-flow-enable-padding-linalg-ops",
|
"--iree-flow-enable-padding-linalg-ops",
|
||||||
"--iree-flow-linalg-ops-padding-size=32",
|
"--iree-flow-linalg-ops-padding-size=32",
|
||||||
"--iree-spirv-unify-aliased-resources=false",
|
|
||||||
]
|
]
|
||||||
|
if args.use_tuned:
|
||||||
|
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
|
||||||
|
vae_gcloud_bucket = GCLOUD_BUCKET
|
||||||
|
unet_args = IREE_EXTRA_ARGS
|
||||||
|
vae_args = IREE_EXTRA_ARGS + [
|
||||||
|
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
|
||||||
|
]
|
||||||
|
unet_name = UNET_FP16_TUNED
|
||||||
|
vae_name = VAE_FP16
|
||||||
|
else:
|
||||||
|
unet_gcloud_bucket = GCLOUD_BUCKET
|
||||||
|
vae_gcloud_bucket = GCLOUD_BUCKET
|
||||||
|
IREE_EXTRA_ARGS += [
|
||||||
|
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
|
||||||
|
]
|
||||||
|
unet_args = IREE_EXTRA_ARGS
|
||||||
|
vae_args = IREE_EXTRA_ARGS
|
||||||
|
unet_name = UNET_FP16
|
||||||
|
vae_name = VAE_FP16
|
||||||
if args.import_mlir == True:
|
if args.import_mlir == True:
|
||||||
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
||||||
model_name=UNET_FP16
|
model_name=UNET_FP16
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return get_shark_model(
|
return get_shark_model(
|
||||||
GCLOUD_BUCKET,
|
vae_gcloud_bucket,
|
||||||
VAE_FP16,
|
vae_name,
|
||||||
IREE_EXTRA_ARGS,
|
vae_args,
|
||||||
), get_shark_model(
|
), get_shark_model(
|
||||||
GCLOUD_BUCKET,
|
unet_gcloud_bucket,
|
||||||
UNET_FP16,
|
unet_name,
|
||||||
IREE_EXTRA_ARGS,
|
unet_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif args.precision == "fp32":
|
elif args.precision == "fp32":
|
||||||
|
|||||||
@@ -78,4 +78,11 @@ p.add_argument(
|
|||||||
help="Profiles vulkan device and collects the .rdc info",
|
help="Profiles vulkan device and collects the .rdc info",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--use_tuned",
|
||||||
|
default=True,
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help="Download and use the tuned version of the model if available",
|
||||||
|
)
|
||||||
|
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user