mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
get rid of reduce using channels
This commit is contained in:
@@ -190,16 +190,15 @@ class Conv2D(Function):
|
||||
dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape])
|
||||
|
||||
# compute derivative of weights using ProcessingOps.CONV
|
||||
# TODO: there has to be a way to do this without the expand/reduce for at least matmul
|
||||
# since it's ctx.op.matmul(input, grad_output, ctx.buffer(weight.shape), transpose_a=True)
|
||||
xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs*C.groups, C.cin, C.iy, C.ix))
|
||||
xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (1,0,2,3))
|
||||
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output, (C.bs * C.groups * C.rcout, 1, C.oy, C.ox))
|
||||
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups)
|
||||
xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4))
|
||||
xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix))
|
||||
grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3))
|
||||
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.groups * C.rcout, C.bs, C.oy, C.ox))
|
||||
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups)
|
||||
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw)
|
||||
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.cin, C.bs, C.groups, C.rcout, Cdw.oy, Cdw.ox))
|
||||
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,2,3,0,4,5))
|
||||
grad_weight = ctx.reduce_op(ReduceOps.SUM, grad_weight, (1, *grad_weight.shape[1:]))
|
||||
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.groups*C.rcout, C.cin, Cdw.oy, Cdw.ox))
|
||||
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.cin, C.groups*C.rcout, Cdw.oy, Cdw.ox))
|
||||
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3))
|
||||
# TODO: remove this slice using asymmetric padding
|
||||
dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3])))
|
||||
return dx, dw
|
||||
|
||||
Reference in New Issue
Block a user