mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Solve get_grouped_dims does not split issue (#9085)
* Solve dims too large errors on webgpu * Simplify divisor find * Test square root divisor * Fix lint * Refactor into group_dims and split_dims * Refactor * Fix lint * Add back max check in _group_dims * Prefer grouping over split --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -450,7 +450,7 @@ jobs:
|
||||
- name: Run selected webgpu tests
|
||||
run: |
|
||||
WEBGPU=1 python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit \
|
||||
--ignore=test/test_copy_speed.py --ignore=test/test_rearrange_einops.py --ignore=test/test_speed_v_torch.py \
|
||||
--ignore=test/test_copy_speed.py --ignore=test/test_rearrange_einops.py \
|
||||
--ignore=test/test_fuzz_shape_ops.py --ignore=test/test_linearizer_failures.py --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
@@ -810,7 +810,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "error due to too large dimensions")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_underflow(self):
|
||||
N = 10000
|
||||
|
||||
@@ -1175,13 +1175,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert end_range < k.uops.index(u)
|
||||
|
||||
def test_grouped_dims(self):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
|
||||
sizes = [x.arg[1] for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
if assert_same_length:
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
# TODO: add these back after uop symbolic
|
||||
# for i in range(len(dims)):
|
||||
@@ -1198,11 +1199,26 @@ class TestLinearizer(unittest.TestCase):
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
|
||||
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
|
||||
|
||||
# test splitting globals
|
||||
# _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
# _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,4,12])
|
||||
# _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [12,16,4])
|
||||
# _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,4,24])
|
||||
# test splitting globals: len(dims) == len(max)
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
# test splitting globals: len(dims) < len(max)
|
||||
# len(dim) -> len(limited)
|
||||
# 1 -> 2
|
||||
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
|
||||
# 1 -> 3
|
||||
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 3
|
||||
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
|
||||
# collapse on onto the left most axis
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
|
||||
@@ -1215,11 +1231,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# # dim too large and not factorable
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", (128,3,4), (16,4,23), False,)
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
|
||||
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@@ -21,7 +21,6 @@ class TestSpecific(unittest.TestCase):
|
||||
(x @ w).reshape(1, 128, 4).contiguous().realize()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Too large dimensions")
|
||||
def test_big_vec_mul(self):
|
||||
# from LLaMA
|
||||
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]
|
||||
|
||||
@@ -6,6 +6,7 @@ from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
||||
import math
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
@@ -15,23 +16,37 @@ def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> l
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
def _limit_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
||||
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
||||
for i,m in enumerate(max_sizes):
|
||||
if dims[i] * dims[i+1] <= m:
|
||||
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
||||
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
||||
break
|
||||
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
else: return None
|
||||
return dims
|
||||
|
||||
def _split_dims(dims, max_sizes):
|
||||
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
|
||||
_dims = list(dims) + [1]*(3-len(dims))
|
||||
for i in range(len(_dims)):
|
||||
while _dims[i] > max_sizes[i]:
|
||||
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
|
||||
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
|
||||
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
|
||||
|
||||
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
|
||||
if reverse: dims = dims[::-1]
|
||||
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
# try to group first: (a, b, c, d) -> (ab, c, d)
|
||||
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
|
||||
# check if grouping failed
|
||||
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
||||
# try to split up dims: (a,) -> (b, c)
|
||||
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
if limited != dims:
|
||||
if len(limited) < len(dims):
|
||||
ret = []
|
||||
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
|
||||
for idx, contraction_group in zip(raw_idxs, contraction):
|
||||
@@ -39,6 +54,11 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
ret.append(idx % dims[c])
|
||||
idx //= dims[c]
|
||||
ret.append(idx)
|
||||
elif len(limited) > len(dims):
|
||||
a, b = len(limited), len(dims)
|
||||
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
||||
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
|
||||
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
||||
return ret[::-1] if reverse else ret
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user