const fold ReduceOps (#4059)

This commit is contained in:
chenyu
2024-04-03 14:39:28 -04:00
committed by GitHub
parent fe03725b21
commit 406cb5fd90
4 changed files with 33 additions and 5 deletions

View File

@@ -159,6 +159,10 @@ class LazyBuffer:
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
# TODO: this logic should move to the scheduler
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
if self.is_unrealized_unpadded_const():
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, axis)