[SD] Fix cuda OTF annotation (#1008)

This commit is contained in:
yzhang93
2023-02-13 12:32:50 -08:00
committed by GitHub
parent dd2e482214
commit 5167df08b9
2 changed files with 21 additions and 13 deletions

View File

@@ -22,15 +22,15 @@ def get_device():
def get_device_args():
device = get_device()
device_spec_args = ""
device_spec_args = []
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
gpu_flags = get_iree_gpu_args()
for flag in gpu_flags:
device_spec_args += flag + " "
device_spec_args.append(flag)
elif device == "vulkan":
device_spec_args = (
device_spec_args.append(
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
return device, device_spec_args
@@ -83,23 +83,24 @@ def load_lower_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
config_max_length = args.max_length
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_max_length = 77
config_version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
config_max_length = 77
device, device_spec_args = get_device_args()
spec = ""
if get_device_args:
spec = device_spec_args.split("=")[-1].strip()
if device_spec_args:
spec = device_spec_args[-1].split("=")[-1].strip()
if device == "vulkan":
spec = spec.split("-")[0]
if spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
if not spec or spec in ["rdna3", "sm_80"]:
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{config_max_length}_{device}.json"
else:
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}_{spec}.json"
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{config_max_length}_{device}_{spec}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
@@ -148,9 +149,9 @@ def dump_after_mlir(input_mlir, use_winograd):
dump_module = ireec.compile_str(
input_mlir,
target_backends=[iree_target_map(device)],
extra_args=[
extra_args=device_spec_args
+ [
preprocess_flag,
device_spec_args,
"--compile-to=preprocessing",
],
)

View File

@@ -252,7 +252,14 @@ def set_init_device_flags():
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
args.use_tuned = False
elif (
"cuda" in args.device
and get_cuda_sm_cc() == "sm_89"
and args.hf_model_id != "stabilityai/stable-diffusion-2-1-base"
):
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [