mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Fix pytest benchmarks and shark_tank generation. (#1632)
- fix setup_venv.sh for benchmarks/imports etc. - fix torch benchmarks in SharkBenchmarkRunner - generate SD artifacts using build_tools/stable_diffusion_testing.py and --import_mlir - decouple SD gen from tank/generate_sharktank for now
This commit is contained in:
@@ -24,13 +24,13 @@ def get_image(url, local_filename):
|
||||
shutil.copyfileobj(res.raw, f)
|
||||
|
||||
|
||||
def compare_images(new_filename, golden_filename):
|
||||
def compare_images(new_filename, golden_filename, upload=False):
|
||||
new = np.array(Image.open(new_filename)) / 255.0
|
||||
golden = np.array(Image.open(golden_filename)) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.1:
|
||||
if os.name != "nt":
|
||||
if os.name != "nt" and upload == True:
|
||||
subprocess.run(
|
||||
[
|
||||
"gsutil",
|
||||
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename):
|
||||
"gs://shark_tank/testdata/builder/",
|
||||
]
|
||||
)
|
||||
raise SystemExit("new and golden not close")
|
||||
raise AssertionError("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python build_tools/stable_diffusion_testing.py --gen
|
||||
python tank/generate_sharktank.py
|
||||
|
||||
@@ -63,7 +63,14 @@ def get_inpaint_inputs():
|
||||
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
||||
|
||||
|
||||
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
def test_loop(
|
||||
device="vulkan",
|
||||
beta=False,
|
||||
extra_flags=[],
|
||||
upload_bool=True,
|
||||
exit_on_fail=True,
|
||||
do_gen=False,
|
||||
):
|
||||
# Get golden values from tank
|
||||
shutil.rmtree("./test_images", ignore_errors=True)
|
||||
model_metrics = []
|
||||
@@ -81,6 +88,8 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
extra_flags.append("--no-progress_bar")
|
||||
if do_gen:
|
||||
extra_flags.append("--import_debug")
|
||||
to_skip = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
@@ -181,7 +190,14 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
"./test_images/golden/" + model_name + "/*.png"
|
||||
)
|
||||
golden_file = glob(golden_path)[0]
|
||||
compare_images(test_file, golden_file)
|
||||
try:
|
||||
compare_images(
|
||||
test_file, golden_file, upload=upload_bool
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
if exit_on_fail == True:
|
||||
raise
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
@@ -200,6 +216,9 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
if do_gen:
|
||||
prepare_artifacts()
|
||||
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
@@ -218,15 +237,49 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
f.write(";".join(output) + "\n")
|
||||
|
||||
|
||||
def prepare_artifacts():
|
||||
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
|
||||
if not os.path.isdir(gen_path):
|
||||
os.mkdir(gen_path)
|
||||
for dirname in os.listdir(os.getcwd()):
|
||||
for modelname in ["clip", "unet", "vae"]:
|
||||
if modelname in dirname and "vmfb" not in dirname:
|
||||
if not os.path.isdir(os.path.join(gen_path, dirname)):
|
||||
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
|
||||
print(f"Moved dir: {dirname} to {gen_path}.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-d", "--device", default="vulkan")
|
||||
parser.add_argument(
|
||||
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
parser.add_argument("-e", "--extra_args", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
test_loop(args.device, args.beta, [])
|
||||
extra_args = []
|
||||
if args.extra_args:
|
||||
for arg in args.extra_args.split(","):
|
||||
extra_args.append(arg)
|
||||
test_loop(
|
||||
args.device,
|
||||
args.beta,
|
||||
extra_args,
|
||||
args.upload,
|
||||
args.exit_on_fail,
|
||||
args.gen,
|
||||
)
|
||||
if args.gen:
|
||||
prepare_artifacts()
|
||||
|
||||
@@ -5,7 +5,7 @@ requires = [
|
||||
"packaging",
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20221021.633",
|
||||
"torch-mlir>=20230620.875",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
numpy>1.22.4
|
||||
pytorch-triton
|
||||
torchvision==0.16.0.dev20230322
|
||||
torchvision
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
@@ -15,7 +15,7 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow>2.11
|
||||
tf-nightly
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
|
||||
@@ -128,7 +128,7 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
@@ -145,14 +145,8 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ ! -z "${ONNX}" ]]; then
|
||||
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
|
||||
$PYTHON -m pip install onnx onnxruntime psutil
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully installed ONNX and ONNX runtime."
|
||||
else
|
||||
echo "Could not install ONNX." >&2
|
||||
fi
|
||||
if [[ -z "${NO_BREVITAS}" ]]; then
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
|
||||
@@ -124,42 +124,41 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(modelname)
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
def benchmark_torch(self, modelname, device="cpu"):
|
||||
import torch
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device == "cuda":
|
||||
torch.set_default_device("cuda:0")
|
||||
# if self.enable_tf32:
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cpu")
|
||||
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
input.to(torch_device)
|
||||
|
||||
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
|
||||
try:
|
||||
frontend_model = torch.compile(
|
||||
frontend_model, mode="max-autotune", backend="inductor"
|
||||
)
|
||||
except RuntimeError:
|
||||
frontend_model = HFmodel.model
|
||||
if device == "cuda":
|
||||
frontend_model.cuda()
|
||||
input.to(torch.device("cuda:0"))
|
||||
print(input)
|
||||
else:
|
||||
frontend_model.cpu()
|
||||
input.cpu()
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
end = time.time()
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
frontend_model.to(torch.device("cpu"))
|
||||
@@ -171,7 +170,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
if self.device == "cuda":
|
||||
if device == "cuda":
|
||||
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
|
||||
torch_device = torch.device("cpu")
|
||||
return [
|
||||
|
||||
@@ -13,7 +13,6 @@ google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
|
||||
|
@@ -16,12 +16,6 @@ import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src.models import (
|
||||
model_wrappers as mw,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stable_args import (
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
@@ -60,31 +54,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
print("generating artifacts for: " + torch_model_name)
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
for precision_value in precision_values:
|
||||
args.precision = precision_value
|
||||
for length in seq_lengths:
|
||||
model = mw.SharkifyStableDiffusionModel(
|
||||
model_id=torch_model_name,
|
||||
custom_weights="",
|
||||
precision=precision_value,
|
||||
max_len=length,
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
custom_vae="",
|
||||
debug=True,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
continue
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(
|
||||
torch_model_name, import_args
|
||||
@@ -103,10 +72,11 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name, import_args)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
if import_args["batch_size"] != 1:
|
||||
if import_args["batch_size"] > 1:
|
||||
print(
|
||||
f"Batch size for this model set to {import_args['batch_size']}"
|
||||
)
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(torch_model_name)
|
||||
@@ -391,7 +361,7 @@ if __name__ == "__main__":
|
||||
|
||||
# old_import_args = parser.parse_import_args()
|
||||
import_args = {
|
||||
"batch_size": "1",
|
||||
"batch_size": 1,
|
||||
}
|
||||
print(import_args)
|
||||
home = str(Path.home())
|
||||
@@ -404,11 +374,6 @@ if __name__ == "__main__":
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
import_args,
|
||||
)
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
# save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -278,7 +278,7 @@ def get_vision_model(torch_model, import_args):
|
||||
int(import_args["batch_size"]), 3, *input_image_size
|
||||
)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model is not None:
|
||||
if fp16_model == True:
|
||||
test_input_fp16 = test_input.to(
|
||||
device=torch.device("cuda"), dtype=torch.half
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-varian
|
||||
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
@@ -18,11 +17,9 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
||||
|
Reference in New Issue
Block a user