mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
remove more weird logic in the dispatcher
This commit is contained in:
@@ -139,19 +139,20 @@ class Transpose(Function):
|
||||
return ctx.op.perm_axis(x, order, ret)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
norder = np.argsort(ctx.order).tolist()
|
||||
order, = ctx.saved_tensors
|
||||
norder = np.argsort(order).tolist()
|
||||
ret = ctx.buffer([grad_output.shape[i] for i in norder])
|
||||
return ctx.op.perm_axis(grad_output, norder, ret)
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape)
|
||||
ctx.save_for_backward(x.shape, arg)
|
||||
ret = ctx.buffer([y[1]-y[0] for y in arg])
|
||||
return ctx.op.inner_slice(x, arg, ret)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape, = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
|
||||
shape, arg = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
|
||||
ret = ctx.buffer([y[1]-y[0] for y in narg])
|
||||
return ctx.op.inner_slice(grad_output, narg, ret)
|
||||
|
||||
|
||||
@@ -384,14 +384,6 @@ class Function:
|
||||
def apply(self, device, *x, **kwargs):
|
||||
ctx = self(*x) # self - operation i.e 'add', 'sub', etc.
|
||||
ctx.device = device
|
||||
# use default params
|
||||
params = inspect.signature(self.forward).parameters
|
||||
for p in params.values():
|
||||
if p.default is not p.empty:
|
||||
setattr(ctx, p.name, p.default)
|
||||
# overwrite with passed params
|
||||
for k, v in kwargs.items():
|
||||
setattr(ctx, k, v)
|
||||
ctx.needs_input_grad = [t.requires_grad for t in x]
|
||||
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
|
||||
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
|
||||
Reference in New Issue
Block a user