reduce_op in SSA format too

This commit is contained in:
George Hotz
2022-06-11 16:40:14 -07:00
parent bbf231da34
commit 6685807df7
3 changed files with 11 additions and 10 deletions

View File

@@ -3,6 +3,9 @@ from collections import namedtuple
def prod(x): return int(np.prod(x))
def reduce_shape(shape, axis):
return [1 if i in axis else shape[i] for i in range(len(shape))]
def binary_broadcast(x_shape, y_shape, extra=False):
n_dims = max(len(x_shape), len(y_shape))
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)

View File

@@ -1,5 +1,5 @@
import numpy as np # TODO: remove this, it's used for np.prod and np.argsort
from tinygrad.helpers import prod, binary_broadcast, get_conv_args
from tinygrad.helpers import prod, reduce_shape, get_conv_args
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
from tinygrad.tensor import Function
@@ -37,13 +37,10 @@ class Exp(_UnaryOp):
# ************* reduce ops *************
def reduce_shape(shape, axis):
return [1 if i in axis else shape[i] for i in range(len(shape))]
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input.shape)
return ctx.reduce_op(ReduceOps.SUM, input, ctx.buffer(reduce_shape(input.shape, axis)))
return ctx.reduce_op(ReduceOps.SUM, input, reduce_shape(input.shape, axis))
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
@@ -53,21 +50,21 @@ class Sum(Function):
class Max(Function):
def forward(ctx, input, axis=None):
ret = ctx.reduce_op(ReduceOps.MAX, input, ctx.buffer(reduce_shape(input.shape, axis)))
ret = ctx.reduce_op(ReduceOps.MAX, input, reduce_shape(input.shape, axis))
ctx.save_for_backward(input, ret)
return ret
def backward(ctx, grad_output):
input, ret = ctx.saved_tensors
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret)
div = ctx.reduce_op(ReduceOps.SUM, ret2, ctx.buffer(grad_output.shape))
div = ctx.reduce_op(ReduceOps.SUM, ret2, grad_output.shape)
ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2)
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output)
# ************* binary ops *************
def unbroadcast(ctx, out, in_sh):
return ctx.reduce_op(ReduceOps.SUM, out, ctx.buffer(in_sh))
return ctx.reduce_op(ReduceOps.SUM, out, in_sh)
class Add(Function):
def forward(ctx, x, y):

View File

@@ -45,7 +45,7 @@ def log_op(op, ret, inp):
G.nodes[nm(ret)]['fillcolor'] = top_colors[top]
G.nodes[nm(ret)]['style'] = 'filled'
from tinygrad.helpers import binary_broadcast
from tinygrad.helpers import binary_broadcast, reduce_shape
class Ops:
def unary_op(ctx, op:UnaryOps, x):
ret = ctx.buffer(x.shape)
@@ -53,7 +53,8 @@ class Ops:
log_op(op, ret, [x])
return ret
def reduce_op(ctx, op:BinaryOps, x, ret):
def reduce_op(ctx, op:BinaryOps, x, new_shape):
ret = ctx.buffer(new_shape)
log_op(op, ret, [x])
return ctx.op.reduce_op(op, x, ret)