diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index ece920d053..eb4e368612 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -1,6 +1,7 @@ import unittest from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext from tinygrad.runtime.support.am.ip import AM_GMC +from tinygrad.runtime.support.amd import import_module from tinygrad.helpers import mv_address class FakeGMC(AM_GMC): @@ -171,5 +172,19 @@ class TestAMPageTable(unittest.TestCase): must_cover_checker(va, sz) not_cover_checker(va, sz) +class TestAM(unittest.TestCase): + def test_imports(self): + with self.assertRaises(ImportError): import_module("gc", (7, 0, 0)) + x = import_module("gc", (11, 0, 0)) + assert x.__name__ == "tinygrad.runtime.autogen.am.gc_11_0_0" + x = import_module("gc", (11, 6, 0)) + assert x.__name__ == "tinygrad.runtime.autogen.am.gc_11_0_0" + x = import_module("gc", (12, 0, 0)) + assert x.__name__ == "tinygrad.runtime.autogen.am.gc_12_0_0" + x = import_module("gc", (10, 3, 0)) + assert x.__name__ == "tinygrad.runtime.autogen.am.gc_10_3_0" + x = import_module("gc", (10, 3, 3)) + assert x.__name__ == "tinygrad.runtime.autogen.am.gc_10_3_0" + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/runtime/support/amd.py b/tinygrad/runtime/support/amd.py index edc8932072..11f873b036 100644 --- a/tinygrad/runtime/support/amd.py +++ b/tinygrad/runtime/support/amd.py @@ -27,6 +27,6 @@ def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]: def import_module(name:str, version:tuple[int, ...], version_prefix:str=""): for ver in [version, version[:2]+(0,), version[:1]+(0, 0)]: - try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, version))}") + try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, ver))}") except ImportError: pass raise ImportError(f"Failed to load autogen module for {name.upper()} {'.'.join(map(str, version))}")