and processing op

This commit is contained in:
George Hotz
2022-06-11 16:46:38 -07:00
parent 6d5591f7a3
commit 35e55afe17
2 changed files with 8 additions and 5 deletions

View File

@@ -151,10 +151,10 @@ class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1):
C = get_conv_args(x.shape, w.shape, stride, groups)
ctx.save_for_backward(x,w,(C.ys,C.xs), C.groups)
return ctx.processing_op(ProcessingOps.CONV, x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups)
return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.groups*C.rcout, C.oy, C.ox), (C.ys,C.xs), C.groups)
def backward(ctx, grad_output):
x, w, stride, groups = ctx.saved_tensors
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None
dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, stride, groups) if ctx.needs_input_grad[0] else None
dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, w.shape, stride, groups) if ctx.needs_input_grad[1] else None
return dx, dw

View File

@@ -73,6 +73,9 @@ class Ops:
log_op(op, ret, [x])
return ret
def processing_op(ctx, op:ProcessingOps, x, y, ret, stride, groups):
def processing_op(ctx, op:ProcessingOps, x, y, out_shape, stride, groups):
# TODO: can we do better than out_shape?
ret = ctx.buffer(out_shape)
ctx.op.processing_op(op, x, y, ret, stride, groups)
log_op(op, ret, [x, y])
return ctx.op.processing_op(op, x, y, ret, stride, groups)
return ret