Fix dim splitting bug for len(dim) == len(limited) case (#13142)

* Fix gpudims bug on webgpu

* Fix split dim bug

* Remove webgpu_bug from examples

* Add test for shape correctness

* Fix 3D indexing

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Ahmed Harmouche
2025-11-07 18:31:06 +01:00
committed by GitHub
parent b8e48effcb
commit 3ecff3a8da
2 changed files with 29 additions and 1 deletions

View File

@@ -4,7 +4,7 @@ from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
@@ -278,6 +278,8 @@ class TestLinearizer(unittest.TestCase):
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# 2 -> 2
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
@@ -302,6 +304,27 @@ class TestLinearizer(unittest.TestCase):
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
# We should check if the returned indices are correct, for all cases.
# (65536, 2) -> (32768, 4)
dims, expected_limited_dims = (65536,2), (32768, 4)
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
def match_div(): raise RuntimeError("match_div")
def match_mod(): raise RuntimeError("match_mod")
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
pm = PatternMatcher([
(flat_idx_pattern//dims[1], match_div),
(flat_idx_pattern%dims[1], match_mod)
])
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[0], pm)
self.assertIn("match_div", str(error.exception))
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[1], pm)
self.assertIn("match_mod", str(error.exception))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)

View File

@@ -47,6 +47,11 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
elif limited != dims:
# Convert to 1D
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
# Get back original indices from 1D
ret = [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
return ret[::-1] if reverse else ret
def add_gpudims(ctx:Renderer, s:UOp):