From 35e55afe179c07d325a577fd07dd19c7a61a6caf Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 11 Jun 2022 16:46:38 -0700 Subject: [PATCH] and processing op --- tinygrad/mlops.py | 6 +++--- tinygrad/ops.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 7101945281..1acda2f2fc 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -151,10 +151,10 @@ class Conv2D(Function): def forward(ctx, x, w, stride=1, groups=1): C = get_conv_args(x.shape, w.shape, stride, groups) ctx.save_for_backward(x,w,(C.ys,C.xs), C.groups) - return ctx.processing_op(ProcessingOps.CONV, x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups) + return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.groups*C.rcout, C.oy, C.ox), (C.ys,C.xs), C.groups) def backward(ctx, grad_output): x, w, stride, groups = ctx.saved_tensors - dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None - dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None + dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, stride, groups) if ctx.needs_input_grad[0] else None + dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, w.shape, stride, groups) if ctx.needs_input_grad[1] else None return dx, dw \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0f27badae8..20173bbc68 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -73,6 +73,9 @@ class Ops: log_op(op, ret, [x]) return ret - def processing_op(ctx, op:ProcessingOps, x, y, ret, stride, groups): + def processing_op(ctx, op:ProcessingOps, x, y, out_shape, stride, groups): + # TODO: can we do better than out_shape? + ret = ctx.buffer(out_shape) + ctx.op.processing_op(op, x, y, ret, stride, groups) log_op(op, ret, [x, y]) - return ctx.op.processing_op(op, x, y, ret, stride, groups) \ No newline at end of file + return ret \ No newline at end of file