new split reduce heuristic try 2 (#4294)

* new split reduce heuristic

* update comment

* rename

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
David Hou
2024-04-25 15:14:15 -07:00
committed by GitHub
parent f1ebcffb87
commit c2dbe2a78b

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import math
from typing import Union, Optional, Any, Tuple, List
from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.shape.symbolic import sint
from tinygrad.shape.shapetracker import ShapeTracker
@@ -181,13 +181,21 @@ class LazyBuffer:
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, axis, acc_dt)
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, s))/(st or math.inf), divisor, i) for i,(s,st) in \
enumerate(zip(self.shape, self.st.real_strides())) if i in axis and (st is None or isinstance(st, int)))
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis, acc_dt)
# choose largest divisor (>=16) to split on, penalize large strides
def splitted_shape(dim_aft_div):
return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,), acc_dt).reshape(splitted_shape(()))._reduce_op(op, axis, acc_dt)
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
self_real_strides = self.st.real_strides(ignore_valid=True)
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
if not split_candidates: return self._reduce_op(op, axis)
dim_to_split, divisor = split_candidates[0]
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
# *** movement ops ***