mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
285 lines
10 KiB
Python
285 lines
10 KiB
Python
import os
|
|
from sys import executable
|
|
import subprocess
|
|
from apps.stable_diffusion.src.utils.resources import (
|
|
get_json_file,
|
|
)
|
|
from datetime import datetime as dt
|
|
from amdshark.amdshark_downloader import download_public_file
|
|
from image_comparison import compare_images
|
|
import argparse
|
|
from glob import glob
|
|
import shutil
|
|
import requests
|
|
|
|
model_config_dicts = get_json_file(
|
|
os.path.join(
|
|
os.getcwd(),
|
|
"apps/stable_diffusion/src/utils/resources/model_config.json",
|
|
)
|
|
)
|
|
|
|
|
|
def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
|
|
with open(filename, "r+") as f:
|
|
lines = f.readlines()
|
|
metrics = {}
|
|
vals_to_read = [
|
|
"Clip Inference time",
|
|
"Average step",
|
|
"VAE Inference time",
|
|
"Total image generation",
|
|
]
|
|
for line in lines:
|
|
for val in vals_to_read:
|
|
if val in line:
|
|
metrics[val] = line.split(" ")[-1].strip("\n")
|
|
|
|
metrics["Average step"] = metrics["Average step"].strip("ms/it")
|
|
metrics["Total image generation"] = metrics["Total image generation"].strip("sec")
|
|
metrics["device"] = device
|
|
metrics["use_tune"] = use_tune
|
|
metrics["model_name"] = model_name
|
|
metrics["import_mlir"] = import_mlir
|
|
metrics["command"] = command
|
|
return metrics
|
|
|
|
|
|
def get_inpaint_inputs():
|
|
os.mkdir("./test_images/inputs")
|
|
img_url = (
|
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
|
|
"/main/stable_diffusion_inpaint/input_bench_image.png"
|
|
)
|
|
mask_url = (
|
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
|
|
"/main/stable_diffusion_inpaint/input_bench_mask.png"
|
|
)
|
|
img = requests.get(img_url)
|
|
mask = requests.get(mask_url)
|
|
open("./test_images/inputs/image.png", "wb").write(img.content)
|
|
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
|
|
|
|
|
def test_loop(
|
|
device="vulkan",
|
|
beta=False,
|
|
extra_flags=[],
|
|
upload_bool=True,
|
|
exit_on_fail=True,
|
|
do_gen=False,
|
|
):
|
|
# Get golden values from tank
|
|
shutil.rmtree("./test_images", ignore_errors=True)
|
|
model_metrics = []
|
|
os.mkdir("./test_images")
|
|
os.mkdir("./test_images/golden")
|
|
get_inpaint_inputs()
|
|
hf_model_names = model_config_dicts[0].values()
|
|
tuned_options = [
|
|
"--no-use_tuned",
|
|
"--use_tuned",
|
|
]
|
|
import_options = ["--import_mlir", "--no-import_mlir"]
|
|
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
|
inpaint_prompt_text = (
|
|
"--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
|
|
)
|
|
if os.name == "nt":
|
|
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
|
|
inpaint_prompt_text = (
|
|
'--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
|
|
)
|
|
if beta:
|
|
extra_flags.append("--beta_models=True")
|
|
extra_flags.append("--no-progress_bar")
|
|
if do_gen:
|
|
extra_flags.append("--import_debug")
|
|
to_skip = [
|
|
"Linaqruf/anything-v3.0",
|
|
"prompthero/openjourney",
|
|
"wavymulder/Analog-Diffusion",
|
|
"dreamlike-art/dreamlike-diffusion-1.0",
|
|
]
|
|
counter = 0
|
|
for import_opt in import_options:
|
|
for model_name in hf_model_names:
|
|
if model_name in to_skip:
|
|
continue
|
|
for use_tune in tuned_options:
|
|
if (
|
|
model_name == "stabilityai/stable-diffusion-2-1"
|
|
and use_tune == tuned_options[0]
|
|
):
|
|
continue
|
|
elif (
|
|
model_name == "stabilityai/stable-diffusion-2-1-base"
|
|
and use_tune == tuned_options[1]
|
|
):
|
|
continue
|
|
elif use_tune == tuned_options[1]:
|
|
continue
|
|
command = (
|
|
[
|
|
executable, # executable is the python from the venv used to run this
|
|
"apps/stable_diffusion/scripts/txt2img.py",
|
|
"--device=" + device,
|
|
prompt_text,
|
|
"--negative_prompts=" + '""',
|
|
"--seed=42",
|
|
import_opt,
|
|
"--output_dir="
|
|
+ os.path.join(os.getcwd(), "test_images", model_name),
|
|
"--hf_model_id=" + model_name,
|
|
use_tune,
|
|
]
|
|
if "inpainting" not in model_name
|
|
else [
|
|
executable,
|
|
"apps/stable_diffusion/scripts/inpaint.py",
|
|
"--device=" + device,
|
|
inpaint_prompt_text,
|
|
"--negative_prompts=" + '""',
|
|
"--img_path=./test_images/inputs/image.png",
|
|
"--mask_path=./test_images/inputs/mask.png",
|
|
"--seed=42",
|
|
"--import_mlir",
|
|
"--output_dir="
|
|
+ os.path.join(os.getcwd(), "test_images", model_name),
|
|
"--hf_model_id=" + model_name,
|
|
use_tune,
|
|
]
|
|
)
|
|
command += extra_flags
|
|
if os.name == "nt":
|
|
command = " ".join(command)
|
|
dumpfile_name = "_".join(model_name.split("/")) + ".txt"
|
|
dumpfile_name = os.path.join(os.getcwd(), dumpfile_name)
|
|
with open(dumpfile_name, "w+") as f:
|
|
generated_image = not subprocess.call(
|
|
command,
|
|
stdout=f,
|
|
stderr=f,
|
|
)
|
|
if os.name != "nt":
|
|
command = " ".join(command)
|
|
if generated_image:
|
|
model_metrics.append(
|
|
parse_sd_out(
|
|
dumpfile_name,
|
|
command,
|
|
device,
|
|
use_tune,
|
|
model_name,
|
|
import_opt,
|
|
)
|
|
)
|
|
print(command)
|
|
print("Successfully generated image")
|
|
os.makedirs("./test_images/golden/" + model_name, exist_ok=True)
|
|
download_public_file(
|
|
"gs://amdshark_tank/testdata/golden/" + model_name,
|
|
"./test_images/golden/" + model_name,
|
|
)
|
|
test_file_path = os.path.join(
|
|
os.getcwd(),
|
|
"test_images",
|
|
model_name,
|
|
"generated_imgs",
|
|
dt.now().strftime("%Y%m%d"),
|
|
"*.png",
|
|
)
|
|
test_file = glob(test_file_path)[0]
|
|
|
|
golden_path = "./test_images/golden/" + model_name + "/*.png"
|
|
golden_file = glob(golden_path)[0]
|
|
try:
|
|
compare_images(test_file, golden_file, upload=upload_bool)
|
|
except AssertionError as e:
|
|
print(e)
|
|
if exit_on_fail == True:
|
|
raise
|
|
else:
|
|
print(command)
|
|
print("failed to generate image for this configuration")
|
|
with open(dumpfile_name, "r+") as f:
|
|
output = f.readlines()
|
|
print("\n".join(output))
|
|
exit(1)
|
|
if os.name == "nt":
|
|
counter += 1
|
|
if counter % 2 == 0:
|
|
extra_flags.append(
|
|
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
|
)
|
|
else:
|
|
if counter != 1:
|
|
extra_flags.remove(
|
|
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
|
)
|
|
if do_gen:
|
|
prepare_artifacts()
|
|
|
|
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
|
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
|
f.write(header)
|
|
for metric in model_metrics:
|
|
output = [
|
|
metric["model_name"],
|
|
metric["device"],
|
|
metric["use_tune"],
|
|
metric["import_mlir"],
|
|
metric["Clip Inference time"],
|
|
metric["Average step"],
|
|
metric["VAE Inference time"],
|
|
metric["Total image generation"],
|
|
metric["command"],
|
|
]
|
|
f.write(";".join(output) + "\n")
|
|
|
|
|
|
def prepare_artifacts():
|
|
gen_path = os.path.join(os.getcwd(), "gen_amdshark_tank")
|
|
if not os.path.isdir(gen_path):
|
|
os.mkdir(gen_path)
|
|
for dirname in os.listdir(os.getcwd()):
|
|
for modelname in ["clip", "unet", "vae"]:
|
|
if modelname in dirname and "vmfb" not in dirname:
|
|
if not os.path.isdir(os.path.join(gen_path, dirname)):
|
|
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
|
|
print(f"Moved dir: {dirname} to {gen_path}.")
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("-d", "--device", default="vulkan")
|
|
parser.add_argument(
|
|
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
|
)
|
|
parser.add_argument("-e", "--extra_args", type=str, default=None)
|
|
parser.add_argument(
|
|
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
|
|
)
|
|
parser.add_argument(
|
|
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
|
|
)
|
|
parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False)
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
print(args)
|
|
extra_args = []
|
|
if args.extra_args:
|
|
for arg in args.extra_args.split(","):
|
|
extra_args.append(arg)
|
|
test_loop(
|
|
args.device,
|
|
args.beta,
|
|
extra_args,
|
|
args.upload,
|
|
args.exit_on_fail,
|
|
args.gen,
|
|
)
|
|
if args.gen:
|
|
prepare_artifacts()
|