Solve get_grouped_dims does not split issue (#9085)

* Solve dims too large errors on webgpu

* Simplify divisor find

* Test square root divisor

* Fix lint

* Refactor into group_dims and split_dims

* Refactor

* Fix lint

* Add back max check in _group_dims

* Prefer grouping over split

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Ahmed Harmouche
2025-02-17 01:57:29 +01:00
committed by GitHub
parent 84dc331dd1
commit 59fe45f947
5 changed files with 56 additions and 21 deletions

View File

@@ -450,7 +450,7 @@ jobs:
- name: Run selected webgpu tests
run: |
WEBGPU=1 python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit \
--ignore=test/test_copy_speed.py --ignore=test/test_rearrange_einops.py --ignore=test/test_speed_v_torch.py \
--ignore=test/test_copy_speed.py --ignore=test/test_rearrange_einops.py \
--ignore=test/test_fuzz_shape_ops.py --ignore=test/test_linearizer_failures.py --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay

View File

@@ -810,7 +810,7 @@ class TestAutoCastType(unittest.TestCase):
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "error due to too large dimensions")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):
N = 10000

View File

@@ -1175,13 +1175,14 @@ class TestLinearizer(unittest.TestCase):
assert end_range < k.uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
sizes = [x.arg[1] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
@@ -1198,11 +1199,26 @@ class TestLinearizer(unittest.TestCase):
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals
# _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
# _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,4,12])
# _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [12,16,4])
# _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,4,24])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
@@ -1215,11 +1231,11 @@ class TestLinearizer(unittest.TestCase):
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# # dim too large and not factorable
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (23,), (16,16,16), False,)
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (128,3,4), (16,4,23), False,)
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):

View File

@@ -21,7 +21,6 @@ class TestSpecific(unittest.TestCase):
(x @ w).reshape(1, 128, 4).contiguous().realize()
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Too large dimensions")
def test_big_vec_mul(self):
# from LLaMA
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]

View File

@@ -6,6 +6,7 @@ from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
import math
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
@@ -15,23 +16,37 @@ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> l
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
# ***** indexing *****
def _limit_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if dims[i] * dims[i+1] <= m:
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
else: return None
return dims
def _split_dims(dims, max_sizes):
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
_dims = list(dims) + [1]*(3-len(dims))
for i in range(len(_dims)):
while _dims[i] > max_sizes[i]:
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
if reverse: dims = dims[::-1]
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
# try to group first: (a, b, c, d) -> (ab, c, d)
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
# check if grouping failed
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
# try to split up dims: (a,) -> (b, c)
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
if limited != dims:
if len(limited) < len(dims):
ret = []
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
for idx, contraction_group in zip(raw_idxs, contraction):
@@ -39,6 +54,11 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
ret.append(idx % dims[c])
idx //= dims[c]
ret.append(idx)
elif len(limited) > len(dims):
a, b = len(limited), len(dims)
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
return ret[::-1] if reverse else ret
@dataclass