import os from sys import executable import subprocess from apps.stable_diffusion.src.utils.resources import ( get_json_file, ) from datetime import datetime as dt from shark.shark_downloader import download_public_file from image_comparison import compare_images import argparse from glob import glob import shutil import requests model_config_dicts = get_json_file( os.path.join( os.getcwd(), "apps/stable_diffusion/src/utils/resources/model_config.json", ) ) def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir): with open(filename, "r+") as f: lines = f.readlines() metrics = {} vals_to_read = [ "Clip Inference time", "Average step", "VAE Inference time", "Total image generation", ] for line in lines: for val in vals_to_read: if val in line: metrics[val] = line.split(" ")[-1].strip("\n") metrics["Average step"] = metrics["Average step"].strip("ms/it") metrics["Total image generation"] = metrics[ "Total image generation" ].strip("sec") metrics["device"] = device metrics["use_tune"] = use_tune metrics["model_name"] = model_name metrics["import_mlir"] = import_mlir metrics["command"] = command return metrics def get_inpaint_inputs(): os.mkdir("./test_images/inputs") img_url = ( "https://huggingface.co/datasets/diffusers/test-arrays/resolve" "/main/stable_diffusion_inpaint/input_bench_image.png" ) mask_url = ( "https://huggingface.co/datasets/diffusers/test-arrays/resolve" "/main/stable_diffusion_inpaint/input_bench_mask.png" ) img = requests.get(img_url) mask = requests.get(mask_url) open("./test_images/inputs/image.png", "wb").write(img.content) open("./test_images/inputs/mask.png", "wb").write(mask.content) def test_loop( device="vulkan", beta=False, extra_flags=[], upload_bool=True, exit_on_fail=True, do_gen=False, ): # Get golden values from tank shutil.rmtree("./test_images", ignore_errors=True) model_metrics = [] os.mkdir("./test_images") os.mkdir("./test_images/golden") get_inpaint_inputs() hf_model_names = model_config_dicts[0].values() tuned_options = ["--no-use_tuned", "--use_tuned"] import_options = ["--import_mlir", "--no-import_mlir"] prompt_text = "--prompt=cyberpunk forest by Salvador Dali" inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" if os.name == "nt": prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' if beta: extra_flags.append("--beta_models=True") extra_flags.append("--no-progress_bar") if do_gen: extra_flags.append("--import_debug") to_skip = [ "Linaqruf/anything-v3.0", "prompthero/openjourney", "wavymulder/Analog-Diffusion", "dreamlike-art/dreamlike-diffusion-1.0", ] counter = 0 for import_opt in import_options: for model_name in hf_model_names: if model_name in to_skip: continue for use_tune in tuned_options: if ( model_name == "stabilityai/stable-diffusion-2-1" and use_tune == tuned_options[0] ): continue elif ( model_name == "stabilityai/stable-diffusion-2-1-base" and use_tune == tuned_options[1] ): continue command = ( [ executable, # executable is the python from the venv used to run this "apps/stable_diffusion/scripts/txt2img.py", "--device=" + device, prompt_text, "--negative_prompts=" + '""', "--seed=42", import_opt, "--output_dir=" + os.path.join(os.getcwd(), "test_images", model_name), "--hf_model_id=" + model_name, use_tune, ] if "inpainting" not in model_name else [ executable, "apps/stable_diffusion/scripts/inpaint.py", "--device=" + device, inpaint_prompt_text, "--negative_prompts=" + '""', "--img_path=./test_images/inputs/image.png", "--mask_path=./test_images/inputs/mask.png", "--seed=42", "--import_mlir", "--output_dir=" + os.path.join(os.getcwd(), "test_images", model_name), "--hf_model_id=" + model_name, use_tune, ] ) command += extra_flags if os.name == "nt": command = " ".join(command) dumpfile_name = "_".join(model_name.split("/")) + ".txt" dumpfile_name = os.path.join(os.getcwd(), dumpfile_name) with open(dumpfile_name, "w+") as f: generated_image = not subprocess.call( command, stdout=f, stderr=f, ) if os.name != "nt": command = " ".join(command) if generated_image: model_metrics.append( parse_sd_out( dumpfile_name, command, device, use_tune, model_name, import_opt, ) ) print(command) print("Successfully generated image") os.makedirs( "./test_images/golden/" + model_name, exist_ok=True ) download_public_file( "gs://shark_tank/testdata/golden/" + model_name, "./test_images/golden/" + model_name, ) test_file_path = os.path.join( os.getcwd(), "test_images", model_name, "generated_imgs", dt.now().strftime("%Y%m%d"), "*.png", ) test_file = glob(test_file_path)[0] golden_path = ( "./test_images/golden/" + model_name + "/*.png" ) golden_file = glob(golden_path)[0] try: compare_images( test_file, golden_file, upload=upload_bool ) except AssertionError as e: print(e) if exit_on_fail == True: raise else: print(command) print("failed to generate image for this configuration") with open(dumpfile_name, "r+") as f: output = f.readlines() print("\n".join(output)) exit(1) if os.name == "nt": counter += 1 if counter % 2 == 0: extra_flags.append( "--iree_vulkan_target_triple=rdna2-unknown-windows" ) else: if counter != 1: extra_flags.remove( "--iree_vulkan_target_triple=rdna2-unknown-windows" ) if do_gen: prepare_artifacts() with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f: header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n" f.write(header) for metric in model_metrics: output = [ metric["model_name"], metric["device"], metric["use_tune"], metric["import_mlir"], metric["Clip Inference time"], metric["Average step"], metric["VAE Inference time"], metric["Total image generation"], metric["command"], ] f.write(";".join(output) + "\n") def prepare_artifacts(): gen_path = os.path.join(os.getcwd(), "gen_shark_tank") if not os.path.isdir(gen_path): os.mkdir(gen_path) for dirname in os.listdir(os.getcwd()): for modelname in ["clip", "unet", "vae"]: if modelname in dirname and "vmfb" not in dirname: if not os.path.isdir(os.path.join(gen_path, dirname)): shutil.move(os.path.join(os.getcwd(), dirname), gen_path) print(f"Moved dir: {dirname} to {gen_path}.") parser = argparse.ArgumentParser() parser.add_argument("-d", "--device", default="vulkan") parser.add_argument( "-b", "--beta", action=argparse.BooleanOptionalAction, default=False ) parser.add_argument("-e", "--extra_args", type=str, default=None) parser.add_argument( "-u", "--upload", action=argparse.BooleanOptionalAction, default=True ) parser.add_argument( "-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True ) parser.add_argument( "-g", "--gen", action=argparse.BooleanOptionalAction, default=False ) if __name__ == "__main__": args = parser.parse_args() print(args) extra_args = [] if args.extra_args: for arg in args.extra_args.split(","): extra_args.append(arg) test_loop( args.device, args.beta, extra_args, args.upload, args.exit_on_fail, args.gen, ) if args.gen: prepare_artifacts()