mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
Perf/cache string ops (#1078)
* perf: remove extra function, include in cached getitem * perf: only calculate hash once per node --------- Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -290,9 +290,10 @@ class _Device:
|
||||
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
|
||||
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
|
||||
x = x.split(":")[0].upper()
|
||||
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
def _default_device(self) -> str:
|
||||
for device in ["METAL", "CUDA", "GPU"]:
|
||||
try:
|
||||
|
||||
@@ -18,8 +18,10 @@ class Node:
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
def hash(self) -> int: return hash(self.key)
|
||||
def __repr__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return hash(self.__repr__())
|
||||
def __hash__(self): return self.hash
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
|
||||
Reference in New Issue
Block a user