mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
register the operators outside
This commit is contained in:
@@ -406,19 +406,21 @@ class Function(Ops):
|
||||
ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
def register(name, fxn):
|
||||
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
|
||||
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
|
||||
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
setattr(Tensor, f"__{name}__", dispatch)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
|
||||
|
||||
# register functions to move between devices
|
||||
for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
|
||||
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, Device.__dict__[device]))
|
||||
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device]))
|
||||
|
||||
# this registers all the mlops "math" operations
|
||||
# register all the mlops "math" operations
|
||||
def register(name, fxn):
|
||||
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
|
||||
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
|
||||
for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):
|
||||
if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls)
|
||||
|
||||
# register the operators
|
||||
def register_op(name, fxn):
|
||||
setattr(Tensor, f"__{name}__", fxn)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: fxn(x,self))
|
||||
for name in ['add', 'sub', 'mul', 'pow', 'matmul']: register_op(name, getattr(Tensor, name))
|
||||
Reference in New Issue
Block a user