Enable conv nchw-to-nhwc flag by default for most models + minor fixes (#584)

This commit is contained in:
Ean Garvey
2022-12-07 18:24:02 -06:00
committed by GitHub
parent d2475ec169
commit 40eea21863
4 changed files with 48 additions and 10 deletions

View File

@@ -15,6 +15,7 @@ import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
import os
import re
@@ -66,6 +67,16 @@ def get_iree_common_args():
]
# Args that are suitable only for certain models or groups of models.
# shark_args are passed down from pytests to control which models compile with these flags,
# but they can also be set in shark/parser.py
def get_model_specific_args():
ms_args = []
if shark_args.enable_conv_transform == True:
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
return ms_args
def create_dispatch_dirs(bench_dir, device):
protected_files = ["ordered-dispatches.txt"]
bench_dir_path = bench_dir.split("/")
@@ -213,14 +224,22 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
def compile_module_to_flatbuffer(
module, device, frontend, func_name, model_config_path, extra_args
module,
device,
frontend,
func_name,
model_config_path,
extra_args,
model_name="None",
):
# Setup Compile arguments wrt to frontends.
input_type = ""
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args()
args += get_model_specific_args()
args += extra_args
print(args)
if frontend in ["tensorflow", "tf"]:
input_type = "mhlo"

View File

@@ -105,4 +105,11 @@ parser.add_argument(
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)
parser.add_argument(
"--enable_conv_transform",
default=True,
action="store",
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -143,14 +143,14 @@ def get_vision_model(torch_model):
import torchvision.models as models
vision_models_dict = {
"alexnet": models.alexnet(pretrained=True),
"resnet18": models.resnet18(pretrained=True),
"resnet50": models.resnet50(pretrained=True),
"resnet101": models.resnet101(pretrained=True),
"squeezenet1_0": models.squeezenet1_0(pretrained=True),
"wide_resnet50_2": models.wide_resnet50_2(pretrained=True),
"mobilenet_v3_small": models.mobilenet_v3_small(pretrained=True),
"mnasnet1_0": models.mnasnet1_0(pretrained=True),
"alexnet": models.alexnet(weights="DEFAULT"),
"resnet18": models.resnet18(weights="DEFAULT"),
"resnet50": models.resnet50(weights="DEFAULT"),
"resnet101": models.resnet101(weights="DEFAULT"),
"squeezenet1_0": models.squeezenet1_0(weights="DEFAULT"),
"wide_resnet50_2": models.wide_resnet50_2(weights="DEFAULT"),
"mobilenet_v3_small": models.mobilenet_v3_small(weights="DEFAULT"),
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
}
if isinstance(torch_model, str):
torch_model = vision_models_dict[torch_model]

View File

@@ -127,8 +127,11 @@ class SharkModuleTester:
self.config = config
def create_and_check_module(self, dynamic, device):
shark_args.local_tank_cache = self.local_tank_cache
shark_args.update_tank = self.update_tank
if self.config["model_name"] in ["alexnet", "resnet18"]:
shark_args.enable_conv_transform = False
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
@@ -347,7 +350,16 @@ class SharkModuleTest(unittest.TestCase):
pytest.xfail(
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
)
if config["model_name"] == "mobilenet_v3_small":
if config["model_name"] == "mobilenet_v3_small" and device not in [
"cpu"
]:
pytest.xfail(
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
)
if config["model_name"] == "mnasnet1_0" and device not in [
"cpu",
"cuda",
]:
pytest.xfail(
reason="Numerics Issues: https://github.com/nod-ai/SHARK/issues/388"
)