mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
dispatcher: now this is nice
This commit is contained in:
@@ -369,9 +369,11 @@ class Function:
|
||||
cls.backward = staticmethod(cls.backward)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, *tensors):
|
||||
def __init__(self, device, *tensors):
|
||||
self.device = device
|
||||
self.parents = tensors
|
||||
self.requires_grad = any(t.requires_grad for t in tensors)
|
||||
self.needs_input_grad = [t.requires_grad for t in tensors]
|
||||
self.requires_grad = any(self.needs_input_grad)
|
||||
self.saved_tensors = []
|
||||
|
||||
buffer = property(lambda self: Device.buffers[self.device])
|
||||
@@ -381,26 +383,25 @@ class Function:
|
||||
if self.requires_grad:
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
def apply(self, device, *x, **kwargs):
|
||||
ctx = self(*x) # self - operation i.e 'add', 'sub', etc.
|
||||
ctx.device = device
|
||||
ctx.needs_input_grad = [t.requires_grad for t in x]
|
||||
@classmethod
|
||||
def apply(cls, *x, **kwargs):
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0] # this is the prototype tensor
|
||||
|
||||
# create tensors from number arguments
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
assert all([tt.device == t.device for t in x]), "All tensors are not on the same device"
|
||||
|
||||
ctx = cls(tt.device, *x)
|
||||
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
|
||||
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=any(ctx.needs_input_grad))
|
||||
ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
po.output = [ret]
|
||||
if ret.requires_grad:
|
||||
ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
def register(name, fxn):
|
||||
def dispatch(*x, **kwargs):
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)]
|
||||
assert all([tt[0].device == t.device for t in tt]), "All tensors are not on the same device"
|
||||
# create tensors from number arguments
|
||||
x = [Tensor(np.array([arg], dtype=tt[0].dtype), device=tt[0].device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
return fxn.apply(fxn, tt[0].device, *x, **kwargs)
|
||||
|
||||
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
|
||||
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
|
||||
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
setattr(Tensor, f"__{name}__", dispatch)
|
||||
|
||||
Reference in New Issue
Block a user