From 02cd8510cb09bb739861d938b8ee572866f2ad3d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 3 Jul 2022 17:23:20 -0700 Subject: [PATCH] cleanups --- tinygrad/llops/ops_cpu.py | 2 -- tinygrad/mlops.py | 10 +++------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 070a3617fc..07d55fd1b2 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -33,7 +33,6 @@ class CPUBuffer(np.ndarray): if x.shape == new_shape: return x[:] # this is just a copy, regardless of the reduce op elif op == ReduceOps.SUM: return x.sum(axis, keepdims=True) elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True) - else: raise Exception(f"{op} isn't supported") def movement_op(x, op, arg=None): if op == MovementOps.RESHAPE: return x.reshape(arg) @@ -43,7 +42,6 @@ class CPUBuffer(np.ndarray): padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))] elif op == MovementOps.EXPAND: return x.expand(arg) - else: raise Exception(f"{op} isn't supported") def processing_op(x,op,w,C): assert op == ProcessingOps.CONV, f"{op} isn't supported" diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 6ffacad0d8..319ea7aa60 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -176,10 +176,8 @@ class Conv2D(Function): xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1)) xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx))) xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx)) - wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)) - wt = wt.movement_op(MovementOps.FLIP, (3, 4)) - wt = wt.movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4)) - wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)) + wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4)) + wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3)) py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px py_ = x.shape[2] - xt.shape[2] + C.py px_ = x.shape[3] - xt.shape[3] + C.px @@ -187,11 +185,9 @@ class Conv2D(Function): dx = ctx._conv(xt, wt, Cdx) if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV - xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)) - xdw = xdw.movement_op(MovementOps.PERMUTE, (2,1,0,3,4)) + xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)).movement_op(MovementOps.PERMUTE, (2, 1, 0, 3, 4)) xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix)) grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3)) - grad_output_dw = grad_output_dw.movement_op(MovementOps.RESHAPE, (C.cout, C.bs, C.oy, C.ox)) py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.sy * (grad_output_dw.shape[2]-1) + 1 px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.sx * (grad_output_dw.shape[3]-1) + 1 Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)