From 8c084b8c12720154bbbec2585d21cc30cdc4087f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 9 Jun 2022 10:28:46 -0700 Subject: [PATCH] dispatcher: now this is nice --- tinygrad/tensor.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0d629ab7f1..aa57e2bae8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -369,9 +369,11 @@ class Function: cls.backward = staticmethod(cls.backward) return super().__new__(cls) - def __init__(self, *tensors): + def __init__(self, device, *tensors): + self.device = device self.parents = tensors - self.requires_grad = any(t.requires_grad for t in tensors) + self.needs_input_grad = [t.requires_grad for t in tensors] + self.requires_grad = any(self.needs_input_grad) self.saved_tensors = [] buffer = property(lambda self: Device.buffers[self.device]) @@ -381,26 +383,25 @@ class Function: if self.requires_grad: self.saved_tensors.extend(x) - def apply(self, device, *x, **kwargs): - ctx = self(*x) # self - operation i.e 'add', 'sub', etc. - ctx.device = device - ctx.needs_input_grad = [t.requires_grad for t in x] + @classmethod + def apply(cls, *x, **kwargs): + tt = [arg for arg in x if isinstance(arg, Tensor)][0] # this is the prototype tensor + + # create tensors from number arguments + x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] + assert all([tt.device == t.device for t in x]), "All tensors are not on the same device" + + ctx = cls(tt.device, *x) with ProfileOp(ctx, ctx.__class__.__name__, x) as po: - ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs), - device=ctx.device, requires_grad=any(ctx.needs_input_grad)) + ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs), + device=ctx.device, requires_grad=ctx.requires_grad) po.output = [ret] if ret.requires_grad: ret._ctx = ctx # used by autograd engine return ret def register(name, fxn): - def dispatch(*x, **kwargs): - tt = [arg for arg in x if isinstance(arg, Tensor)] - assert all([tt[0].device == t.device for t in tt]), "All tensors are not on the same device" - # create tensors from number arguments - x = [Tensor(np.array([arg], dtype=tt[0].dtype), device=tt[0].device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] - return fxn.apply(fxn, tt[0].device, *x, **kwargs) - + def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch) if name in ['add', 'sub', 'mul', 'pow', 'matmul']: setattr(Tensor, f"__{name}__", dispatch)