movement op to SSA

This commit is contained in:
George Hotz
2022-06-11 16:44:24 -07:00
parent 6685807df7
commit 6d5591f7a3
2 changed files with 13 additions and 12 deletions

View File

@@ -119,35 +119,31 @@ class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
return ctx.movement_op(MovementOps.RESHAPE, x, ctx.buffer(shape))
return ctx.movement_op(MovementOps.RESHAPE, x, shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return ctx.movement_op(MovementOps.RESHAPE, grad_output, ctx.buffer(in_shape))
return ctx.movement_op(MovementOps.RESHAPE, grad_output, in_shape)
class Permute(Function):
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
ret = ctx.buffer([x.shape[i] for i in order])
return ctx.movement_op(MovementOps.PERMUTE, x, ret, order)
return ctx.movement_op(MovementOps.PERMUTE, x, order)
def backward(ctx, grad_output):
order, = ctx.saved_tensors
norder = np.argsort(order).tolist()
ret = ctx.buffer([grad_output.shape[i] for i in norder])
return ctx.movement_op(MovementOps.PERMUTE, grad_output, ret, norder)
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape, arg)
ret = ctx.buffer([y[1]-y[0] for y in arg])
return ctx.movement_op(MovementOps.SLICE, x, ret, arg)
return ctx.movement_op(MovementOps.SLICE, x, arg)
def backward(ctx, grad_output):
shape, arg = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
ret = ctx.buffer([y[1]-y[0] for y in narg])
return ctx.movement_op(MovementOps.SLICE, grad_output, ret, narg)
return ctx.movement_op(MovementOps.SLICE, grad_output, narg)
# ************* processing ops *************

View File

@@ -64,9 +64,14 @@ class Ops:
log_op(op, ret, [x, y])
return ret
def movement_op(ctx, op:MovementOps, x, ret, arg=None):
def movement_op(ctx, op:MovementOps, x, arg=None):
if op == MovementOps.RESHAPE: new_shape = arg
if op == MovementOps.PERMUTE: new_shape = [x.shape[i] for i in arg]
if op == MovementOps.SLICE: new_shape = [y-x for x,y in arg]
ret = ctx.buffer(new_shape)
ctx.op.movement_op(op, x, ret, arg)
log_op(op, ret, [x])
return ctx.op.movement_op(op, x, ret, arg)
return ret
def processing_op(ctx, op:ProcessingOps, x, y, ret, stride, groups):
log_op(op, ret, [x, y])