diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e4a8a22497..45a80853df 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -2,7 +2,7 @@ from __future__ import annotations import math from typing import Union, Optional, Any, Tuple, List from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype -from tinygrad.helpers import prod, getenv, all_int, all_same +from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu from tinygrad.shape.symbolic import sint from tinygrad.shape.shapetracker import ShapeTracker @@ -181,13 +181,21 @@ class LazyBuffer: if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, axis, acc_dt) - heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, s))/(st or math.inf), divisor, i) for i,(s,st) in \ - enumerate(zip(self.shape, self.st.real_strides())) if i in axis and (st is None or isinstance(st, int))) - if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis, acc_dt) - # choose largest divisor (>=16) to split on, penalize large strides - def splitted_shape(dim_aft_div): - return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:] - return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,), acc_dt).reshape(splitted_shape(()))._reduce_op(op, axis, acc_dt) + + # if there are few globals, make some reduces into globals by splitting into two kernels + # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm + # ~2**10 should be enough if GROUP is used + # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum. + # split is moved to the end to provide maximum locality for the second phase reduce. + self_real_strides = self.st.real_strides(ignore_valid=True) + split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1) + if self.shape[i] % x == 0 and self_real_strides[i] != 0] + if not split_candidates: return self._reduce_op(op, axis) + dim_to_split, divisor = split_candidates[0] + splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:] + splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split])) + if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}") + return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split # *** movement ops ***