mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix gpudim bug and test_split_2d_to_3d (#14896)
This commit is contained in:
@@ -85,10 +85,7 @@ class TestGroupedDims(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_split_2d_to_3d_bug(self):
|
||||
# TODO: fix get_grouped_dims a=3,b=2 path: _split_dims redistributes factors across all dims,
|
||||
# but line 51 assumes limited[0]*limited[1]==dims[0]. triggers on WebGPU with 2D shapes > 65535.
|
||||
def test_split_2d_to_3d(self):
|
||||
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
|
||||
def test_max_sizes_none(self):
|
||||
|
||||
@@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
elif (a:=len(limited)) > (b:=len(dims)):
|
||||
if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
||||
if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]]
|
||||
if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
||||
elif limited != dims:
|
||||
if limited != dims:
|
||||
# Convert to 1D
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
# Get back original indices from 1D
|
||||
return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
|
||||
return raw_idxs
|
||||
|
||||
Reference in New Issue
Block a user