mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
and processing op
This commit is contained in:
@@ -151,10 +151,10 @@ class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups)
|
||||
ctx.save_for_backward(x,w,(C.ys,C.xs), C.groups)
|
||||
return ctx.processing_op(ProcessingOps.CONV, x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups)
|
||||
return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.groups*C.rcout, C.oy, C.ox), (C.ys,C.xs), C.groups)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w, stride, groups = ctx.saved_tensors
|
||||
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None
|
||||
dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None
|
||||
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, stride, groups) if ctx.needs_input_grad[0] else None
|
||||
dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, w.shape, stride, groups) if ctx.needs_input_grad[1] else None
|
||||
return dx, dw
|
||||
@@ -73,6 +73,9 @@ class Ops:
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, ret, stride, groups):
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, out_shape, stride, groups):
|
||||
# TODO: can we do better than out_shape?
|
||||
ret = ctx.buffer(out_shape)
|
||||
ctx.op.processing_op(op, x, y, ret, stride, groups)
|
||||
log_op(op, ret, [x, y])
|
||||
return ctx.op.processing_op(op, x, y, ret, stride, groups)
|
||||
return ret
|
||||
Reference in New Issue
Block a user