move test_grouped_dims to test/null (#14893)

it's a pure helper
This commit is contained in:
chenyu
2026-02-19 14:50:53 -05:00
committed by GitHub
parent af997c1ea5
commit 52f727738b
2 changed files with 90 additions and 96 deletions

View File

@@ -3,8 +3,7 @@ import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
@@ -253,100 +252,6 @@ class TestLinearizer(unittest.TestCase):
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
assert end_range < uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# 2 -> 2
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
# We should check if the returned indices are correct, for all cases.
# (65536, 2) -> (32768, 4)
dims, expected_limited_dims = (65536,2), (32768, 4)
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
def match_div(): raise RuntimeError("match_div")
def match_mod(): raise RuntimeError("match_mod")
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
pm = PatternMatcher([
(flat_idx_pattern//dims[1], match_div),
(flat_idx_pattern%dims[1], match_mod)
])
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[0], pm)
self.assertIn("match_div", str(error.exception))
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[1], pm)
self.assertIn("match_mod", str(error.exception))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse

89
test/null/test_gpudims.py Normal file
View File

@@ -0,0 +1,89 @@
import unittest
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import Ops, PatternMatcher, graph_rewrite, UPat
from tinygrad.helpers import flatten, dedup
class TestGroupedDims(unittest.TestCase):
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# 2 -> 2
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
# We should check if the returned indices are correct, for all cases.
# (65536, 2) -> (32768, 4)
dims, expected_limited_dims = (65536,2), (32768, 4)
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
def match_div(): raise RuntimeError("match_div")
def match_mod(): raise RuntimeError("match_mod")
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
pm = PatternMatcher([
(flat_idx_pattern//dims[1], match_div),
(flat_idx_pattern%dims[1], match_mod)
])
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[0], pm)
self.assertIn("match_div", str(error.exception))
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[1], pm)
self.assertIn("match_mod", str(error.exception))
if __name__ == '__main__':
unittest.main()