mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Bug fix: Pass the device attribute appropriately.
Previously the device attribute was not passed and device was hardcoded to "cpu". So every tests were running on cpu.
This commit is contained in:
@@ -26,6 +26,9 @@ def get_vulkan_triple_flag():
|
||||
elif vulkan_device == "A100-SXM4-40GB":
|
||||
print("Found Nvidia Device. Using ampere-rtx3080-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3080-linux"
|
||||
elif vulkan_device == "3090":
|
||||
print("Found Nvidia Device. Using ampere-rtx3090-linux")
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3090-linux"
|
||||
else:
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
|
||||
@@ -35,7 +35,7 @@ class MiniLMModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -35,7 +35,7 @@ class AlbertModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -35,7 +35,7 @@ class BertModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -35,7 +35,7 @@ class DistilBertModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
@@ -50,7 +50,6 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.module_tester = DistilBertModuleTester(self)
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
@@ -62,35 +61,27 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -35,7 +35,7 @@ class MobileBertUncasedModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
@@ -60,6 +60,7 @@ class MobileBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="golden and original results mismatch")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
|
||||
@@ -36,7 +36,7 @@ class Resnet101ModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=False)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -36,7 +36,7 @@ class Resnet18ModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=False)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -36,7 +36,7 @@ class Resnet50ModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=False)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -36,7 +36,7 @@ class SqueezenetModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=False)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
@@ -36,7 +36,7 @@ class WideResnet50ModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device=self.device, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
Reference in New Issue
Block a user