Disable winograd on VAE with rdna2 and fix unet tuning. (#1313)

* Disable winograd on VAE with rdna2 and fix unet tuning.

* Fix batch size 1 downloads and clear_all on windows.
This commit is contained in:
Ean Garvey
2023-04-18 15:55:10 -05:00
committed by GitHub
parent b70919b38d
commit 1afe07c296
3 changed files with 14 additions and 7 deletions

View File

@@ -233,11 +233,14 @@ def sd_model_annotation(mlir_model, model_name, base_model_id=None):
winograd_model, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
tuned_model = mlir_model
else:
use_winograd = False
lowering_config_dir = load_lower_configs(base_model_id)

View File

@@ -141,6 +141,8 @@ def compile_through_fx(
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
if "unet" in model_name.split("_")[0]:
args.annotation_model = "unet"
mlir_module = sd_model_annotation(
mlir_module, extended_model_name, base_model_id
)
@@ -663,7 +665,9 @@ def clear_all():
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
shutil.rmtree(
os.path.join(home, ".local/shark_tank"), ignore_errors=True
)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))

View File

@@ -196,7 +196,7 @@ def download_model(
tank_url=None,
frontend=None,
tuned=None,
import_args={"batch_size": "1"},
import_args=None,
):
model_name = model_name.replace("/", "_")
dyn_str = "_dynamic" if dynamic else ""