ShapeTracker.real_strides -> is_expanded [pr] (#12579)

only keep the used part
This commit is contained in:
chenyu
2025-10-10 10:52:45 +08:00
committed by GitHub
parent 88ce63a49a
commit c8dfd10257
9 changed files with 37 additions and 54 deletions

View File

@@ -81,7 +81,7 @@ def lin_to_feats(lin:Kernel, use_sts=True):
ret = [float(x) for x in ret] ret = [float(x) for x in ret]
if use_sts: if use_sts:
my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts]) my_sts = dedup([(x.shape == lin.full_shape, x.is_expanded(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts])
assert len(my_sts) < MAX_BUFS assert len(my_sts) < MAX_BUFS
sts_len = 3 + 5*MAX_DIMS sts_len = 3 + 5*MAX_DIMS
for s in my_sts: for s in my_sts:

View File

@@ -13,6 +13,6 @@ if __name__ == "__main__":
(Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v (Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v
] ]
for t0, t1 in tensors: for t0, t1 in tensors:
print(f"{t0.shape=}, {t0.uop.st.real_strides()=}, {t1.shape=}, {t1.uop.st.real_strides()=}") print(f"{t0.shape=}, {t0.uop.st.is_expanded()=}, {t1.shape=}, {t1.uop.st.is_expanded()=}")
for _ in range(5): for _ in range(5):
t0.dot(t1, dtype=acc_dtype).realize() t0.dot(t1, dtype=acc_dtype).realize()

View File

@@ -595,21 +595,21 @@ class TestMoveTensor(unittest.TestCase):
np.testing.assert_equal(x.grad.numpy(), [[2,2,2],[0,0,0],[-2,-2,-2]]) np.testing.assert_equal(x.grad.numpy(), [[2,2,2],[0,0,0],[-2,-2,-2]])
class TestZeroShapeTensor(unittest.TestCase): class TestZeroShapeTensor(unittest.TestCase):
def test_shape_stride(self): def test_shape_is_expanded(self):
t = Tensor.empty(3, 2, 0) t = Tensor.empty(3, 2, 0)
assert t.shape == (3, 2, 0) assert t.shape == (3, 2, 0)
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1 # numpy has stride 0, 0, 0; torch has stride 2, 1, 1
assert t.uop.st.real_strides() == (0, 0, 0) assert t.uop.st.is_expanded() == (True, True, True)
t = Tensor.empty(3, 0, 2) t = Tensor.empty(3, 0, 2)
assert t.shape == (3, 0, 2) assert t.shape == (3, 0, 2)
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1 # numpy has stride 0, 0, 0; torch has stride 2, 2, 1
assert t.uop.st.real_strides() == (0, 0, 0) assert t.uop.st.is_expanded() == (True, True, True)
t = Tensor.empty(0, 0, 0) t = Tensor.empty(0, 0, 0)
assert t.shape == (0, 0, 0) assert t.shape == (0, 0, 0)
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1 # numpy has stride 0, 0, 0; torch has stride 1, 1, 1
assert t.uop.st.real_strides() == (0, 0, 0) assert t.uop.st.is_expanded() == (True, True, True)
def test_rand(self): def test_rand(self):
t = Tensor.rand(3, 2, 0) t = Tensor.rand(3, 2, 0)

View File

@@ -971,9 +971,8 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper((2, 0, 4), z.shape) numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
# this isn't technically necessary, but matches NumPy stride calculations. # this isn't technically necessary, but matches NumPy stride calculations.
# NOTE: this is empty and shouldn't have strides # NOTE: this is empty and shouldn't have strides
#numpy_testing_assert_equal_helper((60, 20, 5), z.uop.st.real_strides()) numpy_testing_assert_equal_helper((True, True, True), z.uop.st.is_expanded())
# NOTE tinygrad's int slicing implementation makes this not contiguous self.assertTrue(z.uop.st.contiguous)
# self.assertTrue(z.uop.st.contiguous)
@unittest.skip("bool indexing not supported") @unittest.skip("bool indexing not supported")
def test_index_getitem_copy_bools_slices(self): def test_index_getitem_copy_bools_slices(self):

View File

@@ -95,23 +95,20 @@ class TestRealIssues(unittest.TestCase):
class TestRealDoesntSimplify(unittest.TestCase): class TestRealDoesntSimplify(unittest.TestCase):
def tearDown(self): def tearDown(self):
st = self.st.real_strides()
print(st)
self.st = self.st.simplify() self.st = self.st.simplify()
assert len(self.st.views) != 1 assert len(self.st.views) != 1
assert None in st
def test_1(self): def test_1(self):
self.st = ShapeTracker(( self.st = ShapeTracker((
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
View.create((8, 6, 11), (66, 11, 1), 0, None))) View.create((8, 6, 11), (66, 11, 1), 0, None)))
self.assertEqual(self.st.real_strides(), (33, None, 1)) self.assertEqual(self.st.is_expanded(), (False, False, False))
def test_2(self): def test_2(self):
self.st = ShapeTracker(( self.st = ShapeTracker((
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None))) View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
self.assertEqual(self.st.real_strides(), (None, 18, -3, -1)) self.assertEqual(self.st.is_expanded(), (False, False, False, False))
class TestRealStrides(unittest.TestCase): class TestRealStrides(unittest.TestCase):
def test_1(self): def test_1(self):
@@ -119,7 +116,7 @@ class TestRealStrides(unittest.TestCase):
View.create((2048,), (1,), 0, ((0, 512),)), View.create((2048,), (1,), 0, ((0, 512),)),
View.create((16, 32, 4), (128, 4, 1), 0, None), View.create((16, 32, 4), (128, 4, 1), 0, None),
)) ))
self.assertEqual(st.real_strides(), (None, 4, 1)) self.assertEqual(st.is_expanded(), (False, False, False))
def test_2(self): def test_2(self):
# test/test_ops.py::TestOps::test_simple_padding_conv1d # test/test_ops.py::TestOps::test_simple_padding_conv1d
@@ -128,7 +125,7 @@ class TestRealStrides(unittest.TestCase):
View.create((6, 2, 78), (140, 70, 1), 0, ((0, 6), (0, 2), (0, 70))), View.create((6, 2, 78), (140, 70, 1), 0, ((0, 6), (0, 2), (0, 70))),
View.create((6, 2, 13, 6), (156, 78, 1, 13), 0, None), View.create((6, 2, 13, 6), (156, 78, 1, 13), 0, None),
)) ))
self.assertEqual(st.real_strides(), (90, 45, None, None)) self.assertEqual(st.is_expanded(), (False, False, False, False))
def test_3(self): def test_3(self):
# test/test_ops.py::TestOps::test_simple_cumsum # test/test_ops.py::TestOps::test_simple_cumsum
@@ -137,7 +134,7 @@ class TestRealStrides(unittest.TestCase):
View.create((4, 131327), (131072, 1), 0, ((0, 4), (0, 131072))), View.create((4, 131327), (131072, 1), 0, ((0, 4), (0, 131072))),
View.create((4, 511, 257), (131327, 1, 511), 0, None), View.create((4, 511, 257), (131327, 1, 511), 0, None),
)) ))
self.assertEqual(st.real_strides(), (256, None, None)) self.assertEqual(st.is_expanded(), (False, False, False))
def test_4(self): def test_4(self):
# test/test_nn.py::TestNN::test_conv_transpose1d # test/test_nn.py::TestNN::test_conv_transpose1d
@@ -146,7 +143,7 @@ class TestRealStrides(unittest.TestCase):
View.create((1, 4, 1, 16, 8, 121), (0, 1792, 0, 112, 0, 1), -5, ((0, 1), (0, 4), (0, 1), (0, 16), (0, 8), (5, 116))), View.create((1, 4, 1, 16, 8, 121), (0, 1792, 0, 112, 0, 1), -5, ((0, 1), (0, 4), (0, 1), (0, 16), (0, 8), (5, 116))),
View.create((4, 64, 115, 16, 7), (15488, 0, 1, 968, 122), 0, None), View.create((4, 64, 115, 16, 7), (15488, 0, 1, 968, 122), 0, None),
)) ))
self.assertEqual(st.real_strides(), (896, 0, None, 56, None)) self.assertEqual(st.is_expanded(), (False, True, False, False, False))
def test_5(self): def test_5(self):
# test/test_ops.py::TestOps::test_conv2d # test/test_ops.py::TestOps::test_conv2d
@@ -155,15 +152,12 @@ class TestRealStrides(unittest.TestCase):
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))), View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None), View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
)) ))
self.assertEqual(st.real_strides(), (132, 12, None, None, None)) self.assertEqual(st.is_expanded(), (False, False, False, True, False))
class TestRealSimplifies(unittest.TestCase): class TestRealSimplifies(unittest.TestCase):
def tearDown(self): def tearDown(self):
st = self.st.real_strides()
self.st = self.st.simplify() self.st = self.st.simplify()
assert len(self.st.views) == 1 assert len(self.st.views) == 1
print(self.st.views[-1].strides, st)
self.assertEqual(self.st.views[-1].strides, st)
def test_1(self): def test_1(self):
self.st = ShapeTracker(( self.st = ShapeTracker((

View File

@@ -10,22 +10,20 @@ class TestSymbolic(unittest.TestCase):
def test_symbolic_st(self): def test_symbolic_st(self):
x = Variable("x", 1, 100) x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3)) st = ShapeTracker.from_shape((x, 3))
assert st.shape == (x, 3) self.assert_tuple_equal(st.shape, (x, 3))
assert st.real_strides() == (3, 1) self.assert_tuple_equal(st.is_expanded(), (False, False))
def test_real_strides_0(self): def test_is_expanded_0(self):
st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501 st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
self.assertEqual(st.real_strides(), (8, None)) self.assert_tuple_equal(st.is_expanded(), (False, False))
@unittest.expectedFailure def test_is_expanded_1(self):
def test_real_strides_1(self):
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None)) self.assert_tuple_equal(st.is_expanded(), (False, False))
@unittest.expectedFailure def test_is_expanded_2(self):
def test_real_strides_2(self):
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None)) self.assert_tuple_equal(st.is_expanded(), (False, False))
def test_merge_view_recursion_err(self): def test_merge_view_recursion_err(self):
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False) vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
@@ -43,18 +41,18 @@ class TestSymbolic(unittest.TestCase):
self.assertEqual(vm3.strides, vm1.strides) self.assertEqual(vm3.strides, vm1.strides)
self.assertEqual(vm2+vm3, vm2) self.assertEqual(vm2+vm3, vm2)
def test_cat_dim0_strides(self): def test_cat_dim0_is_expanded(self):
i = Variable("i", 1, 5).bind(3) i = Variable("i", 1, 5).bind(3)
j = Variable("j", 1, 5).bind(3) j = Variable("j", 1, 5).bind(3)
k = Variable("k", 1, 5).bind(3) k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0) t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0)
st = t.uop.st st = t.uop.st
self.assert_tuple_equal(st.shape, (i+j+k, 4)) self.assert_tuple_equal(st.shape, (i+j+k, 4))
assert st.real_strides() == (4, 1) self.assert_tuple_equal(st.is_expanded(), (False, False))
t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0) t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.uop.st st = t.uop.st
self.assert_tuple_equal(st.shape, (2*i+3, 3)) self.assert_tuple_equal(st.shape, (2*i+3, 3))
assert st.real_strides() == (3, 1) self.assert_tuple_equal(st.is_expanded(), (False, False))
def test_cat_dim1_strides(self): def test_cat_dim1_strides(self):
i = Variable("i", 1, 5).bind(4) i = Variable("i", 1, 5).bind(4)
@@ -63,7 +61,7 @@ class TestSymbolic(unittest.TestCase):
t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1) t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1)
st = t.uop.st st = t.uop.st
self.assert_tuple_equal(st.shape, (3, i+j+k)) self.assert_tuple_equal(st.shape, (3, i+j+k))
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1)) self.assert_tuple_equal(st.is_expanded(), (False, False))
class TestSymbolicVarVals(unittest.TestCase): class TestSymbolicVarVals(unittest.TestCase):
def assert_equal(self, x, y): self.assertFalse(x != y) def assert_equal(self, x, y): self.assertFalse(x != y)

View File

@@ -51,7 +51,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# upcast float4 images, this must be early so we don't accidentally add locals before the upcast # upcast float4 images, this must be early so we don't accidentally add locals before the upcast
for buf_index,buf in enumerate(k.bufs): for buf_index,buf in enumerate(k.bufs):
if isinstance(buf.src[0].dtype, ImageDType): if isinstance(buf.src[0].dtype, ImageDType):
# part of real_strides # part of is_expanded
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
c.op is Ops.RANGE and (c.vmax+1)%4 == 0] c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
if len(unit_stride_axes_mul_4): if len(unit_stride_axes_mul_4):

View File

@@ -31,9 +31,9 @@ def split_reduceop(reduce:UOp, x:UOp):
# ~2**10 should be enough if GROUP is used # ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum. # 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce. # split is moved to the end to provide maximum locality for the second phase reduce.
real_strides = unwrap(x.st).real_strides(ignore_valid=True) is_expanded = unwrap(x.st).is_expanded()
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1) if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and real_strides[i]!=0]): return None if x.shape[i]%d==0 and not is_expanded[i]]): return None
dim_to_split, divisor = split_candidates[0] dim_to_split, divisor = split_candidates[0]
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:] splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split])) splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))

View File

@@ -17,20 +17,12 @@ def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None)
return graph_rewrite(idx, sym, name="indexing sym @ 1") return graph_rewrite(idx, sym, name="indexing sym @ 1")
@functools.cache @functools.cache
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]: def views_to_is_expanded(views: tuple[View, ...]) -> tuple[bool, ...]:
# NOTE: if a stride is not always valid, it will be None # NOTE: return if each dim is expanded
if len(views) == 1 and views[-1].mask is None: return views[-1].strides if len(views) == 1 and views[-1].mask is None: return tuple([bool(st==0) for st in views[-1].strides])
ret: list[sint|None] = [None] * len(views[-1].shape) idx = views_to_valid_uop(views).get_idx()
idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid()
for c in idx.split_uop(Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE] used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] return tuple([i not in used_ranges for i in range(len(views[-1].shape))])
if not ignore_valid:
for masked_axis in [x.arg[0] for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
@dataclass(frozen=True, order=True) @dataclass(frozen=True, order=True)
class ShapeTracker: class ShapeTracker:
@@ -63,8 +55,8 @@ class ShapeTracker:
if all(len(x) == 0 for x in var_vals): return self, {} if all(len(x) == 0 for x in var_vals): return self, {}
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]: def is_expanded(self) -> tuple[bool, ...]:
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid) with Context(TRACK_MATCH_STATS=0): return views_to_is_expanded(self.views)
def simplify(self) -> ShapeTracker: def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: