use functools.partialmethod (#369)

Co-authored-by: Kyle <kposborne@gmail.com>
This commit is contained in:
kposborne2
2022-08-21 12:13:31 -07:00
committed by GitHub
parent e7a4cd91ba
commit ec5d9b355c

View File

@@ -318,8 +318,7 @@ for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
# register all the mlops "math" operations
def register(name:str, fxn:Function):
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, functools.partialmethod(fxn.apply))
for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):
if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls)