diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a135c7306f..773550c05e 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -290,9 +290,10 @@ class _Device: self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device() @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT - def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper()) @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none - def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0] + def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: + x = x.split(":")[0].upper() + return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0] def _default_device(self) -> str: for device in ["METAL", "CUDA", "GPU"]: try: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 69b70278d9..36cdf6ae3a 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -18,8 +18,10 @@ class Node: return ops[type(self)](self, ops, ctx) @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") + @functools.cached_property + def hash(self) -> int: return hash(self.key) def __repr__(self): return "<"+self.key+">" - def __hash__(self): return hash(self.__repr__()) + def __hash__(self): return self.hash def __eq__(self, other:object) -> bool: if not isinstance(other, Node): return NotImplemented return self.key == other.key