hotfix: amd imports (#9620)

This commit is contained in:
nimlgen
2025-03-29 20:19:53 +07:00
committed by GitHub
parent 5908b89f71
commit 118bd1cbed
2 changed files with 16 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
import unittest
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
from tinygrad.runtime.support.am.ip import AM_GMC
from tinygrad.runtime.support.amd import import_module
from tinygrad.helpers import mv_address
class FakeGMC(AM_GMC):
@@ -171,5 +172,19 @@ class TestAMPageTable(unittest.TestCase):
must_cover_checker(va, sz)
not_cover_checker(va, sz)
class TestAM(unittest.TestCase):
def test_imports(self):
with self.assertRaises(ImportError): import_module("gc", (7, 0, 0))
x = import_module("gc", (11, 0, 0))
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_11_0_0"
x = import_module("gc", (11, 6, 0))
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_11_0_0"
x = import_module("gc", (12, 0, 0))
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_12_0_0"
x = import_module("gc", (10, 3, 0))
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_10_3_0"
x = import_module("gc", (10, 3, 3))
assert x.__name__ == "tinygrad.runtime.autogen.am.gc_10_3_0"
if __name__ == "__main__":
unittest.main()

View File

@@ -27,6 +27,6 @@ def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]:
def import_module(name:str, version:tuple[int, ...], version_prefix:str=""):
for ver in [version, version[:2]+(0,), version[:1]+(0, 0)]:
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, version))}")
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, ver))}")
except ImportError: pass
raise ImportError(f"Failed to load autogen module for {name.upper()} {'.'.join(map(str, version))}")