Minor cleanups (#1911)

* cleanups

* remove that simplify
This commit is contained in:
George Hotz
2023-09-24 21:32:50 +08:00
committed by GitHub
parent 20059dc55b
commit 6d9065ed1c
2 changed files with 4 additions and 31 deletions

View File

@@ -117,6 +117,7 @@ class TestOptBinOp(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptReduceLoop(unittest.TestCase):
@unittest.skip("this is broken")
def test_loop_left(self):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)

View File

@@ -24,43 +24,16 @@ P2P = getenv("P2P", 0)
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS, SIMPLIFY_SUM_RESHAPE_EXPAND_SUM = OPT>=2, OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -> Optional[LazyOp]:
if prev_src.op.op == MovementOps.EXPAND:
if src.op.op == ReduceOps.SUM:
if src.shape == self.shape:
dim_difference = [i for i, (a, b) in enumerate(zip(prev_src.shape, self.shape)) if a != b]
# NOTE: we can probably also handle the case where more than one dimension is different with more thought
if len(dim_difference) == 1:
expansion_index = dim_difference[0]
expansion_size = prev_src.shape[expansion_index]
return LazyOp(BinaryOps.MUL, (src, LazyBuffer.const(src, expansion_size)))
return None
# **** realize functions ****
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if not src.realized:
assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp"
# When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis,
# it's equivalent to performing the initial reduction and multiplying the result
# by the size of the expanded dimension.
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND:
expanded = src.op.src[0]
assert isinstance(expanded.op, LazyOp)
if expanded.op.op == MovementOps.RESHAPE:
reshaped = expanded.op.src[0]
simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src)
else:
simplified = _simplify_sum_reshape_expand_sum(self, expanded, src)
if simplified: return simplified
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1:
# If we did remove an expand above, we might stumble back into a case where the reduction is not necessary
if src.shape == self.shape: return src.op
src = src.op
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: src = src.op
return LazyOp(self.op.op, (src,), self.op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
@@ -323,8 +296,7 @@ class LazyBuffer:
return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg)
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
local_st = ShapeTracker.from_shape(self.shape).stride(arg)
if self.shape == local_st.shape and local_st.contiguous: return self
if all(a == 1 for a in arg): return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg)