mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
movement op to SSA
This commit is contained in:
@@ -119,35 +119,31 @@ class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
|
||||
return ctx.movement_op(MovementOps.RESHAPE, x, ctx.buffer(shape))
|
||||
return ctx.movement_op(MovementOps.RESHAPE, x, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.RESHAPE, grad_output, ctx.buffer(in_shape))
|
||||
return ctx.movement_op(MovementOps.RESHAPE, grad_output, in_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.save_for_backward(order)
|
||||
ret = ctx.buffer([x.shape[i] for i in order])
|
||||
return ctx.movement_op(MovementOps.PERMUTE, x, ret, order)
|
||||
return ctx.movement_op(MovementOps.PERMUTE, x, order)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
order, = ctx.saved_tensors
|
||||
norder = np.argsort(order).tolist()
|
||||
ret = ctx.buffer([grad_output.shape[i] for i in norder])
|
||||
return ctx.movement_op(MovementOps.PERMUTE, grad_output, ret, norder)
|
||||
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape, arg)
|
||||
ret = ctx.buffer([y[1]-y[0] for y in arg])
|
||||
return ctx.movement_op(MovementOps.SLICE, x, ret, arg)
|
||||
return ctx.movement_op(MovementOps.SLICE, x, arg)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape, arg = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
|
||||
ret = ctx.buffer([y[1]-y[0] for y in narg])
|
||||
return ctx.movement_op(MovementOps.SLICE, grad_output, ret, narg)
|
||||
return ctx.movement_op(MovementOps.SLICE, grad_output, narg)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
|
||||
@@ -64,9 +64,14 @@ class Ops:
|
||||
log_op(op, ret, [x, y])
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, ret, arg=None):
|
||||
def movement_op(ctx, op:MovementOps, x, arg=None):
|
||||
if op == MovementOps.RESHAPE: new_shape = arg
|
||||
if op == MovementOps.PERMUTE: new_shape = [x.shape[i] for i in arg]
|
||||
if op == MovementOps.SLICE: new_shape = [y-x for x,y in arg]
|
||||
ret = ctx.buffer(new_shape)
|
||||
ctx.op.movement_op(op, x, ret, arg)
|
||||
log_op(op, ret, [x])
|
||||
return ctx.op.movement_op(op, x, ret, arg)
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, ret, stride, groups):
|
||||
log_op(op, ret, [x, y])
|
||||
|
||||
Reference in New Issue
Block a user