diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 2856e9226e..0fb848bce0 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -60,7 +60,7 @@ def movement_op(ctx, op, x, arg=None): elif op == MovementOps.EXPAND: return x.expand(arg) else: raise Exception(f"{op} isn't supported") -def processing_op(ctx, op,x,w,out_shape,C): +def processing_op(ctx,op,x,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)]) gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3]) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 121e5344af..7bd51e943c 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -127,8 +127,8 @@ def contiguous(ctx, x, st, ret=None): def movement_op(ctx, op, x, arg=None): return contiguous(ctx, x, ShapeTracker(*x.shape).movement_op(op, arg)) -def processing_op(ctx,op,x,w,out_shape,C): - ret = ctx.buffer(out_shape) +def processing_op(ctx,op,x,w,C): + ret = ctx.buffer((C.bs, C.cout, C.oy, C.ox)) assert op == ProcessingOps.CONV, f"{op} isn't supported" # input = (bs, groups, cin, iy, ix) # weight = (groups, rcout, cin, H, W) diff --git a/tinygrad/llops/ops_opencl.py b/tinygrad/llops/ops_opencl.py deleted file mode 120000 index 02c9307b09..0000000000 --- a/tinygrad/llops/ops_opencl.py +++ /dev/null @@ -1 +0,0 @@ -../../accel/opencl/ops_opencl.py \ No newline at end of file diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 919dabe342..627939c865 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -25,6 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op from tinygrad.ops import ProcessingOps -def processing_op(ctx,op,x,w,out_shape,C): +def processing_op(ctx,op,x,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index ba5843e06a..f583401dab 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -170,7 +170,7 @@ class Conv2D(Function): def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0): C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding) ctx.save_for_backward(x,w,C) - return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.cout, C.oy, C.ox), C) + return ctx.processing_op(ProcessingOps.CONV, x, w, C) def backward(ctx, grad_output): x, w, C = ctx.saved_tensors @@ -188,7 +188,7 @@ class Conv2D(Function): wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W)) Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups) # TODO: this shape can be wrong. support asymmetric padding to remove the slice - dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx) + dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, Cdx) dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape]) if ctx.needs_input_grad[1]: @@ -199,7 +199,7 @@ class Conv2D(Function): grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3)) grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox)) Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups) - grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (C.cin, C.cout, Cdw.oy, Cdw.ox), Cdw) + grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, Cdw) grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3)) # TODO: remove this slice using asymmetric padding dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3]))) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 45301b2109..40f221ab4f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -72,12 +72,11 @@ class Ops: assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape return ret - def processing_op(ctx, op:ProcessingOps, x, y, out_shape, C): - # TODO: can we do better than out_shape? - if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, out_shape, C) - ret = ctx.op.processing_op(ctx, op, x, y, out_shape, C) + def processing_op(ctx, op:ProcessingOps, x, y, C): + if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, C) + ret = ctx.op.processing_op(ctx, op, x, y, C) log_op(op, ret, [x, y]) - if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, out_shape, C) + if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, C) assert isinstance(ret, ctx.buffer) - assert ret.shape == out_shape + assert ret.shape == (C.bs, C.cout, C.oy, C.ox) return ret \ No newline at end of file