mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
some limit_dims to limit global merging (#5489)
only supports merging dims in a way that does not surpass limit, no splitting yet
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, M
|
||||
from tinygrad.renderer import TensorCore
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
# from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||
@@ -763,24 +763,24 @@ class TestLinearizer(unittest.TestCase):
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
# _assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
# _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
|
||||
|
||||
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# # dim too large and not factorable
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", 0, (23,), (16,16,16), False,)
|
||||
# get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", 0, (128,3,4), (16,4,23), False,)
|
||||
# get_grouped_dims("gidx", (128,3,4), (16,4,23), False,)
|
||||
|
||||
# # too large for sizes
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16), False,)
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
# # variable too large
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", 0, (Variable("start_pos",0,16),3,4), (16,16,16), False,)
|
||||
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
|
||||
|
||||
def test_div_collapse(self):
|
||||
def helper(t, msg, max_ops=0):
|
||||
|
||||
@@ -2,11 +2,12 @@ from __future__ import annotations
|
||||
from typing import List, Tuple, cast, Optional, Any, Dict
|
||||
import functools
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, MetaOps, KernelInfo
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import getenv, prod
|
||||
from tinygrad.helpers import getenv, all_int, get_contraction
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
@@ -53,19 +54,30 @@ else:
|
||||
assert uvalid.dtype == dtypes.bool
|
||||
return uidx, uvalid
|
||||
|
||||
def get_grouped_dims(prefix, dims, max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
|
||||
# TODO: this should be per dim max
|
||||
maxdim = len(max_sizes) if max_sizes is not None else 0
|
||||
local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (),
|
||||
(i, f"{prefix}{i}", s)) for i,s in enumerate((prod(dims[:-(maxdim-1)]),) + dims[-(maxdim-1):] if len(dims) > maxdim else dims)]
|
||||
if maxdim != 0 and len(dims) > maxdim:
|
||||
dd = local_idxs[0]
|
||||
nli = []
|
||||
for s in dims[:-(maxdim-1)]:
|
||||
nli.append(dd % s)
|
||||
dd //= s
|
||||
local_idxs = nli + local_idxs[-(maxdim-1):]
|
||||
return local_idxs
|
||||
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
||||
for i,m in enumerate(max_sizes):
|
||||
if dims[i] * dims[i+1] <= m:
|
||||
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
||||
break
|
||||
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
return dims
|
||||
|
||||
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]]) -> List[UOp]:
|
||||
limited_dims = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = local_idxs = [UOp(UOps.SPECIAL, dtypes.bigint, (), (i, f"{prefix}{i}", s)) for i,s in enumerate(limited_dims)]
|
||||
if limited_dims != dims:
|
||||
ret = []
|
||||
# cast for mypy, get_contraction won't be None
|
||||
for idx, contraction in zip(local_idxs, cast(List[List[int]], get_contraction(dims, limited_dims))):
|
||||
if len(contraction) == 1: ret.append(idx)
|
||||
else:
|
||||
for c in contraction:
|
||||
ret.append(idx % dims[c])
|
||||
idx //= dims[c]
|
||||
return ret
|
||||
|
||||
class IndependentLowerer:
|
||||
def lower(self, ast:LazyOp, opts:Renderer) -> UOp:
|
||||
|
||||
@@ -70,7 +70,7 @@ class LazyOp:
|
||||
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
||||
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
|
||||
@functools.cached_property
|
||||
def full_shape(self):
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
|
||||
@functools.cached_property
|
||||
|
||||
Reference in New Issue
Block a user