remove useless reshape

This commit is contained in:
George Hotz
2022-06-16 10:15:43 -07:00
parent 89db797e57
commit b5796ae4f9
3 changed files with 2 additions and 4 deletions

View File

@@ -102,6 +102,7 @@ def reduce_op(op, inp, ret):
loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};")
acc *= shp
# TODO: support multistage reduces
prg = """
__kernel void reduce(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";

View File

@@ -27,6 +27,4 @@ from tinygrad.ops import ProcessingOps
def processing_op(op,x,w,ret,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
# stride is the same as doing the full conv and slicing with stride at the end
# dilation is the same as conving with a larger weight matrix with 0s added
ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))

View File

@@ -194,10 +194,9 @@ class Conv2D(Function):
xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4))
xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix))
grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.groups * C.rcout, C.bs, C.oy, C.ox))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox))
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups)
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw)
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.cin, C.groups*C.rcout, Cdw.oy, Cdw.ox))
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3))
# TODO: remove this slice using asymmetric padding
dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3])))