mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add TF ResNet50 to tank tests. (#261)
* Add TensorFlow Resnet50 test to shark tank.
This commit is contained in:
@@ -93,8 +93,11 @@ def save_torch_model(torch_model_list):
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list):
|
||||
from tank.model_utils_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import get_causal_image_model
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
)
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
@@ -110,7 +113,8 @@ def save_tf_model(tf_model_list):
|
||||
model, input, _ = get_causal_lm_model(tf_model_name)
|
||||
if model_type == "img":
|
||||
model, input, _ = get_causal_image_model(tf_model_name)
|
||||
|
||||
if model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name)
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
|
||||
@@ -34,9 +34,12 @@ def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str):
|
||||
dtype_string = str(input_tensor.dtype).replace("torch.", "")
|
||||
elif mlir_dialect in ["mhlo", "tflite"]:
|
||||
dtype = input_tensor.dtype
|
||||
dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace(
|
||||
"'", ""
|
||||
)
|
||||
try:
|
||||
dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace(
|
||||
"'", ""
|
||||
)
|
||||
except IndexError:
|
||||
dtype_string = str(dtype)
|
||||
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
|
||||
match = regex_split.match(dtype_string)
|
||||
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
|
||||
|
||||
@@ -124,15 +124,15 @@ def export_iree_module_to_vmfb(
|
||||
module,
|
||||
device: str,
|
||||
directory: str,
|
||||
frontend: str = "torch",
|
||||
mlir_dialect: str = "linalg",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path
|
||||
module, device, mlir_dialect, func_name, model_config_path
|
||||
)
|
||||
module_name = f"{frontend}_{func_name}_{device}"
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
|
||||
@@ -34,22 +34,21 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
frontend: str = "torch",
|
||||
):
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.frontend = frontend
|
||||
self.frontend_model = None
|
||||
self.vmfb_file = None
|
||||
self.mlir_dialect = mlir_dialect
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
function_name,
|
||||
device,
|
||||
mlir_dialect,
|
||||
self.mlir_dialect,
|
||||
)
|
||||
if self.vmfb_file == None:
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module, device, shark_args.repro_dir, self.frontend
|
||||
mlir_module, device, shark_args.repro_dir, self.mlir_dialect
|
||||
)
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
@@ -59,11 +58,12 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
input_tensors,
|
||||
mlir_dialect=self.mlir_dialect,
|
||||
)
|
||||
print(self.benchmark_cl)
|
||||
|
||||
def benchmark_frontend(self, inputs, modelname):
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
if self.mlir_dialect in ["linalg", "torch"]:
|
||||
return self.benchmark_torch(modelname)
|
||||
elif self.frontend in ["tensorflow", "tf"]:
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(inputs, modelname)
|
||||
|
||||
def benchmark_torch(self, modelname):
|
||||
@@ -99,26 +99,27 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
]
|
||||
|
||||
def benchmark_tf(self, frontend_model, inputs):
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(*inputs)
|
||||
# for i in range(shark_args.num_warmup_iterations):
|
||||
# frontend_model.forward(*inputs)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(*inputs)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
break
|
||||
print(
|
||||
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
]
|
||||
# begin = time.time()
|
||||
# for i in range(shark_args.num_iterations):
|
||||
# out = frontend_model.forward(*inputs)
|
||||
# if i == shark_args.num_iterations - 1:
|
||||
# end = time.time()
|
||||
# break
|
||||
# print(
|
||||
# f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
# )
|
||||
# return [
|
||||
# f"{shark_args.num_iterations/(end-begin)}",
|
||||
# f"{((end-begin)/shark_args.num_iterations)*1000}",
|
||||
# ]
|
||||
return ["n/a", "n/a"]
|
||||
|
||||
def benchmark_c(self):
|
||||
result = run_benchmark_module(self.benchmark_cl)
|
||||
print(f"Shark-{self.frontend} C-benchmark:{result} iter/second")
|
||||
print(f"Shark-IREE-C benchmark:{result} iter/second")
|
||||
return [f"{result}", f"{1000/result}"]
|
||||
|
||||
def benchmark_python(self, inputs):
|
||||
@@ -132,7 +133,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
print(
|
||||
f"Shark-{self.frontend} Python-benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
f"Shark-IREE Python benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
return [
|
||||
f"{shark_args.num_iterations/(end-begin)}",
|
||||
|
||||
@@ -147,6 +147,50 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### TensorFlow Keras Resnet Models #########################################################
|
||||
# Static shape, including batch size (1).
|
||||
# Can be dynamic once dynamic shape support is ready.
|
||||
INPUT_SHAPE = [1, 224, 224, 3]
|
||||
|
||||
tf_model = tf.keras.applications.resnet50.ResNet50(
|
||||
weights="imagenet", include_top=True, input_shape=tuple(INPUT_SHAPE[1:])
|
||||
)
|
||||
|
||||
|
||||
class ResNetModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(ResNetModule, self).__init__()
|
||||
self.m = tf_model
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])
|
||||
def forward(self, inputs):
|
||||
return self.m.predict(inputs)
|
||||
|
||||
|
||||
def load_image(path_to_image):
|
||||
image = tf.io.read_file(path_to_image)
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
image = tf.image.resize(image, (224, 224))
|
||||
image = image[tf.newaxis, :]
|
||||
return image
|
||||
|
||||
|
||||
def get_keras_model(modelname):
|
||||
model = ResNetModule()
|
||||
content_path = tf.keras.utils.get_file(
|
||||
"YellowLabradorLooking_new.jpg",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg",
|
||||
)
|
||||
content_image = load_image(content_path)
|
||||
input_tensor = tf.keras.applications.resnet50.preprocess_input(
|
||||
content_image
|
||||
)
|
||||
input_data = tf.expand_dims(input_tensor, 0)
|
||||
actual_out = model.forward(*input_data)
|
||||
return model, input_data, actual_out
|
||||
|
||||
|
||||
##################### Tensorflow Hugging Face Image Classification Models ###################################
|
||||
from transformers import TFAutoModelForImageClassification
|
||||
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
|
||||
|
||||
67
tank/resnet50_tf/resnet50_tf_test.py
Normal file
67
tank/resnet50_tf/resnet50_tf_test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_downloader import download_tf_model
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Resnet50ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
benchmark=False,
|
||||
):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model("resnet50")
|
||||
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
func_name,
|
||||
device=device,
|
||||
mlir_dialect="mhlo",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(inputs), "resnet50", dynamic, device, "tensorflow"
|
||||
)
|
||||
|
||||
|
||||
class Resnet50ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = Resnet50ModuleTester(self)
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -11,6 +11,7 @@ microsoft/layoutlm-base-uncased,hf
|
||||
google/mobilebert-uncased,hf
|
||||
microsoft/mpnet-base,hf
|
||||
roberta-base,hf
|
||||
resnet50,keras
|
||||
xlm-roberta-base,hf
|
||||
microsoft/MiniLM-L12-H384-uncased,hf
|
||||
funnel-transformer/small,hf
|
||||
|
||||
|
Reference in New Issue
Block a user