some limit_dims to limit global merging (#5489)

only supports merging dims in a way that does not surpass limit, no splitting yet
This commit is contained in:
chenyu
2024-07-19 12:17:46 -04:00
committed by GitHub
parent e04704faff
commit 3f590c3b31
3 changed files with 36 additions and 24 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, M
from tinygrad.renderer import TensorCore
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
# from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
@@ -763,24 +763,24 @@ class TestLinearizer(unittest.TestCase):
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
# _assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
# _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# # dim too large and not factorable
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (23,), (16,16,16), False,)
# get_grouped_dims("gidx", (23,), (16,16,16), False,)
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (128,3,4), (16,4,23), False,)
# get_grouped_dims("gidx", (128,3,4), (16,4,23), False,)
# # too large for sizes
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (Variable("start_pos",0,16),3,4), (16,16,16), False,)
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
def test_div_collapse(self):
def helper(t, msg, max_ops=0):

View File

@@ -2,11 +2,12 @@ from __future__ import annotations
from typing import List, Tuple, cast, Optional, Any, Dict
import functools
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.renderer import Renderer
from tinygrad.helpers import getenv, prod
from tinygrad.helpers import getenv, all_int, get_contraction
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
@@ -53,19 +54,30 @@ else:
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
def get_grouped_dims(prefix, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
# TODO: this should be per dim max
maxdim = len(max_sizes) if max_sizes is not None else 0
local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (),
(i, f"{prefix}{i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
if maxdim != 0 and len(dims) > maxdim:
dd = local_idxs[0]
nli = []
for s in dims[:-(maxdim-1)]:
nli.append(dd % s)
dd //= s
local_idxs = nli + local_idxs[-(maxdim-1):]
return local_idxs
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if dims[i] * dims[i+1] <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
return dims
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
limited_dims = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
ret = local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (), (i, f"{prefix}{i}", s)) for i,s in enumerate(limited_dims)]
if limited_dims != dims:
ret = []
# cast for mypy, get_contraction won't be None
for idx, contraction in zip(local_idxs, cast(List[List[int]], get_contraction(dims, limited_dims))):
if len(contraction) == 1: ret.append(idx)
else:
for c in contraction:
ret.append(idx % dims[c])
idx //= dims[c]
return ret
class IndependentLowerer:
def lower(self, ast:LazyOp, opts:Renderer) -> UOp:

View File

@@ -70,7 +70,7 @@ class LazyOp:
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
@functools.cached_property
def full_shape(self):
def full_shape(self) -> Tuple[sint, ...]:
if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
@functools.cached_property