diff --git a/test/null/test_gpudims.py b/test/null/test_gpudims.py index e84b5ec752..8324e29a8c 100644 --- a/test/null/test_gpudims.py +++ b/test/null/test_gpudims.py @@ -85,10 +85,7 @@ class TestGroupedDims(unittest.TestCase): with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) - @unittest.expectedFailure - def test_split_2d_to_3d_bug(self): - # TODO: fix get_grouped_dims a=3,b=2 path: _split_dims redistributes factors across all dims, - # but line 51 assumes limited[0]*limited[1]==dims[0]. triggers on WebGPU with 2D shapes > 65535. + def test_split_2d_to_3d(self): self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) def test_max_sizes_none(self): diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 4c9decdcfc..11fa8b7c72 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No elif (a:=len(limited)) > (b:=len(dims)): if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]] if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]] - if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]] - elif limited != dims: + if limited != dims: # Convert to 1D - flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2] + flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2] # Get back original indices from 1D return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]] return raw_idxs