From 16956b79de2ad978347c8893b4b4fdd2cf1955ef Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 10 Apr 2025 23:02:11 +0800 Subject: [PATCH] canonicalize Device.DEFAULT (#9835) --- test/unit/test_device.py | 6 ++++++ tinygrad/device.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/unit/test_device.py b/test/unit/test_device.py index e0ffe6c6ea..b4dd37bcc2 100644 --- a/test/unit/test_device.py +++ b/test/unit/test_device.py @@ -24,6 +24,12 @@ class TestDevice(unittest.TestCase): with self.assertRaises(ModuleNotFoundError): Device["TYPO"] + def test_lowercase_canonicalizes(self): + device = Device.DEFAULT + Device.DEFAULT = device.lower() + self.assertEqual(Device.canonicalize(None), device) + Device.DEFAULT = device + class MockCompiler(Compiler): def __init__(self, key): super().__init__(key) def compile(self, src) -> bytes: return src.encode() diff --git a/tinygrad/device.py b/tinygrad/device.py index adbeefb70b..0e58217651 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -18,7 +18,7 @@ class _Device: @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):]) # NOTE: you can't cache canonicalize in case Device.DEFAULT changes - def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT + def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT) def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix)) @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Compiled: