mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
move unary and binary op mem alloc to Ops class
This commit is contained in:
@@ -8,21 +8,20 @@ from tinygrad.tensor import Function
|
||||
class _UnaryOp(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return ctx.unary_op(ctx.fop, input, ctx.buffer(input.shape))
|
||||
return ctx.unary_op(ctx.fop, input)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return ctx.binary_op(ctx.bop, input, grad_output, ctx.buffer(input.shape))
|
||||
return ctx.binary_op(ctx.bop, input, grad_output)
|
||||
|
||||
class ReLU(_UnaryOp):
|
||||
fop = UnaryOps.RELU
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
ret = ctx.buffer(input.shape)
|
||||
ctx.unary_op(UnaryOps.SIGN, input, ret)
|
||||
ctx.unary_op(UnaryOps.RELU, ret, ret)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret, grad_output, ret)
|
||||
ret = ctx.unary_op(UnaryOps.SIGN, input)
|
||||
ret = ctx.unary_op(UnaryOps.RELU, ret)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret, grad_output)
|
||||
|
||||
class Log(_UnaryOp):
|
||||
fop = UnaryOps.LOG
|
||||
@@ -30,7 +29,7 @@ class Log(_UnaryOp):
|
||||
|
||||
class Exp(_UnaryOp):
|
||||
def forward(ctx, input):
|
||||
ret = ctx.unary_op(UnaryOps.EXP, input, ctx.buffer(input.shape))
|
||||
ret = ctx.unary_op(UnaryOps.EXP, input)
|
||||
ctx.save_for_backward(ret) # we save the output here, not the input
|
||||
return ret
|
||||
|
||||
@@ -50,7 +49,7 @@ class Sum(Function):
|
||||
shape_input, = ctx.saved_tensors
|
||||
# NOTE: the b Buffer isn't used, since this is just for broadcast
|
||||
ret = ctx.buffer(shape_input)
|
||||
return ctx.binary_op(BinaryOps.A, grad_output, ret, ret)
|
||||
return ctx.binary_op(BinaryOps.A, grad_output, ret)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
@@ -60,10 +59,10 @@ class Max(Function):
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, ret = ctx.saved_tensors
|
||||
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret, ctx.buffer(input.shape))
|
||||
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret)
|
||||
div = ctx.reduce_op(ReduceOps.SUM, ret2, ctx.buffer(grad_output.shape))
|
||||
ctx.binary_op(BinaryOps.DIV, div, ret2, ret2)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output, ret2)
|
||||
ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
@@ -73,8 +72,7 @@ def unbroadcast(ctx, out, in_sh):
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
buf = ctx.buffer(binary_broadcast(x.shape, y.shape))
|
||||
return ctx.binary_op(BinaryOps.ADD, x, y, buf) #ctx.buffer(binary_broadcast(x.shape, y.shape)))
|
||||
return ctx.binary_op(BinaryOps.ADD, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
@@ -84,41 +82,38 @@ class Add(Function):
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return ctx.binary_op(BinaryOps.SUB, x, y, ctx.buffer(binary_broadcast(x.shape, y.shape)))
|
||||
return ctx.binary_op(BinaryOps.SUB, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output, ctx.buffer(grad_output.shape))
|
||||
neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output)
|
||||
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
|
||||
unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return ctx.binary_op(BinaryOps.MUL, x, y, ctx.buffer(binary_broadcast(x.shape, y.shape)))
|
||||
return ctx.binary_op(BinaryOps.MUL, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
tmp = ctx.buffer(grad_output.shape)
|
||||
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
|
||||
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
|
||||
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output), x.shape) if ctx.needs_input_grad[0] else None
|
||||
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output), y.shape) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
ret = ctx.buffer(binary_broadcast(x.shape, y.shape))
|
||||
ret = ctx.binary_op(BinaryOps.POW, x, y)
|
||||
ctx.save_for_backward(x, y, ret)
|
||||
return ctx.binary_op(BinaryOps.POW, x, y, ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y,powxy = ctx.saved_tensors
|
||||
tmp = ctx.buffer(grad_output.shape)
|
||||
ctx.binary_op(BinaryOps.DIV, x, powxy, tmp) # pow(x,y)/x
|
||||
ctx.binary_op(BinaryOps.MUL, y, tmp, tmp) # y * pow(x,y)/x
|
||||
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), x.shape) if ctx.needs_input_grad[0] else None
|
||||
log_x = ctx.unary_op(UnaryOps.LOG, x, ctx.buffer(x.shape))
|
||||
ctx.binary_op(BinaryOps.MUL, log_x, powxy, tmp) # log(x) * pow(x,y)
|
||||
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp, tmp), y.shape) if ctx.needs_input_grad[1] else None
|
||||
tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
|
||||
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
|
||||
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
@@ -45,18 +45,23 @@ def log_op(op, ret, inp):
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[top]
|
||||
G.nodes[nm(ret)]['style'] = 'filled'
|
||||
|
||||
from tinygrad.helpers import binary_broadcast
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x, ret):
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.buffer(x.shape)
|
||||
ctx.op.unary_op(op, x, ret)
|
||||
log_op(op, ret, [x])
|
||||
return ctx.op.unary_op(op, x, ret)
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op:BinaryOps, x, ret):
|
||||
log_op(op, ret, [x])
|
||||
return ctx.op.reduce_op(op, x, ret)
|
||||
|
||||
def binary_op(ctx, op:ReduceOps, x, y, ret):
|
||||
def binary_op(ctx, op:ReduceOps, x, y):
|
||||
ret = ctx.buffer(binary_broadcast(x.shape, y.shape))
|
||||
ctx.op.binary_op(op, x, y, ret)
|
||||
log_op(op, ret, [x, y])
|
||||
return ctx.op.binary_op(op, x, y, ret)
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, ret, arg=None):
|
||||
log_op(op, ret, [x])
|
||||
|
||||
Reference in New Issue
Block a user