clean up reduce a little

This commit is contained in:
George Hotz
2022-06-09 09:34:19 -07:00
parent d3e7238bdd
commit 259b536e3a

View File

@@ -37,14 +37,17 @@ def binary_op(op, x, y, ret):
return ret
def reduce_op(op, inp, ret):
if inp.shape == ret.shape: # this is just a copy
if inp.shape == ret.shape: # this is just a copy, regardless of the reduce op
ret[:] = inp
return ret
if ret.shape == (1,): axis=tuple(range(len(inp.shape)))
else: axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
else:
if ret.shape == (1,): # full reduce
axis = tuple(range(len(inp.shape)))
else:
assert len(inp.shape) == len(ret.shape)
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
return ret
def reshape(x, shape):