use simple gidxs on GPU

This commit is contained in:
George Hotz
2025-08-24 12:44:49 -07:00
parent a286a1a6f7
commit fa61b692fc
2 changed files with 10 additions and 119 deletions

View File

@@ -3,7 +3,6 @@ import unittest
from dataclasses import replace
from tinygrad.codegen.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.shape.shapetracker import ShapeTracker
@@ -467,77 +466,7 @@ class TestLinearizer(unittest.TestCase):
end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0]
assert end_range < uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
sizes = [x.arg[1] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
@unittest.skip("only one global now")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse