mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
import os
|
|
import subprocess
|
|
from apps.stable_diffusion.src.utils.resources import (
|
|
get_json_file,
|
|
)
|
|
from shark.shark_downloader import download_public_file
|
|
from image_comparison import compare_images
|
|
import argparse
|
|
from glob import glob
|
|
import shutil
|
|
|
|
model_config_dicts = get_json_file(
|
|
os.path.join(
|
|
os.getcwd(),
|
|
"apps/stable_diffusion/src/utils/resources/model_config.json",
|
|
)
|
|
)
|
|
|
|
|
|
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
|
# Get golden values from tank
|
|
shutil.rmtree("./test_images", ignore_errors=True)
|
|
os.mkdir("./test_images")
|
|
os.mkdir("./test_images/golden")
|
|
hf_model_names = model_config_dicts[0].values()
|
|
tuned_options = ["--no-use_tuned", "use_tuned"]
|
|
if beta:
|
|
extra_flags.append("--beta_models=True")
|
|
for model_name in hf_model_names:
|
|
for use_tune in tuned_options:
|
|
command = [
|
|
"python",
|
|
"apps/stable_diffusion/scripts/txt2img.py",
|
|
"--device=" + device,
|
|
"--prompt=cyberpunk forest by Salvador Dali",
|
|
"--output_dir="
|
|
+ os.path.join(os.getcwd(), "test_images", model_name),
|
|
"--hf_model_id=" + model_name,
|
|
use_tune,
|
|
]
|
|
command += extra_flags
|
|
generated_image = not subprocess.call(
|
|
command, stdout=subprocess.DEVNULL
|
|
)
|
|
if generated_image:
|
|
print(" ".join(command))
|
|
print("Successfully generated image")
|
|
os.makedirs(
|
|
"./test_images/golden/" + model_name, exist_ok=True
|
|
)
|
|
download_public_file(
|
|
"gs://shark_tank/testdata/golden/" + model_name,
|
|
"./test_images/golden/" + model_name,
|
|
)
|
|
test_file_path = os.path.join(
|
|
os.getcwd(), "test_images", model_name, "generated_imgs"
|
|
)
|
|
test_file = glob(test_file_path + "/*.png")[0]
|
|
golden_path = "./test_images/golden/" + model_name + "/*.png"
|
|
golden_file = glob(golden_path)[0]
|
|
compare_images(test_file, golden_file)
|
|
else:
|
|
print(" ".join(command))
|
|
print("failed to generate image for this configuration")
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("-d", "--device", default="vulkan")
|
|
parser.add_argument(
|
|
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
print(args)
|
|
test_loop(args.device, args.beta, [])
|