diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 1806337cff..22435c6f47 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -93,7 +93,7 @@ def reduce_op(op, inp, ret): # this takes a ret index to an inp index, indexing 0 on the reduced strides view = View(ret.shape, strides_for_shape(inp.shape)) - # combined adjacent reduce axis + # generate loops with combined adjacent reduce axis acc = 1 loop_start, loop_end = [], [] for shp,stride in st.views[-1].shape_strides[::-1]: diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 98b5759488..f1b87a197c 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -192,19 +192,14 @@ class Conv2D(Function): # compute derivative of weights using ProcessingOps.CONV # TODO: there has to be a way to do this without the expand/reduce for at least matmul # since it's ctx.op.matmul(input, grad_output, ctx.buffer(weight.shape), transpose_a=True) - xdw = ctx.movement_op(MovementOps.RESHAPE, x, (1, C.bs * C.groups * C.cin, C.iy, C.ix)) - grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output, (C.bs * C.groups, 1, C.rcout, C.oy, C.ox)) - # this expand is slow - grad_output_dw = ctx.movement_op(MovementOps.EXPAND, grad_output_dw, (C.bs * C.groups, C.cin, C.rcout, C.oy, C.ox)) - grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox)) - # padding is the same, stride and dilation are flipped - Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin) + xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs*C.groups, C.cin, C.iy, C.ix)) + xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (1,0,2,3)) + grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output, (C.bs * C.groups * C.rcout, 1, C.oy, C.ox)) + Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups) grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw) - grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.bs, C.groups, C.cin, C.rcout, Cdw.oy, Cdw.ox)) - # sum across the batch dimension + grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.cin, C.bs, C.groups, C.rcout, Cdw.oy, Cdw.ox)) + grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,2,3,0,4,5)) grad_weight = ctx.reduce_op(ReduceOps.SUM, grad_weight, (1, *grad_weight.shape[1:])) - # flip channels out and in - grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (0,1,3,2,4,5)) grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.groups*C.rcout, C.cin, Cdw.oy, Cdw.ox)) dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3]))) - return dx, dw \ No newline at end of file + return dx, dw diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d3e7fa748a..62a2b69b10 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -48,13 +48,13 @@ class Ops: log_op(op, ret, [x]) return ret - def reduce_op(ctx, op:BinaryOps, x, new_shape): + def reduce_op(ctx, op:ReduceOps, x, new_shape): ret = ctx.buffer(new_shape) ctx.op.reduce_op(op, x, ret) log_op(op, ret, [x]) return ret - def binary_op(ctx, op:ReduceOps, x, y): + def binary_op(ctx, op:BinaryOps, x, y): assert x.shape == y.shape ret = ctx.buffer(x.shape) ctx.op.binary_op(op, x, y, ret)