remove more weird logic in the dispatcher

This commit is contained in:
George Hotz
2022-06-09 10:06:01 -07:00
parent e58d344759
commit 70bb3a7976
2 changed files with 5 additions and 12 deletions

View File

@@ -139,19 +139,20 @@ class Transpose(Function):
return ctx.op.perm_axis(x, order, ret)
def backward(ctx, grad_output):
norder = np.argsort(ctx.order).tolist()
order, = ctx.saved_tensors
norder = np.argsort(order).tolist()
ret = ctx.buffer([grad_output.shape[i] for i in norder])
return ctx.op.perm_axis(grad_output, norder, ret)
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
ctx.save_for_backward(x.shape, arg)
ret = ctx.buffer([y[1]-y[0] for y in arg])
return ctx.op.inner_slice(x, arg, ret)
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
shape, arg = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
ret = ctx.buffer([y[1]-y[0] for y in narg])
return ctx.op.inner_slice(grad_output, narg, ret)

View File

@@ -384,14 +384,6 @@ class Function:
def apply(self, device, *x, **kwargs):
ctx = self(*x) # self - operation i.e 'add', 'sub', etc.
ctx.device = device
# use default params
params = inspect.signature(self.forward).parameters
for p in params.values():
if p.default is not p.empty:
setattr(ctx, p.name, p.default)
# overwrite with passed params
for k, v in kwargs.items():
setattr(ctx, k, v)
ctx.needs_input_grad = [t.requires_grad for t in x]
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),