Fix sourcing for canonical MiniLM shark_tank model artifacts. (#278)

* Fix generation of MiniLM artifacts.

* Fix miniLM output for validation. Xfail numerics failure on mpnet.

* Update distilbert-base-uncased_tf_test.py

* try-except for transition of minilm model
This commit is contained in:
Ean Garvey
2022-08-17 23:03:47 -05:00
committed by GitHub
parent 82c541dfb8
commit a3654f33da
6 changed files with 22 additions and 4 deletions

View File

@@ -98,6 +98,7 @@ def save_tf_model(tf_model_list):
get_causal_image_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
)
with open(tf_model_list) as csvfile:
@@ -109,13 +110,16 @@ def save_tf_model(tf_model_list):
model = None
input = None
print(model_type)
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
if model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
if model_type == "keras":
model, input, _ = get_keras_model(tf_model_name)
if model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(WORKDIR, str(tf_model_name) + "_tf")
os.makedirs(tf_model_dir, exist_ok=True)

View File

@@ -199,9 +199,11 @@ class SharkImporter:
)
elif golden_out is tuple:
golden_out = self.convert_to_numpy(golden_out)
else:
elif hasattr(golden_out, "logits"):
# from transformers import TFSequenceClassifierOutput
golden_out = golden_out.logits
else:
golden_out = golden_out.last_hidden_state
# Save the artifacts in the directory dir.
self.save_data(
dir,

View File

@@ -50,7 +50,13 @@ class MiniLMModuleTester:
rtol = 1e-02
atol = 1e-03
result = shark_module.forward(inputs)
# TODO: Remove catch once new MiniLM stable
try:
result = shark_module.forward(inputs)[0][1].to_host()
except:
result = shark_module.forward(inputs)
np.testing.assert_allclose(golden_out, result, rtol=rtol, atol=atol)

View File

@@ -34,11 +34,13 @@ class DistilBertModuleTest(unittest.TestCase):
self.module_tester = DistilBertModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
@pytest.mark.xfail(reason="shark_tank hash issues -- awaiting triage")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.xfail(reason="shark_tank hash issues -- awaiting triage")
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
@@ -47,6 +49,7 @@ class DistilBertModuleTest(unittest.TestCase):
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.xfail(reason="shark_tank hash issues -- awaiting triage")
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)

View File

@@ -34,11 +34,13 @@ class MpNetModuleTest(unittest.TestCase):
self.module_tester = MpNetModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
@@ -47,6 +49,7 @@ class MpNetModuleTest(unittest.TestCase):
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)

View File

@@ -13,7 +13,7 @@ microsoft/mpnet-base,hf
roberta-base,hf
resnet50,keras
xlm-roberta-base,hf
microsoft/MiniLM-L12-H384-uncased,hf
microsoft/MiniLM-L12-H384-uncased,TFhf
funnel-transformer/small,hf
microsoft/mpnet-base,hf
facebook/convnext-tiny-224,img
1 model_name model_type
13 roberta-base hf
14 resnet50 keras
15 xlm-roberta-base hf
16 microsoft/MiniLM-L12-H384-uncased hf TFhf
17 funnel-transformer/small hf
18 microsoft/mpnet-base hf
19 facebook/convnext-tiny-224 img