move unary and binary op mem alloc to Ops class

This commit is contained in:
George Hotz
2022-06-11 16:35:03 -07:00
parent 1511cbf9c5
commit bbf231da34
2 changed files with 32 additions and 32 deletions

View File

@@ -8,21 +8,20 @@ from tinygrad.tensor import Function
class _UnaryOp(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return ctx.unary_op(ctx.fop, input, ctx.buffer(input.shape))
return ctx.unary_op(ctx.fop, input)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return ctx.binary_op(ctx.bop, input, grad_output, ctx.buffer(input.shape))
return ctx.binary_op(ctx.bop, input, grad_output)
class ReLU(_UnaryOp):
fop = UnaryOps.RELU
def backward(ctx, grad_output):
input, = ctx.saved_tensors
ret = ctx.buffer(input.shape)
ctx.unary_op(UnaryOps.SIGN, input, ret)
ctx.unary_op(UnaryOps.RELU, ret, ret)
return ctx.binary_op(BinaryOps.MUL, ret, grad_output, ret)
ret = ctx.unary_op(UnaryOps.SIGN, input)
ret = ctx.unary_op(UnaryOps.RELU, ret)
return ctx.binary_op(BinaryOps.MUL, ret, grad_output)
class Log(_UnaryOp):
fop = UnaryOps.LOG
@@ -30,7 +29,7 @@ class Log(_UnaryOp):
class Exp(_UnaryOp):
def forward(ctx, input):
ret = ctx.unary_op(UnaryOps.EXP, input, ctx.buffer(input.shape))
ret = ctx.unary_op(UnaryOps.EXP, input)
ctx.save_for_backward(ret) # we save the output here, not the input
return ret
@@ -50,7 +49,7 @@ class Sum(Function):
shape_input, = ctx.saved_tensors
# NOTE: the b Buffer isn't used, since this is just for broadcast
ret = ctx.buffer(shape_input)
return ctx.binary_op(BinaryOps.A, grad_output, ret, ret)
return ctx.binary_op(BinaryOps.A, grad_output, ret)
class Max(Function):
def forward(ctx, input, axis=None):
@@ -60,10 +59,10 @@ class Max(Function):
def backward(ctx, grad_output):
input, ret = ctx.saved_tensors
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret, ctx.buffer(input.shape))
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret)
div = ctx.reduce_op(ReduceOps.SUM, ret2, ctx.buffer(grad_output.shape))
ctx.binary_op(BinaryOps.DIV, div, ret2, ret2)
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output, ret2)
ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2)
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output)
# ************* binary ops *************
@@ -73,8 +72,7 @@ def unbroadcast(ctx, out, in_sh):
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
buf = ctx.buffer(binary_broadcast(x.shape, y.shape))
return ctx.binary_op(BinaryOps.ADD, x, y, buf) #ctx.buffer(binary_broadcast(x.shape, y.shape)))
return ctx.binary_op(BinaryOps.ADD, x, y)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
@@ -84,41 +82,38 @@ class Add(Function):
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return ctx.binary_op(BinaryOps.SUB, x, y, ctx.buffer(binary_broadcast(x.shape, y.shape)))
return ctx.binary_op(BinaryOps.SUB, x, y)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output, ctx.buffer(grad_output.shape))
neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output)
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return ctx.binary_op(BinaryOps.MUL, x, y, ctx.buffer(binary_broadcast(x.shape, y.shape)))
return ctx.binary_op(BinaryOps.MUL, x, y)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
tmp = ctx.buffer(grad_output.shape)
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output), x.shape) if ctx.needs_input_grad[0] else None
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output), y.shape) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
class Pow(Function):
def forward(ctx, x, y):
ret = ctx.buffer(binary_broadcast(x.shape, y.shape))
ret = ctx.binary_op(BinaryOps.POW, x, y)
ctx.save_for_backward(x, y, ret)
return ctx.binary_op(BinaryOps.POW, x, y, ret)
return ret
def backward(ctx, grad_output):
x,y,powxy = ctx.saved_tensors
tmp = ctx.buffer(grad_output.shape)
ctx.binary_op(BinaryOps.DIV, x, powxy, tmp) # pow(x,y)/x
ctx.binary_op(BinaryOps.MUL, y, tmp, tmp) # y * pow(x,y)/x
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), x.shape) if ctx.needs_input_grad[0] else None
log_x = ctx.unary_op(UnaryOps.LOG, x, ctx.buffer(x.shape))
ctx.binary_op(BinaryOps.MUL, log_x, powxy, tmp) # log(x) * pow(x,y)
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), y.shape) if ctx.needs_input_grad[1] else None
tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
# ************* movement ops *************

View File

@@ -45,18 +45,23 @@ def log_op(op, ret, inp):
G.nodes[nm(ret)]['fillcolor'] = top_colors[top]
G.nodes[nm(ret)]['style'] = 'filled'
from tinygrad.helpers import binary_broadcast
class Ops:
def unary_op(ctx, op:UnaryOps, x, ret):
def unary_op(ctx, op:UnaryOps, x):
ret = ctx.buffer(x.shape)
ctx.op.unary_op(op, x, ret)
log_op(op, ret, [x])
return ctx.op.unary_op(op, x, ret)
return ret
def reduce_op(ctx, op:BinaryOps, x, ret):
log_op(op, ret, [x])
return ctx.op.reduce_op(op, x, ret)
def binary_op(ctx, op:ReduceOps, x, y, ret):
def binary_op(ctx, op:ReduceOps, x, y):
ret = ctx.buffer(binary_broadcast(x.shape, y.shape))
ctx.op.binary_op(op, x, y, ret)
log_op(op, ret, [x, y])
return ctx.op.binary_op(op, x, y, ret)
return ret
def movement_op(ctx, op:MovementOps, x, ret, arg=None):
log_op(op, ret, [x])