mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
new split reduce heuristic try 2 (#4294)
* new split reduce heuristic * update comment * rename --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import math
|
||||
from typing import Union, Optional, Any, Tuple, List
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -181,13 +181,21 @@ class LazyBuffer:
|
||||
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
|
||||
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
||||
return self._reduce_op(op, axis, acc_dt)
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, s))/(st or math.inf), divisor, i) for i,(s,st) in \
|
||||
enumerate(zip(self.shape, self.st.real_strides())) if i in axis and (st is None or isinstance(st, int)))
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis, acc_dt)
|
||||
# choose largest divisor (>=16) to split on, penalize large strides
|
||||
def splitted_shape(dim_aft_div):
|
||||
return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,), acc_dt).reshape(splitted_shape(()))._reduce_op(op, axis, acc_dt)
|
||||
|
||||
# if there are few globals, make some reduces into globals by splitting into two kernels
|
||||
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
||||
# ~2**10 should be enough if GROUP is used
|
||||
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
|
||||
# split is moved to the end to provide maximum locality for the second phase reduce.
|
||||
self_real_strides = self.st.real_strides(ignore_valid=True)
|
||||
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
|
||||
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
|
||||
if not split_candidates: return self._reduce_op(op, axis)
|
||||
dim_to_split, divisor = split_candidates[0]
|
||||
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
|
||||
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
|
||||
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
|
||||
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user