From 681332ef3288ff3fc0e8681f538fe5efe187d79d Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 16 Feb 2023 12:57:50 -0600 Subject: [PATCH] fix tests after default flag changes (#1009) * fix tests after default flag changes also adds support for import-mlir * Update setup_venv.ps1 --------- --- .github/workflows/test-models.yml | 1 - build_tools/image_comparison.py | 12 ++- build_tools/stable_diffusion_testing.py | 97 ++++++++++++++++--------- 3 files changed, 70 insertions(+), 40 deletions(-) diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index 16c648a7..3babe3cd 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -158,5 +158,4 @@ jobs: if: matrix.suite == 'vulkan' && matrix.os == '7950x' run: | ./setup_venv.ps1 - ./shark.venv/Scripts/activate python build_tools/stable_diffusion_testing.py --device=vulkan diff --git a/build_tools/image_comparison.py b/build_tools/image_comparison.py index 2f3318c2..4071bd2a 100644 --- a/build_tools/image_comparison.py +++ b/build_tools/image_comparison.py @@ -30,9 +30,15 @@ def compare_images(new_filename, golden_filename): diff = np.abs(new - golden) mean = np.mean(diff) if mean > 0.1: - subprocess.run( - ["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"] - ) + if os.name != "nt": + subprocess.run( + [ + "gsutil", + "cp", + new_filename, + "gs://shark_tank/testdata/builder/", + ] + ) raise SystemExit("new and golden not close") else: print("SUCCESS") diff --git a/build_tools/stable_diffusion_testing.py b/build_tools/stable_diffusion_testing.py index 063bc5d0..ae867f97 100644 --- a/build_tools/stable_diffusion_testing.py +++ b/build_tools/stable_diffusion_testing.py @@ -1,8 +1,10 @@ import os +from sys import executable import subprocess from apps.stable_diffusion.src.utils.resources import ( get_json_file, ) +from datetime import datetime as dt from shark.shark_downloader import download_public_file from image_comparison import compare_images import argparse @@ -23,45 +25,68 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]): os.mkdir("./test_images") os.mkdir("./test_images/golden") hf_model_names = model_config_dicts[0].values() - tuned_options = ["--no-use_tuned", "use_tuned"] + tuned_options = ["--no-use_tuned", "--use_tuned"] + import_options = ["--import_mlir", "--no-import_mlir"] + prompt_text = "--prompt=cyberpunk forest by Salvador Dali" + if os.name == "nt": + prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' if beta: extra_flags.append("--beta_models=True") - for model_name in hf_model_names: - for use_tune in tuned_options: - command = [ - "python", - "apps/stable_diffusion/scripts/txt2img.py", - "--device=" + device, - "--prompt=cyberpunk forest by Salvador Dali", - "--output_dir=" - + os.path.join(os.getcwd(), "test_images", model_name), - "--hf_model_id=" + model_name, - use_tune, - ] - command += extra_flags - generated_image = not subprocess.call( - command, stdout=subprocess.DEVNULL - ) - if generated_image: - print(" ".join(command)) - print("Successfully generated image") - os.makedirs( - "./test_images/golden/" + model_name, exist_ok=True + for import_opt in import_options: + for model_name in hf_model_names: + for use_tune in tuned_options: + command = [ + executable, # executable is the python from the venv used to run this + "apps/stable_diffusion/scripts/txt2img.py", + "--device=" + device, + prompt_text, + "--negative_prompts=" + '""', + "--seed=42", + import_opt, + "--output_dir=" + + os.path.join(os.getcwd(), "test_images", model_name), + "--hf_model_id=" + model_name, + use_tune, + ] + command += extra_flags + if os.name == "nt": + command = " ".join(command) + generated_image = not subprocess.call( + command, stdout=subprocess.DEVNULL ) - download_public_file( - "gs://shark_tank/testdata/golden/" + model_name, - "./test_images/golden/" + model_name, - ) - test_file_path = os.path.join( - os.getcwd(), "test_images", model_name, "generated_imgs" - ) - test_file = glob(test_file_path + "/*.png")[0] - golden_path = "./test_images/golden/" + model_name + "/*.png" - golden_file = glob(golden_path)[0] - compare_images(test_file, golden_file) - else: - print(" ".join(command)) - print("failed to generate image for this configuration") + if os.name != "nt": + command = " ".join(command) + if generated_image: + print(command) + print("Successfully generated image") + os.makedirs( + "./test_images/golden/" + model_name, exist_ok=True + ) + download_public_file( + "gs://shark_tank/testdata/golden/" + model_name, + "./test_images/golden/" + model_name, + ) + test_file_path = os.path.join( + os.getcwd(), + "test_images", + model_name, + "generated_imgs", + dt.now().strftime("%Y%m%d"), + "*.png", + ) + test_file = glob(test_file_path)[0] + + golden_path = ( + "./test_images/golden/" + model_name + "/*.png" + ) + golden_file = glob(golden_path)[0] + compare_images(test_file, golden_file) + else: + print(command) + print("failed to generate image for this configuration") + if "2_1_base" in model_name: + print("failed a known successful model.") + exit(1) parser = argparse.ArgumentParser()