fix gpudim bug and test_split_2d_to_3d (#14896)

This commit is contained in:
chenyu
2026-02-19 16:46:24 -05:00
committed by GitHub
parent 2b31823ef9
commit 9d6cf00be2
2 changed files with 3 additions and 7 deletions

View File

@@ -85,10 +85,7 @@ class TestGroupedDims(unittest.TestCase):
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
@unittest.expectedFailure
def test_split_2d_to_3d_bug(self):
# TODO: fix get_grouped_dims a=3,b=2 path: _split_dims redistributes factors across all dims,
# but line 51 assumes limited[0]*limited[1]==dims[0]. triggers on WebGPU with 2D shapes > 65535.
def test_split_2d_to_3d(self):
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
def test_max_sizes_none(self):

View File

@@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
elif (a:=len(limited)) > (b:=len(dims)):
if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
elif limited != dims:
if limited != dims:
# Convert to 1D
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
# Get back original indices from 1D
return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
return raw_idxs