Files
SHARK-Studio/build_tools/stable_diffusion_testing.py
pdhirajkumarprasad fe03539901 Migration to AMDShark (#2182)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-20 12:52:07 +05:30

285 lines
10 KiB
Python

import os
from sys import executable
import subprocess
from apps.stable_diffusion.src.utils.resources import (
get_json_file,
)
from datetime import datetime as dt
from amdshark.amdshark_downloader import download_public_file
from image_comparison import compare_images
import argparse
from glob import glob
import shutil
import requests
model_config_dicts = get_json_file(
os.path.join(
os.getcwd(),
"apps/stable_diffusion/src/utils/resources/model_config.json",
)
)
def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir):
with open(filename, "r+") as f:
lines = f.readlines()
metrics = {}
vals_to_read = [
"Clip Inference time",
"Average step",
"VAE Inference time",
"Total image generation",
]
for line in lines:
for val in vals_to_read:
if val in line:
metrics[val] = line.split(" ")[-1].strip("\n")
metrics["Average step"] = metrics["Average step"].strip("ms/it")
metrics["Total image generation"] = metrics["Total image generation"].strip("sec")
metrics["device"] = device
metrics["use_tune"] = use_tune
metrics["model_name"] = model_name
metrics["import_mlir"] = import_mlir
metrics["command"] = command
return metrics
def get_inpaint_inputs():
os.mkdir("./test_images/inputs")
img_url = (
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
"/main/stable_diffusion_inpaint/input_bench_image.png"
)
mask_url = (
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
"/main/stable_diffusion_inpaint/input_bench_mask.png"
)
img = requests.get(img_url)
mask = requests.get(mask_url)
open("./test_images/inputs/image.png", "wb").write(img.content)
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
os.mkdir("./test_images")
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = (
"--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
)
if os.name == "nt":
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
inpaint_prompt_text = (
'--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
)
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"dreamlike-art/dreamlike-diffusion-1.0",
]
counter = 0
for import_opt in import_options:
for model_name in hf_model_names:
if model_name in to_skip:
continue
for use_tune in tuned_options:
if (
model_name == "stabilityai/stable-diffusion-2-1"
and use_tune == tuned_options[0]
):
continue
elif (
model_name == "stabilityai/stable-diffusion-2-1-base"
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this
"apps/stable_diffusion/scripts/txt2img.py",
"--device=" + device,
prompt_text,
"--negative_prompts=" + '""',
"--seed=42",
import_opt,
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
if "inpainting" not in model_name
else [
executable,
"apps/stable_diffusion/scripts/inpaint.py",
"--device=" + device,
inpaint_prompt_text,
"--negative_prompts=" + '""',
"--img_path=./test_images/inputs/image.png",
"--mask_path=./test_images/inputs/mask.png",
"--seed=42",
"--import_mlir",
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
)
command += extra_flags
if os.name == "nt":
command = " ".join(command)
dumpfile_name = "_".join(model_name.split("/")) + ".txt"
dumpfile_name = os.path.join(os.getcwd(), dumpfile_name)
with open(dumpfile_name, "w+") as f:
generated_image = not subprocess.call(
command,
stdout=f,
stderr=f,
)
if os.name != "nt":
command = " ".join(command)
if generated_image:
model_metrics.append(
parse_sd_out(
dumpfile_name,
command,
device,
use_tune,
model_name,
import_opt,
)
)
print(command)
print("Successfully generated image")
os.makedirs("./test_images/golden/" + model_name, exist_ok=True)
download_public_file(
"gs://amdshark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
test_file_path = os.path.join(
os.getcwd(),
"test_images",
model_name,
"generated_imgs",
dt.now().strftime("%Y%m%d"),
"*.png",
)
test_file = glob(test_file_path)[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
try:
compare_images(test_file, golden_file, upload=upload_bool)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
else:
print(command)
print("failed to generate image for this configuration")
with open(dumpfile_name, "r+") as f:
output = f.readlines()
print("\n".join(output))
exit(1)
if os.name == "nt":
counter += 1
if counter % 2 == 0:
extra_flags.append(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
else:
if counter != 1:
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
for metric in model_metrics:
output = [
metric["model_name"],
metric["device"],
metric["use_tune"],
metric["import_mlir"],
metric["Clip Inference time"],
metric["Average step"],
metric["VAE Inference time"],
metric["Total image generation"],
metric["command"],
]
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_amdshark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()