mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
remove out_shape from processing_op
This commit is contained in:
@@ -60,7 +60,7 @@ def movement_op(ctx, op, x, arg=None):
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def processing_op(ctx, op,x,w,out_shape,C):
|
||||
def processing_op(ctx,op,x,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
|
||||
@@ -127,8 +127,8 @@ def contiguous(ctx, x, st, ret=None):
|
||||
def movement_op(ctx, op, x, arg=None):
|
||||
return contiguous(ctx, x, ShapeTracker(*x.shape).movement_op(op, arg))
|
||||
|
||||
def processing_op(ctx,op,x,w,out_shape,C):
|
||||
ret = ctx.buffer(out_shape)
|
||||
def processing_op(ctx,op,x,w,C):
|
||||
ret = ctx.buffer((C.bs, C.cout, C.oy, C.ox))
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../accel/opencl/ops_opencl.py
|
||||
@@ -25,6 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
|
||||
|
||||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
def processing_op(ctx,op,x,w,out_shape,C):
|
||||
def processing_op(ctx,op,x,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
|
||||
@@ -170,7 +170,7 @@ class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
ctx.save_for_backward(x,w,C)
|
||||
return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.cout, C.oy, C.ox), C)
|
||||
return ctx.processing_op(ProcessingOps.CONV, x, w, C)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w, C = ctx.saved_tensors
|
||||
@@ -188,7 +188,7 @@ class Conv2D(Function):
|
||||
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups)
|
||||
# TODO: this shape can be wrong. support asymmetric padding to remove the slice
|
||||
dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx)
|
||||
dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, Cdx)
|
||||
dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape])
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
@@ -199,7 +199,7 @@ class Conv2D(Function):
|
||||
grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3))
|
||||
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox))
|
||||
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups)
|
||||
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (C.cin, C.cout, Cdw.oy, Cdw.ox), Cdw)
|
||||
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, Cdw)
|
||||
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3))
|
||||
# TODO: remove this slice using asymmetric padding
|
||||
dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3])))
|
||||
|
||||
@@ -72,12 +72,11 @@ class Ops:
|
||||
assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, out_shape, C):
|
||||
# TODO: can we do better than out_shape?
|
||||
if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, out_shape, C)
|
||||
ret = ctx.op.processing_op(ctx, op, x, y, out_shape, C)
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, C):
|
||||
if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, C)
|
||||
ret = ctx.op.processing_op(ctx, op, x, y, C)
|
||||
log_op(op, ret, [x, y])
|
||||
if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, out_shape, C)
|
||||
if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, C)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == out_shape
|
||||
assert ret.shape == (C.bs, C.cout, C.oy, C.ox)
|
||||
return ret
|
||||
Reference in New Issue
Block a user