Perf/cache string ops (#1078)

* perf: remove extra function, include in cached getitem

* perf: only calculate hash once per node

---------

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-06-29 22:23:11 +02:00
committed by GitHub
parent e234bf2298
commit 542b2d93a5
2 changed files with 6 additions and 3 deletions

View File

@@ -290,9 +290,10 @@ class _Device:
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
x = x.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def _default_device(self) -> str:
for device in ["METAL", "CUDA", "GPU"]:
try:

View File

@@ -18,8 +18,10 @@ class Node:
return ops[type(self)](self, ops, ctx)
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
def hash(self) -> int: return hash(self.key)
def __repr__(self): return "<"+self.key+">"
def __hash__(self): return hash(self.__repr__())
def __hash__(self): return self.hash
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key