mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
hotfix: amd imports (#9620)
This commit is contained in:
15
test/external/external_test_am.py
vendored
15
test/external/external_test_am.py
vendored
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
|
||||
from tinygrad.runtime.support.am.ip import AM_GMC
|
||||
from tinygrad.runtime.support.amd import import_module
|
||||
from tinygrad.helpers import mv_address
|
||||
|
||||
class FakeGMC(AM_GMC):
|
||||
@@ -171,5 +172,19 @@ class TestAMPageTable(unittest.TestCase):
|
||||
must_cover_checker(va, sz)
|
||||
not_cover_checker(va, sz)
|
||||
|
||||
class TestAM(unittest.TestCase):
|
||||
def test_imports(self):
|
||||
with self.assertRaises(ImportError): import_module("gc", (7, 0, 0))
|
||||
x = import_module("gc", (11, 0, 0))
|
||||
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_11_0_0"
|
||||
x = import_module("gc", (11, 6, 0))
|
||||
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_11_0_0"
|
||||
x = import_module("gc", (12, 0, 0))
|
||||
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_12_0_0"
|
||||
x = import_module("gc", (10, 3, 0))
|
||||
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_10_3_0"
|
||||
x = import_module("gc", (10, 3, 3))
|
||||
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_10_3_0"
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -27,6 +27,6 @@ def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]:
|
||||
|
||||
def import_module(name:str, version:tuple[int, ...], version_prefix:str=""):
|
||||
for ver in [version, version[:2]+(0,), version[:1]+(0, 0)]:
|
||||
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, version))}")
|
||||
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, ver))}")
|
||||
except ImportError: pass
|
||||
raise ImportError(f"Failed to load autogen module for {name.upper()} {'.'.join(map(str, version))}")
|
||||
|
||||
Reference in New Issue
Block a user