From 70bb3a7976fa0176bbcd64d47183f0884bb61b8c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 9 Jun 2022 10:06:01 -0700 Subject: [PATCH] remove more weird logic in the dispatcher --- tinygrad/mlops.py | 9 +++++---- tinygrad/tensor.py | 8 -------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 4906784b79..be28c323e4 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -139,19 +139,20 @@ class Transpose(Function): return ctx.op.perm_axis(x, order, ret) def backward(ctx, grad_output): - norder = np.argsort(ctx.order).tolist() + order, = ctx.saved_tensors + norder = np.argsort(order).tolist() ret = ctx.buffer([grad_output.shape[i] for i in norder]) return ctx.op.perm_axis(grad_output, norder, ret) class Slice(Function): def forward(ctx, x, arg=None): - ctx.save_for_backward(x.shape) + ctx.save_for_backward(x.shape, arg) ret = ctx.buffer([y[1]-y[0] for y in arg]) return ctx.op.inner_slice(x, arg, ret) def backward(ctx, grad_output): - shape, = ctx.saved_tensors - narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)] + shape, arg = ctx.saved_tensors + narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)] ret = ctx.buffer([y[1]-y[0] for y in narg]) return ctx.op.inner_slice(grad_output, narg, ret) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 79029bfe01..0d629ab7f1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -384,14 +384,6 @@ class Function: def apply(self, device, *x, **kwargs): ctx = self(*x) # self - operation i.e 'add', 'sub', etc. ctx.device = device - # use default params - params = inspect.signature(self.forward).parameters - for p in params.values(): - if p.default is not p.empty: - setattr(ctx, p.name, p.default) - # overwrite with passed params - for k, v in kwargs.items(): - setattr(ctx, k, v) ctx.needs_input_grad = [t.requires_grad for t in x] with ProfileOp(ctx, ctx.__class__.__name__, x) as po: ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),