mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
reduce_op in SSA format too
This commit is contained in:
@@ -3,6 +3,9 @@ from collections import namedtuple
|
||||
|
||||
def prod(x): return int(np.prod(x))
|
||||
|
||||
def reduce_shape(shape, axis):
|
||||
return [1 if i in axis else shape[i] for i in range(len(shape))]
|
||||
|
||||
def binary_broadcast(x_shape, y_shape, extra=False):
|
||||
n_dims = max(len(x_shape), len(y_shape))
|
||||
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np # TODO: remove this, it's used for np.prod and np.argsort
|
||||
from tinygrad.helpers import prod, binary_broadcast, get_conv_args
|
||||
from tinygrad.helpers import prod, reduce_shape, get_conv_args
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
@@ -37,13 +37,10 @@ class Exp(_UnaryOp):
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
def reduce_shape(shape, axis):
|
||||
return [1 if i in axis else shape[i] for i in range(len(shape))]
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input.shape)
|
||||
return ctx.reduce_op(ReduceOps.SUM, input, ctx.buffer(reduce_shape(input.shape, axis)))
|
||||
return ctx.reduce_op(ReduceOps.SUM, input, reduce_shape(input.shape, axis))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_input, = ctx.saved_tensors
|
||||
@@ -53,21 +50,21 @@ class Sum(Function):
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ret = ctx.reduce_op(ReduceOps.MAX, input, ctx.buffer(reduce_shape(input.shape, axis)))
|
||||
ret = ctx.reduce_op(ReduceOps.MAX, input, reduce_shape(input.shape, axis))
|
||||
ctx.save_for_backward(input, ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, ret = ctx.saved_tensors
|
||||
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret)
|
||||
div = ctx.reduce_op(ReduceOps.SUM, ret2, ctx.buffer(grad_output.shape))
|
||||
div = ctx.reduce_op(ReduceOps.SUM, ret2, grad_output.shape)
|
||||
ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(ctx, out, in_sh):
|
||||
return ctx.reduce_op(ReduceOps.SUM, out, ctx.buffer(in_sh))
|
||||
return ctx.reduce_op(ReduceOps.SUM, out, in_sh)
|
||||
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
|
||||
@@ -45,7 +45,7 @@ def log_op(op, ret, inp):
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[top]
|
||||
G.nodes[nm(ret)]['style'] = 'filled'
|
||||
|
||||
from tinygrad.helpers import binary_broadcast
|
||||
from tinygrad.helpers import binary_broadcast, reduce_shape
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.buffer(x.shape)
|
||||
@@ -53,7 +53,8 @@ class Ops:
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op:BinaryOps, x, ret):
|
||||
def reduce_op(ctx, op:BinaryOps, x, new_shape):
|
||||
ret = ctx.buffer(new_shape)
|
||||
log_op(op, ret, [x])
|
||||
return ctx.op.reduce_op(op, x, ret)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user