diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index d8ab1279d4..88807fba97 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -81,7 +81,7 @@ def lin_to_feats(lin:Kernel, use_sts=True): ret = [float(x) for x in ret] if use_sts: - my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts]) + my_sts = dedup([(x.shape == lin.full_shape, x.is_expanded(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts]) assert len(my_sts) < MAX_BUFS sts_len = 3 + 5*MAX_DIMS for s in my_sts: diff --git a/test/external/external_benchmark_bert_matmuls.py b/test/external/external_benchmark_bert_matmuls.py index 4f64629b54..fa2afd0388 100644 --- a/test/external/external_benchmark_bert_matmuls.py +++ b/test/external/external_benchmark_bert_matmuls.py @@ -13,6 +13,6 @@ if __name__ == "__main__": (Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v ] for t0, t1 in tensors: - print(f"{t0.shape=}, {t0.uop.st.real_strides()=}, {t1.shape=}, {t1.uop.st.real_strides()=}") + print(f"{t0.shape=}, {t0.uop.st.is_expanded()=}, {t1.shape=}, {t1.uop.st.is_expanded()=}") for _ in range(5): t0.dot(t1, dtype=acc_dtype).realize() diff --git a/test/test_tensor.py b/test/test_tensor.py index b046378118..43b8202dc4 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -595,21 +595,21 @@ class TestMoveTensor(unittest.TestCase): np.testing.assert_equal(x.grad.numpy(), [[2,2,2],[0,0,0],[-2,-2,-2]]) class TestZeroShapeTensor(unittest.TestCase): - def test_shape_stride(self): + def test_shape_is_expanded(self): t = Tensor.empty(3, 2, 0) assert t.shape == (3, 2, 0) # numpy has stride 0, 0, 0; torch has stride 2, 1, 1 - assert t.uop.st.real_strides() == (0, 0, 0) + assert t.uop.st.is_expanded() == (True, True, True) t = Tensor.empty(3, 0, 2) assert t.shape == (3, 0, 2) # numpy has stride 0, 0, 0; torch has stride 2, 2, 1 - assert t.uop.st.real_strides() == (0, 0, 0) + assert t.uop.st.is_expanded() == (True, True, True) t = Tensor.empty(0, 0, 0) assert t.shape == (0, 0, 0) # numpy has stride 0, 0, 0; torch has stride 1, 1, 1 - assert t.uop.st.real_strides() == (0, 0, 0) + assert t.uop.st.is_expanded() == (True, True, True) def test_rand(self): t = Tensor.rand(3, 2, 0) diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index c9d6d7c7da..36a9885aa8 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -971,9 +971,8 @@ class TestIndexing(unittest.TestCase): numpy_testing_assert_equal_helper((2, 0, 4), z.shape) # this isn't technically necessary, but matches NumPy stride calculations. # NOTE: this is empty and shouldn't have strides - #numpy_testing_assert_equal_helper((60, 20, 5), z.uop.st.real_strides()) - # NOTE tinygrad's int slicing implementation makes this not contiguous - # self.assertTrue(z.uop.st.contiguous) + numpy_testing_assert_equal_helper((True, True, True), z.uop.st.is_expanded()) + self.assertTrue(z.uop.st.contiguous) @unittest.skip("bool indexing not supported") def test_index_getitem_copy_bools_slices(self): diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 2412ea475f..ad570f8129 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -95,23 +95,20 @@ class TestRealIssues(unittest.TestCase): class TestRealDoesntSimplify(unittest.TestCase): def tearDown(self): - st = self.st.real_strides() - print(st) self.st = self.st.simplify() assert len(self.st.views) != 1 - assert None in st def test_1(self): self.st = ShapeTracker(( View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), View.create((8, 6, 11), (66, 11, 1), 0, None))) - self.assertEqual(self.st.real_strides(), (33, None, 1)) + self.assertEqual(self.st.is_expanded(), (False, False, False)) def test_2(self): self.st = ShapeTracker(( View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None))) - self.assertEqual(self.st.real_strides(), (None, 18, -3, -1)) + self.assertEqual(self.st.is_expanded(), (False, False, False, False)) class TestRealStrides(unittest.TestCase): def test_1(self): @@ -119,7 +116,7 @@ class TestRealStrides(unittest.TestCase): View.create((2048,), (1,), 0, ((0, 512),)), View.create((16, 32, 4), (128, 4, 1), 0, None), )) - self.assertEqual(st.real_strides(), (None, 4, 1)) + self.assertEqual(st.is_expanded(), (False, False, False)) def test_2(self): # test/test_ops.py::TestOps::test_simple_padding_conv1d @@ -128,7 +125,7 @@ class TestRealStrides(unittest.TestCase): View.create((6, 2, 78), (140, 70, 1), 0, ((0, 6), (0, 2), (0, 70))), View.create((6, 2, 13, 6), (156, 78, 1, 13), 0, None), )) - self.assertEqual(st.real_strides(), (90, 45, None, None)) + self.assertEqual(st.is_expanded(), (False, False, False, False)) def test_3(self): # test/test_ops.py::TestOps::test_simple_cumsum @@ -137,7 +134,7 @@ class TestRealStrides(unittest.TestCase): View.create((4, 131327), (131072, 1), 0, ((0, 4), (0, 131072))), View.create((4, 511, 257), (131327, 1, 511), 0, None), )) - self.assertEqual(st.real_strides(), (256, None, None)) + self.assertEqual(st.is_expanded(), (False, False, False)) def test_4(self): # test/test_nn.py::TestNN::test_conv_transpose1d @@ -146,7 +143,7 @@ class TestRealStrides(unittest.TestCase): View.create((1, 4, 1, 16, 8, 121), (0, 1792, 0, 112, 0, 1), -5, ((0, 1), (0, 4), (0, 1), (0, 16), (0, 8), (5, 116))), View.create((4, 64, 115, 16, 7), (15488, 0, 1, 968, 122), 0, None), )) - self.assertEqual(st.real_strides(), (896, 0, None, 56, None)) + self.assertEqual(st.is_expanded(), (False, True, False, False, False)) def test_5(self): # test/test_ops.py::TestOps::test_conv2d @@ -155,15 +152,12 @@ class TestRealStrides(unittest.TestCase): View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))), View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None), )) - self.assertEqual(st.real_strides(), (132, 12, None, None, None)) + self.assertEqual(st.is_expanded(), (False, False, False, True, False)) class TestRealSimplifies(unittest.TestCase): def tearDown(self): - st = self.st.real_strides() self.st = self.st.simplify() assert len(self.st.views) == 1 - print(self.st.views[-1].strides, st) - self.assertEqual(self.st.views[-1].strides, st) def test_1(self): self.st = ShapeTracker(( diff --git a/test/unit/test_symbolic_shapetracker.py b/test/unit/test_symbolic_shapetracker.py index c89419a2f9..0c5d11b46d 100644 --- a/test/unit/test_symbolic_shapetracker.py +++ b/test/unit/test_symbolic_shapetracker.py @@ -10,22 +10,20 @@ class TestSymbolic(unittest.TestCase): def test_symbolic_st(self): x = Variable("x", 1, 100) st = ShapeTracker.from_shape((x, 3)) - assert st.shape == (x, 3) - assert st.real_strides() == (3, 1) + self.assert_tuple_equal(st.shape, (x, 3)) + self.assert_tuple_equal(st.is_expanded(), (False, False)) - def test_real_strides_0(self): + def test_is_expanded_0(self): st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501 - self.assertEqual(st.real_strides(), (8, None)) + self.assert_tuple_equal(st.is_expanded(), (False, False)) - @unittest.expectedFailure - def test_real_strides_1(self): + def test_is_expanded_1(self): st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 - self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None)) + self.assert_tuple_equal(st.is_expanded(), (False, False)) - @unittest.expectedFailure - def test_real_strides_2(self): + def test_is_expanded_2(self): st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 - self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None)) + self.assert_tuple_equal(st.is_expanded(), (False, False)) def test_merge_view_recursion_err(self): vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False) @@ -43,18 +41,18 @@ class TestSymbolic(unittest.TestCase): self.assertEqual(vm3.strides, vm1.strides) self.assertEqual(vm2+vm3, vm2) - def test_cat_dim0_strides(self): + def test_cat_dim0_is_expanded(self): i = Variable("i", 1, 5).bind(3) j = Variable("j", 1, 5).bind(3) k = Variable("k", 1, 5).bind(3) t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0) st = t.uop.st self.assert_tuple_equal(st.shape, (i+j+k, 4)) - assert st.real_strides() == (4, 1) + self.assert_tuple_equal(st.is_expanded(), (False, False)) t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0) st = t.uop.st self.assert_tuple_equal(st.shape, (2*i+3, 3)) - assert st.real_strides() == (3, 1) + self.assert_tuple_equal(st.is_expanded(), (False, False)) def test_cat_dim1_strides(self): i = Variable("i", 1, 5).bind(4) @@ -63,7 +61,7 @@ class TestSymbolic(unittest.TestCase): t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1) st = t.uop.st self.assert_tuple_equal(st.shape, (3, i+j+k)) - self.assert_tuple_equal(st.real_strides(), (i+j+k, 1)) + self.assert_tuple_equal(st.is_expanded(), (False, False)) class TestSymbolicVarVals(unittest.TestCase): def assert_equal(self, x, y): self.assertFalse(x != y) diff --git a/tinygrad/codegen/opt/heuristic.py b/tinygrad/codegen/opt/heuristic.py index fb17ea629d..cfc61dd511 100644 --- a/tinygrad/codegen/opt/heuristic.py +++ b/tinygrad/codegen/opt/heuristic.py @@ -51,7 +51,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler: # upcast float4 images, this must be early so we don't accidentally add locals before the upcast for buf_index,buf in enumerate(k.bufs): if isinstance(buf.src[0].dtype, ImageDType): - # part of real_strides + # part of is_expanded unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0] if len(unit_stride_axes_mul_4): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 3d3722de92..bffff16e70 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -31,9 +31,9 @@ def split_reduceop(reduce:UOp, x:UOp): # ~2**10 should be enough if GROUP is used # 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum. # split is moved to the end to provide maximum locality for the second phase reduce. - real_strides = unwrap(x.st).real_strides(ignore_valid=True) + is_expanded = unwrap(x.st).is_expanded() if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1) - if x.shape[i]%d==0 and real_strides[i]!=0]): return None + if x.shape[i]%d==0 and not is_expanded[i]]): return None dim_to_split, divisor = split_candidates[0] splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:] splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split])) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 57c84f9e79..b4e0b584f0 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -17,20 +17,12 @@ def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) return graph_rewrite(idx, sym, name="indexing sym @ 1") @functools.cache -def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]: - # NOTE: if a stride is not always valid, it will be None - if len(views) == 1 and views[-1].mask is None: return views[-1].strides - ret: list[sint|None] = [None] * len(views[-1].shape) - idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid() - for c in idx.split_uop(Ops.ADD): - if c.op is Ops.RANGE: ret[c.arg[0]] = 1 - if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg - if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg +def views_to_is_expanded(views: tuple[View, ...]) -> tuple[bool, ...]: + # NOTE: return if each dim is expanded + if len(views) == 1 and views[-1].mask is None: return tuple([bool(st==0) for st in views[-1].strides]) + idx = views_to_valid_uop(views).get_idx() used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE] - ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] - if not ignore_valid: - for masked_axis in [x.arg[0] for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None - return tuple(ret) + return tuple([i not in used_ranges for i in range(len(views[-1].shape))]) @dataclass(frozen=True, order=True) class ShapeTracker: @@ -63,8 +55,8 @@ class ShapeTracker: if all(len(x) == 0 for x in var_vals): return self, {} return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) - def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]: - with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid) + def is_expanded(self) -> tuple[bool, ...]: + with Context(TRACK_MATCH_STATS=0): return views_to_is_expanded(self.views) def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: