mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Enable conv nchw-to-nhwc flag by default for most models + minor fixes (#584)
This commit is contained in:
@@ -15,6 +15,7 @@ import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
@@ -66,6 +67,16 @@ def get_iree_common_args():
|
||||
]
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
# shark_args are passed down from pytests to control which models compile with these flags,
|
||||
# but they can also be set in shark/parser.py
|
||||
def get_model_specific_args():
|
||||
ms_args = []
|
||||
if shark_args.enable_conv_transform == True:
|
||||
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
|
||||
return ms_args
|
||||
|
||||
|
||||
def create_dispatch_dirs(bench_dir, device):
|
||||
protected_files = ["ordered-dispatches.txt"]
|
||||
bench_dir_path = bench_dir.split("/")
|
||||
@@ -213,14 +224,22 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
|
||||
|
||||
def compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path, extra_args
|
||||
module,
|
||||
device,
|
||||
frontend,
|
||||
func_name,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
print(args)
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "mhlo"
|
||||
|
||||
@@ -105,4 +105,11 @@ parser.add_argument(
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_conv_transform",
|
||||
default=True,
|
||||
action="store",
|
||||
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -143,14 +143,14 @@ def get_vision_model(torch_model):
|
||||
import torchvision.models as models
|
||||
|
||||
vision_models_dict = {
|
||||
"alexnet": models.alexnet(pretrained=True),
|
||||
"resnet18": models.resnet18(pretrained=True),
|
||||
"resnet50": models.resnet50(pretrained=True),
|
||||
"resnet101": models.resnet101(pretrained=True),
|
||||
"squeezenet1_0": models.squeezenet1_0(pretrained=True),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(pretrained=True),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(pretrained=True),
|
||||
"mnasnet1_0": models.mnasnet1_0(pretrained=True),
|
||||
"alexnet": models.alexnet(weights="DEFAULT"),
|
||||
"resnet18": models.resnet18(weights="DEFAULT"),
|
||||
"resnet50": models.resnet50(weights="DEFAULT"),
|
||||
"resnet101": models.resnet101(weights="DEFAULT"),
|
||||
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
|
||||
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
|
||||
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
|
||||
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
|
||||
}
|
||||
if isinstance(torch_model, str):
|
||||
torch_model = vision_models_dict[torch_model]
|
||||
|
||||
@@ -127,8 +127,11 @@ class SharkModuleTester:
|
||||
self.config = config
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.update_tank = self.update_tank
|
||||
if self.config["model_name"] in ["alexnet", "resnet18"]:
|
||||
shark_args.enable_conv_transform = False
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
@@ -347,7 +350,16 @@ class SharkModuleTest(unittest.TestCase):
|
||||
pytest.xfail(
|
||||
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
|
||||
)
|
||||
if config["model_name"] == "mobilenet_v3_small":
|
||||
if config["model_name"] == "mobilenet_v3_small" and device not in [
|
||||
"cpu"
|
||||
]:
|
||||
pytest.xfail(
|
||||
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
|
||||
)
|
||||
if config["model_name"] == "mnasnet1_0" and device not in [
|
||||
"cpu",
|
||||
"cuda",
|
||||
]:
|
||||
pytest.xfail(
|
||||
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user