mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD] Fix cuda OTF annotation (#1008)
This commit is contained in:
@@ -22,15 +22,15 @@ def get_device():
|
||||
|
||||
def get_device_args():
|
||||
device = get_device()
|
||||
device_spec_args = ""
|
||||
device_spec_args = []
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
gpu_flags = get_iree_gpu_args()
|
||||
for flag in gpu_flags:
|
||||
device_spec_args += flag + " "
|
||||
device_spec_args.append(flag)
|
||||
elif device == "vulkan":
|
||||
device_spec_args = (
|
||||
device_spec_args.append(
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
)
|
||||
return device, device_spec_args
|
||||
@@ -83,23 +83,24 @@ def load_lower_configs():
|
||||
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
config_version = version
|
||||
config_max_length = args.max_length
|
||||
if variant in ["anythingv3", "analogdiffusion"]:
|
||||
args.max_length = 77
|
||||
config_max_length = 77
|
||||
config_version = "v1_4"
|
||||
if args.annotation_model == "vae":
|
||||
args.max_length = 77
|
||||
config_max_length = 77
|
||||
|
||||
device, device_spec_args = get_device_args()
|
||||
spec = ""
|
||||
if get_device_args:
|
||||
spec = device_spec_args.split("=")[-1].strip()
|
||||
if device_spec_args:
|
||||
spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{config_max_length}_{device}.json"
|
||||
else:
|
||||
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}_{spec}.json"
|
||||
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{config_max_length}_{device}_{spec}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
print("Loading lowering config file from ", lowering_config_dir)
|
||||
@@ -148,9 +149,9 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
dump_module = ireec.compile_str(
|
||||
input_mlir,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=[
|
||||
extra_args=device_spec_args
|
||||
+ [
|
||||
preprocess_flag,
|
||||
device_spec_args,
|
||||
"--compile-to=preprocessing",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -252,7 +252,14 @@ def set_init_device_flags():
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
|
||||
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
|
||||
args.use_tuned = False
|
||||
|
||||
elif (
|
||||
"cuda" in args.device
|
||||
and get_cuda_sm_cc() == "sm_89"
|
||||
and args.hf_model_id != "stabilityai/stable-diffusion-2-1-base"
|
||||
):
|
||||
args.use_tuned = False
|
||||
|
||||
elif args.use_base_vae and args.hf_model_id not in [
|
||||
|
||||
Reference in New Issue
Block a user