mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
canonicalize Device.DEFAULT (#9835)
This commit is contained in:
@@ -24,6 +24,12 @@ class TestDevice(unittest.TestCase):
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
Device["TYPO"]
|
||||
|
||||
def test_lowercase_canonicalizes(self):
|
||||
device = Device.DEFAULT
|
||||
Device.DEFAULT = device.lower()
|
||||
self.assertEqual(Device.canonicalize(None), device)
|
||||
Device.DEFAULT = device
|
||||
|
||||
class MockCompiler(Compiler):
|
||||
def __init__(self, key): super().__init__(key)
|
||||
def compile(self, src) -> bytes: return src.encode()
|
||||
|
||||
@@ -18,7 +18,7 @@ class _Device:
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
||||
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
||||
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
|
||||
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT)
|
||||
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
|
||||
Reference in New Issue
Block a user