mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Cdx without SLICE
This commit is contained in:
@@ -203,10 +203,11 @@ class Conv2D(Function):
|
||||
wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4))
|
||||
wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4))
|
||||
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.W-1)*C.dx-C.px, (C.W-1)*C.dx-C.px_, (C.H-1)*C.dy-C.py, (C.H-1)*C.dy-C.py_), groups=C.groups)
|
||||
# TODO: this shape can be wrong strided. support asymmetric padding to remove the slice
|
||||
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
|
||||
py_ = x.shape[2] - xt.shape[2] + C.py
|
||||
px_ = x.shape[3] - xt.shape[3] + C.px
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=(px, px_, py, py_), groups=C.groups)
|
||||
dx = ctx._conv(xt, wt, Cdx)
|
||||
dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape])
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
# compute derivative of weights using ProcessingOps.CONV
|
||||
|
||||
Reference in New Issue
Block a user