mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
clean up reduce a little
This commit is contained in:
@@ -37,14 +37,17 @@ def binary_op(op, x, y, ret):
|
||||
return ret
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
if inp.shape == ret.shape: # this is just a copy
|
||||
if inp.shape == ret.shape: # this is just a copy, regardless of the reduce op
|
||||
ret[:] = inp
|
||||
return ret
|
||||
if ret.shape == (1,): axis=tuple(range(len(inp.shape)))
|
||||
else: axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
|
||||
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
else:
|
||||
if ret.shape == (1,): # full reduce
|
||||
axis = tuple(range(len(inp.shape)))
|
||||
else:
|
||||
assert len(inp.shape) == len(ret.shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
|
||||
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def reshape(x, shape):
|
||||
|
||||
Reference in New Issue
Block a user