From b5796ae4f9579a706e85aa3697bc5eb3acedf946 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 16 Jun 2022 10:15:43 -0700 Subject: [PATCH] remove useless reshape --- tinygrad/llops/ops_gpu.py | 1 + tinygrad/llops/ops_torch.py | 2 -- tinygrad/mlops.py | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 22435c6f47..33c429f3eb 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -102,6 +102,7 @@ def reduce_op(op, inp, ret): loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};") acc *= shp + # TODO: support multistage reduces prg = """ __kernel void reduce(__global const float *a_g, __global float *res_g) { int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+"""; diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 9471a9cbee..d8c6ca7cb4 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -27,6 +27,4 @@ from tinygrad.ops import ProcessingOps def processing_op(op,x,w,ret,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" - # stride is the same as doing the full conv and slicing with stride at the end - # dilation is the same as conving with a larger weight matrix with 0s added ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index f91d55aa78..2be7fe9920 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -194,10 +194,9 @@ class Conv2D(Function): xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4)) xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix)) grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3)) - grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.groups * C.rcout, C.bs, C.oy, C.ox)) + grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox)) Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups) grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw) - grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.cin, C.groups*C.rcout, Cdw.oy, Cdw.ox)) grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3)) # TODO: remove this slice using asymmetric padding dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3])))