mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Fix dim splitting bug for len(dim) == len(limited) case (#13142)
* Fix gpudims bug on webgpu * Fix split dim bug * Remove webgpu_bug from examples * Add test for shape correctness * Fix 3D indexing --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
|
||||
@@ -278,6 +278,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 3
|
||||
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
# 2 -> 2
|
||||
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
|
||||
@@ -302,6 +304,27 @@ class TestLinearizer(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
|
||||
# We should check if the returned indices are correct, for all cases.
|
||||
# (65536, 2) -> (32768, 4)
|
||||
dims, expected_limited_dims = (65536,2), (32768, 4)
|
||||
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
|
||||
def match_div(): raise RuntimeError("match_div")
|
||||
def match_mod(): raise RuntimeError("match_mod")
|
||||
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
|
||||
pm = PatternMatcher([
|
||||
(flat_idx_pattern//dims[1], match_div),
|
||||
(flat_idx_pattern%dims[1], match_mod)
|
||||
])
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
graph_rewrite(idxs[0], pm)
|
||||
self.assertIn("match_div", str(error.exception))
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
graph_rewrite(idxs[1], pm)
|
||||
self.assertIn("match_mod", str(error.exception))
|
||||
|
||||
# # variable too large
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
|
||||
|
||||
@@ -47,6 +47,11 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
||||
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
|
||||
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
||||
elif limited != dims:
|
||||
# Convert to 1D
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
# Get back original indices from 1D
|
||||
ret = [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
|
||||
return ret[::-1] if reverse else ret
|
||||
|
||||
def add_gpudims(ctx:Renderer, s:UOp):
|
||||
|
||||
Reference in New Issue
Block a user