diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 23d392ef1a..5b7b1a921c 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -4,7 +4,7 @@ from dataclasses import replace from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.gpudims import get_grouped_dims -from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType +from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program @@ -278,6 +278,8 @@ class TestLinearizer(unittest.TestCase): _assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False) # 2 -> 3 _assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) + # 2 -> 2 + _assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False) # test when the only divisor is the square root of dim _assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False) @@ -302,6 +304,27 @@ class TestLinearizer(unittest.TestCase): with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) + # TODO: In the above cases we only test if the shape after reshape is correct, never the indices. + # We should check if the returned indices are correct, for all cases. + # (65536, 2) -> (32768, 4) + dims, expected_limited_dims = (65536,2), (32768, 4) + idxs = get_grouped_dims("gidx", dims, (65535,65535,65535)) + def match_div(): raise RuntimeError("match_div") + def match_mod(): raise RuntimeError("match_mod") + flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1') + pm = PatternMatcher([ + (flat_idx_pattern//dims[1], match_div), + (flat_idx_pattern%dims[1], match_mod) + ]) + + with self.assertRaises(RuntimeError) as error: + graph_rewrite(idxs[0], pm) + self.assertIn("match_div", str(error.exception)) + + with self.assertRaises(RuntimeError) as error: + graph_rewrite(idxs[1], pm) + self.assertIn("match_mod", str(error.exception)) + # # variable too large # with self.assertRaises(AssertionError): # get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,) diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 763e2d440f..b394fb81d3 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -47,6 +47,11 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]] if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]] if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]] + elif limited != dims: + # Convert to 1D + flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2] + # Get back original indices from 1D + ret = [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]] return ret[::-1] if reverse else ret def add_gpudims(ctx:Renderer, s:UOp):