register the operators outside

This commit is contained in:
George Hotz
2022-06-12 10:26:34 -07:00
parent 33f18c61a1
commit 5cf7649eda

View File

@@ -406,19 +406,21 @@ class Function(Ops):
ret._ctx = ctx # used by autograd engine
return ret
def register(name, fxn):
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
setattr(Tensor, f"__{name}__", dispatch)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
# register functions to move between devices
for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, Device.__dict__[device]))
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device]))
# this registers all the mlops "math" operations
# register all the mlops "math" operations
def register(name, fxn):
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):
if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls)
# register the operators
def register_op(name, fxn):
setattr(Tensor, f"__{name}__", fxn)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))
setattr(Tensor, f"__r{name}__", lambda self,x: fxn(x,self))
for name in ['add', 'sub', 'mul', 'pow', 'matmul']: register_op(name, getattr(Tensor, name))