diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index e1efdc56f4..19eb2be763 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -8,21 +8,20 @@ from tinygrad.tensor import Function class _UnaryOp(Function): def forward(ctx, input): ctx.save_for_backward(input) - return ctx.unary_op(ctx.fop, input, ctx.buffer(input.shape)) + return ctx.unary_op(ctx.fop, input) def backward(ctx, grad_output): input, = ctx.saved_tensors - return ctx.binary_op(ctx.bop, input, grad_output, ctx.buffer(input.shape)) + return ctx.binary_op(ctx.bop, input, grad_output) class ReLU(_UnaryOp): fop = UnaryOps.RELU def backward(ctx, grad_output): input, = ctx.saved_tensors - ret = ctx.buffer(input.shape) - ctx.unary_op(UnaryOps.SIGN, input, ret) - ctx.unary_op(UnaryOps.RELU, ret, ret) - return ctx.binary_op(BinaryOps.MUL, ret, grad_output, ret) + ret = ctx.unary_op(UnaryOps.SIGN, input) + ret = ctx.unary_op(UnaryOps.RELU, ret) + return ctx.binary_op(BinaryOps.MUL, ret, grad_output) class Log(_UnaryOp): fop = UnaryOps.LOG @@ -30,7 +29,7 @@ class Log(_UnaryOp): class Exp(_UnaryOp): def forward(ctx, input): - ret = ctx.unary_op(UnaryOps.EXP, input, ctx.buffer(input.shape)) + ret = ctx.unary_op(UnaryOps.EXP, input) ctx.save_for_backward(ret) # we save the output here, not the input return ret @@ -50,7 +49,7 @@ class Sum(Function): shape_input, = ctx.saved_tensors # NOTE: the b Buffer isn't used, since this is just for broadcast ret = ctx.buffer(shape_input) - return ctx.binary_op(BinaryOps.A, grad_output, ret, ret) + return ctx.binary_op(BinaryOps.A, grad_output, ret) class Max(Function): def forward(ctx, input, axis=None): @@ -60,10 +59,10 @@ class Max(Function): def backward(ctx, grad_output): input, ret = ctx.saved_tensors - ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret, ctx.buffer(input.shape)) + ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret) div = ctx.reduce_op(ReduceOps.SUM, ret2, ctx.buffer(grad_output.shape)) - ctx.binary_op(BinaryOps.DIV, div, ret2, ret2) - return ctx.binary_op(BinaryOps.MUL, ret2, grad_output, ret2) + ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2) + return ctx.binary_op(BinaryOps.MUL, ret2, grad_output) # ************* binary ops ************* @@ -73,8 +72,7 @@ def unbroadcast(ctx, out, in_sh): class Add(Function): def forward(ctx, x, y): ctx.save_for_backward(x.shape, y.shape) - buf = ctx.buffer(binary_broadcast(x.shape, y.shape)) - return ctx.binary_op(BinaryOps.ADD, x, y, buf) #ctx.buffer(binary_broadcast(x.shape, y.shape))) + return ctx.binary_op(BinaryOps.ADD, x, y) def backward(ctx, grad_output): shape_x, shape_y = ctx.saved_tensors @@ -84,41 +82,38 @@ class Add(Function): class Sub(Function): def forward(ctx, x, y): ctx.save_for_backward(x.shape, y.shape) - return ctx.binary_op(BinaryOps.SUB, x, y, ctx.buffer(binary_broadcast(x.shape, y.shape))) + return ctx.binary_op(BinaryOps.SUB, x, y) def backward(ctx, grad_output): shape_x, shape_y = ctx.saved_tensors - neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output, ctx.buffer(grad_output.shape)) + neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output) return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \ unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None class Mul(Function): def forward(ctx, x, y): ctx.save_for_backward(x, y) - return ctx.binary_op(BinaryOps.MUL, x, y, ctx.buffer(binary_broadcast(x.shape, y.shape))) + return ctx.binary_op(BinaryOps.MUL, x, y) def backward(ctx, grad_output): x,y = ctx.saved_tensors - tmp = ctx.buffer(grad_output.shape) - grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None - grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None + grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output), x.shape) if ctx.needs_input_grad[0] else None + grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output), y.shape) if ctx.needs_input_grad[1] else None return grad_x, grad_y class Pow(Function): def forward(ctx, x, y): - ret = ctx.buffer(binary_broadcast(x.shape, y.shape)) + ret = ctx.binary_op(BinaryOps.POW, x, y) ctx.save_for_backward(x, y, ret) - return ctx.binary_op(BinaryOps.POW, x, y, ret) + return ret def backward(ctx, grad_output): x,y,powxy = ctx.saved_tensors - tmp = ctx.buffer(grad_output.shape) - ctx.binary_op(BinaryOps.DIV, x, powxy, tmp) # pow(x,y)/x - ctx.binary_op(BinaryOps.MUL, y, tmp, tmp) # y * pow(x,y)/x - grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), x.shape) if ctx.needs_input_grad[0] else None - log_x = ctx.unary_op(UnaryOps.LOG, x, ctx.buffer(x.shape)) - ctx.binary_op(BinaryOps.MUL, log_x, powxy, tmp) # log(x) * pow(x,y) - grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), y.shape) if ctx.needs_input_grad[1] else None + tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x + tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x + grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None + tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y) + grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None return grad_x, grad_y # ************* movement ops ************* diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a961a44a46..d119c04810 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -45,18 +45,23 @@ def log_op(op, ret, inp): G.nodes[nm(ret)]['fillcolor'] = top_colors[top] G.nodes[nm(ret)]['style'] = 'filled' +from tinygrad.helpers import binary_broadcast class Ops: - def unary_op(ctx, op:UnaryOps, x, ret): + def unary_op(ctx, op:UnaryOps, x): + ret = ctx.buffer(x.shape) + ctx.op.unary_op(op, x, ret) log_op(op, ret, [x]) - return ctx.op.unary_op(op, x, ret) + return ret def reduce_op(ctx, op:BinaryOps, x, ret): log_op(op, ret, [x]) return ctx.op.reduce_op(op, x, ret) - def binary_op(ctx, op:ReduceOps, x, y, ret): + def binary_op(ctx, op:ReduceOps, x, y): + ret = ctx.buffer(binary_broadcast(x.shape, y.shape)) + ctx.op.binary_op(op, x, y, ret) log_op(op, ret, [x, y]) - return ctx.op.binary_op(op, x, y, ret) + return ret def movement_op(ctx, op:MovementOps, x, ret, arg=None): log_op(op, ret, [x])