mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
ShapeTracker.real_strides -> is_expanded [pr] (#12579)
only keep the used part
This commit is contained in:
@@ -81,7 +81,7 @@ def lin_to_feats(lin:Kernel, use_sts=True):
|
|||||||
ret = [float(x) for x in ret]
|
ret = [float(x) for x in ret]
|
||||||
|
|
||||||
if use_sts:
|
if use_sts:
|
||||||
my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts])
|
my_sts = dedup([(x.shape == lin.full_shape, x.is_expanded(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts])
|
||||||
assert len(my_sts) < MAX_BUFS
|
assert len(my_sts) < MAX_BUFS
|
||||||
sts_len = 3 + 5*MAX_DIMS
|
sts_len = 3 + 5*MAX_DIMS
|
||||||
for s in my_sts:
|
for s in my_sts:
|
||||||
|
|||||||
@@ -13,6 +13,6 @@ if __name__ == "__main__":
|
|||||||
(Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v
|
(Tensor.empty(BS, 16, 512, 512), Tensor.empty(BS, 512, 16, 64).permute(0,2,1,3)), # qk@v
|
||||||
]
|
]
|
||||||
for t0, t1 in tensors:
|
for t0, t1 in tensors:
|
||||||
print(f"{t0.shape=}, {t0.uop.st.real_strides()=}, {t1.shape=}, {t1.uop.st.real_strides()=}")
|
print(f"{t0.shape=}, {t0.uop.st.is_expanded()=}, {t1.shape=}, {t1.uop.st.is_expanded()=}")
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
t0.dot(t1, dtype=acc_dtype).realize()
|
t0.dot(t1, dtype=acc_dtype).realize()
|
||||||
|
|||||||
@@ -595,21 +595,21 @@ class TestMoveTensor(unittest.TestCase):
|
|||||||
np.testing.assert_equal(x.grad.numpy(), [[2,2,2],[0,0,0],[-2,-2,-2]])
|
np.testing.assert_equal(x.grad.numpy(), [[2,2,2],[0,0,0],[-2,-2,-2]])
|
||||||
|
|
||||||
class TestZeroShapeTensor(unittest.TestCase):
|
class TestZeroShapeTensor(unittest.TestCase):
|
||||||
def test_shape_stride(self):
|
def test_shape_is_expanded(self):
|
||||||
t = Tensor.empty(3, 2, 0)
|
t = Tensor.empty(3, 2, 0)
|
||||||
assert t.shape == (3, 2, 0)
|
assert t.shape == (3, 2, 0)
|
||||||
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
||||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
assert t.uop.st.is_expanded() == (True, True, True)
|
||||||
|
|
||||||
t = Tensor.empty(3, 0, 2)
|
t = Tensor.empty(3, 0, 2)
|
||||||
assert t.shape == (3, 0, 2)
|
assert t.shape == (3, 0, 2)
|
||||||
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
||||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
assert t.uop.st.is_expanded() == (True, True, True)
|
||||||
|
|
||||||
t = Tensor.empty(0, 0, 0)
|
t = Tensor.empty(0, 0, 0)
|
||||||
assert t.shape == (0, 0, 0)
|
assert t.shape == (0, 0, 0)
|
||||||
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
||||||
assert t.uop.st.real_strides() == (0, 0, 0)
|
assert t.uop.st.is_expanded() == (True, True, True)
|
||||||
|
|
||||||
def test_rand(self):
|
def test_rand(self):
|
||||||
t = Tensor.rand(3, 2, 0)
|
t = Tensor.rand(3, 2, 0)
|
||||||
|
|||||||
@@ -971,9 +971,8 @@ class TestIndexing(unittest.TestCase):
|
|||||||
numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
|
numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
|
||||||
# this isn't technically necessary, but matches NumPy stride calculations.
|
# this isn't technically necessary, but matches NumPy stride calculations.
|
||||||
# NOTE: this is empty and shouldn't have strides
|
# NOTE: this is empty and shouldn't have strides
|
||||||
#numpy_testing_assert_equal_helper((60, 20, 5), z.uop.st.real_strides())
|
numpy_testing_assert_equal_helper((True, True, True), z.uop.st.is_expanded())
|
||||||
# NOTE tinygrad's int slicing implementation makes this not contiguous
|
self.assertTrue(z.uop.st.contiguous)
|
||||||
# self.assertTrue(z.uop.st.contiguous)
|
|
||||||
|
|
||||||
@unittest.skip("bool indexing not supported")
|
@unittest.skip("bool indexing not supported")
|
||||||
def test_index_getitem_copy_bools_slices(self):
|
def test_index_getitem_copy_bools_slices(self):
|
||||||
|
|||||||
@@ -95,23 +95,20 @@ class TestRealIssues(unittest.TestCase):
|
|||||||
|
|
||||||
class TestRealDoesntSimplify(unittest.TestCase):
|
class TestRealDoesntSimplify(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
st = self.st.real_strides()
|
|
||||||
print(st)
|
|
||||||
self.st = self.st.simplify()
|
self.st = self.st.simplify()
|
||||||
assert len(self.st.views) != 1
|
assert len(self.st.views) != 1
|
||||||
assert None in st
|
|
||||||
|
|
||||||
def test_1(self):
|
def test_1(self):
|
||||||
self.st = ShapeTracker((
|
self.st = ShapeTracker((
|
||||||
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||||
View.create((8, 6, 11), (66, 11, 1), 0, None)))
|
View.create((8, 6, 11), (66, 11, 1), 0, None)))
|
||||||
self.assertEqual(self.st.real_strides(), (33, None, 1))
|
self.assertEqual(self.st.is_expanded(), (False, False, False))
|
||||||
|
|
||||||
def test_2(self):
|
def test_2(self):
|
||||||
self.st = ShapeTracker((
|
self.st = ShapeTracker((
|
||||||
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
|
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
|
||||||
self.assertEqual(self.st.real_strides(), (None, 18, -3, -1))
|
self.assertEqual(self.st.is_expanded(), (False, False, False, False))
|
||||||
|
|
||||||
class TestRealStrides(unittest.TestCase):
|
class TestRealStrides(unittest.TestCase):
|
||||||
def test_1(self):
|
def test_1(self):
|
||||||
@@ -119,7 +116,7 @@ class TestRealStrides(unittest.TestCase):
|
|||||||
View.create((2048,), (1,), 0, ((0, 512),)),
|
View.create((2048,), (1,), 0, ((0, 512),)),
|
||||||
View.create((16, 32, 4), (128, 4, 1), 0, None),
|
View.create((16, 32, 4), (128, 4, 1), 0, None),
|
||||||
))
|
))
|
||||||
self.assertEqual(st.real_strides(), (None, 4, 1))
|
self.assertEqual(st.is_expanded(), (False, False, False))
|
||||||
|
|
||||||
def test_2(self):
|
def test_2(self):
|
||||||
# test/test_ops.py::TestOps::test_simple_padding_conv1d
|
# test/test_ops.py::TestOps::test_simple_padding_conv1d
|
||||||
@@ -128,7 +125,7 @@ class TestRealStrides(unittest.TestCase):
|
|||||||
View.create((6, 2, 78), (140, 70, 1), 0, ((0, 6), (0, 2), (0, 70))),
|
View.create((6, 2, 78), (140, 70, 1), 0, ((0, 6), (0, 2), (0, 70))),
|
||||||
View.create((6, 2, 13, 6), (156, 78, 1, 13), 0, None),
|
View.create((6, 2, 13, 6), (156, 78, 1, 13), 0, None),
|
||||||
))
|
))
|
||||||
self.assertEqual(st.real_strides(), (90, 45, None, None))
|
self.assertEqual(st.is_expanded(), (False, False, False, False))
|
||||||
|
|
||||||
def test_3(self):
|
def test_3(self):
|
||||||
# test/test_ops.py::TestOps::test_simple_cumsum
|
# test/test_ops.py::TestOps::test_simple_cumsum
|
||||||
@@ -137,7 +134,7 @@ class TestRealStrides(unittest.TestCase):
|
|||||||
View.create((4, 131327), (131072, 1), 0, ((0, 4), (0, 131072))),
|
View.create((4, 131327), (131072, 1), 0, ((0, 4), (0, 131072))),
|
||||||
View.create((4, 511, 257), (131327, 1, 511), 0, None),
|
View.create((4, 511, 257), (131327, 1, 511), 0, None),
|
||||||
))
|
))
|
||||||
self.assertEqual(st.real_strides(), (256, None, None))
|
self.assertEqual(st.is_expanded(), (False, False, False))
|
||||||
|
|
||||||
def test_4(self):
|
def test_4(self):
|
||||||
# test/test_nn.py::TestNN::test_conv_transpose1d
|
# test/test_nn.py::TestNN::test_conv_transpose1d
|
||||||
@@ -146,7 +143,7 @@ class TestRealStrides(unittest.TestCase):
|
|||||||
View.create((1, 4, 1, 16, 8, 121), (0, 1792, 0, 112, 0, 1), -5, ((0, 1), (0, 4), (0, 1), (0, 16), (0, 8), (5, 116))),
|
View.create((1, 4, 1, 16, 8, 121), (0, 1792, 0, 112, 0, 1), -5, ((0, 1), (0, 4), (0, 1), (0, 16), (0, 8), (5, 116))),
|
||||||
View.create((4, 64, 115, 16, 7), (15488, 0, 1, 968, 122), 0, None),
|
View.create((4, 64, 115, 16, 7), (15488, 0, 1, 968, 122), 0, None),
|
||||||
))
|
))
|
||||||
self.assertEqual(st.real_strides(), (896, 0, None, 56, None))
|
self.assertEqual(st.is_expanded(), (False, True, False, False, False))
|
||||||
|
|
||||||
def test_5(self):
|
def test_5(self):
|
||||||
# test/test_ops.py::TestOps::test_conv2d
|
# test/test_ops.py::TestOps::test_conv2d
|
||||||
@@ -155,15 +152,12 @@ class TestRealStrides(unittest.TestCase):
|
|||||||
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
|
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
|
||||||
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
|
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
|
||||||
))
|
))
|
||||||
self.assertEqual(st.real_strides(), (132, 12, None, None, None))
|
self.assertEqual(st.is_expanded(), (False, False, False, True, False))
|
||||||
|
|
||||||
class TestRealSimplifies(unittest.TestCase):
|
class TestRealSimplifies(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
st = self.st.real_strides()
|
|
||||||
self.st = self.st.simplify()
|
self.st = self.st.simplify()
|
||||||
assert len(self.st.views) == 1
|
assert len(self.st.views) == 1
|
||||||
print(self.st.views[-1].strides, st)
|
|
||||||
self.assertEqual(self.st.views[-1].strides, st)
|
|
||||||
|
|
||||||
def test_1(self):
|
def test_1(self):
|
||||||
self.st = ShapeTracker((
|
self.st = ShapeTracker((
|
||||||
|
|||||||
@@ -10,22 +10,20 @@ class TestSymbolic(unittest.TestCase):
|
|||||||
def test_symbolic_st(self):
|
def test_symbolic_st(self):
|
||||||
x = Variable("x", 1, 100)
|
x = Variable("x", 1, 100)
|
||||||
st = ShapeTracker.from_shape((x, 3))
|
st = ShapeTracker.from_shape((x, 3))
|
||||||
assert st.shape == (x, 3)
|
self.assert_tuple_equal(st.shape, (x, 3))
|
||||||
assert st.real_strides() == (3, 1)
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
|
|
||||||
def test_real_strides_0(self):
|
def test_is_expanded_0(self):
|
||||||
st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
||||||
self.assertEqual(st.real_strides(), (8, None))
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
|
|
||||||
@unittest.expectedFailure
|
def test_is_expanded_1(self):
|
||||||
def test_real_strides_1(self):
|
|
||||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
|
|
||||||
@unittest.expectedFailure
|
def test_is_expanded_2(self):
|
||||||
def test_real_strides_2(self):
|
|
||||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
|
|
||||||
def test_merge_view_recursion_err(self):
|
def test_merge_view_recursion_err(self):
|
||||||
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
|
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
|
||||||
@@ -43,18 +41,18 @@ class TestSymbolic(unittest.TestCase):
|
|||||||
self.assertEqual(vm3.strides, vm1.strides)
|
self.assertEqual(vm3.strides, vm1.strides)
|
||||||
self.assertEqual(vm2+vm3, vm2)
|
self.assertEqual(vm2+vm3, vm2)
|
||||||
|
|
||||||
def test_cat_dim0_strides(self):
|
def test_cat_dim0_is_expanded(self):
|
||||||
i = Variable("i", 1, 5).bind(3)
|
i = Variable("i", 1, 5).bind(3)
|
||||||
j = Variable("j", 1, 5).bind(3)
|
j = Variable("j", 1, 5).bind(3)
|
||||||
k = Variable("k", 1, 5).bind(3)
|
k = Variable("k", 1, 5).bind(3)
|
||||||
t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0)
|
t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0)
|
||||||
st = t.uop.st
|
st = t.uop.st
|
||||||
self.assert_tuple_equal(st.shape, (i+j+k, 4))
|
self.assert_tuple_equal(st.shape, (i+j+k, 4))
|
||||||
assert st.real_strides() == (4, 1)
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0)
|
t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||||
st = t.uop.st
|
st = t.uop.st
|
||||||
self.assert_tuple_equal(st.shape, (2*i+3, 3))
|
self.assert_tuple_equal(st.shape, (2*i+3, 3))
|
||||||
assert st.real_strides() == (3, 1)
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
|
|
||||||
def test_cat_dim1_strides(self):
|
def test_cat_dim1_strides(self):
|
||||||
i = Variable("i", 1, 5).bind(4)
|
i = Variable("i", 1, 5).bind(4)
|
||||||
@@ -63,7 +61,7 @@ class TestSymbolic(unittest.TestCase):
|
|||||||
t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1)
|
t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1)
|
||||||
st = t.uop.st
|
st = t.uop.st
|
||||||
self.assert_tuple_equal(st.shape, (3, i+j+k))
|
self.assert_tuple_equal(st.shape, (3, i+j+k))
|
||||||
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
|
self.assert_tuple_equal(st.is_expanded(), (False, False))
|
||||||
|
|
||||||
class TestSymbolicVarVals(unittest.TestCase):
|
class TestSymbolicVarVals(unittest.TestCase):
|
||||||
def assert_equal(self, x, y): self.assertFalse(x != y)
|
def assert_equal(self, x, y): self.assertFalse(x != y)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
|||||||
# upcast float4 images, this must be early so we don't accidentally add locals before the upcast
|
# upcast float4 images, this must be early so we don't accidentally add locals before the upcast
|
||||||
for buf_index,buf in enumerate(k.bufs):
|
for buf_index,buf in enumerate(k.bufs):
|
||||||
if isinstance(buf.src[0].dtype, ImageDType):
|
if isinstance(buf.src[0].dtype, ImageDType):
|
||||||
# part of real_strides
|
# part of is_expanded
|
||||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
|
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
|
||||||
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||||
if len(unit_stride_axes_mul_4):
|
if len(unit_stride_axes_mul_4):
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ def split_reduceop(reduce:UOp, x:UOp):
|
|||||||
# ~2**10 should be enough if GROUP is used
|
# ~2**10 should be enough if GROUP is used
|
||||||
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
|
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
|
||||||
# split is moved to the end to provide maximum locality for the second phase reduce.
|
# split is moved to the end to provide maximum locality for the second phase reduce.
|
||||||
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
|
is_expanded = unwrap(x.st).is_expanded()
|
||||||
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
||||||
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
|
if x.shape[i]%d==0 and not is_expanded[i]]): return None
|
||||||
dim_to_split, divisor = split_candidates[0]
|
dim_to_split, divisor = split_candidates[0]
|
||||||
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
|
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
|
||||||
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
|
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
|
||||||
|
|||||||
@@ -17,20 +17,12 @@ def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None)
|
|||||||
return graph_rewrite(idx, sym, name="indexing sym @ 1")
|
return graph_rewrite(idx, sym, name="indexing sym @ 1")
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:
|
def views_to_is_expanded(views: tuple[View, ...]) -> tuple[bool, ...]:
|
||||||
# NOTE: if a stride is not always valid, it will be None
|
# NOTE: return if each dim is expanded
|
||||||
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
if len(views) == 1 and views[-1].mask is None: return tuple([bool(st==0) for st in views[-1].strides])
|
||||||
ret: list[sint|None] = [None] * len(views[-1].shape)
|
idx = views_to_valid_uop(views).get_idx()
|
||||||
idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid()
|
|
||||||
for c in idx.split_uop(Ops.ADD):
|
|
||||||
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
|
||||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
|
||||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
|
||||||
used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE]
|
used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE]
|
||||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
return tuple([i not in used_ranges for i in range(len(views[-1].shape))])
|
||||||
if not ignore_valid:
|
|
||||||
for masked_axis in [x.arg[0] for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None
|
|
||||||
return tuple(ret)
|
|
||||||
|
|
||||||
@dataclass(frozen=True, order=True)
|
@dataclass(frozen=True, order=True)
|
||||||
class ShapeTracker:
|
class ShapeTracker:
|
||||||
@@ -63,8 +55,8 @@ class ShapeTracker:
|
|||||||
if all(len(x) == 0 for x in var_vals): return self, {}
|
if all(len(x) == 0 for x in var_vals): return self, {}
|
||||||
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
||||||
|
|
||||||
def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]:
|
def is_expanded(self) -> tuple[bool, ...]:
|
||||||
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
|
with Context(TRACK_MATCH_STATS=0): return views_to_is_expanded(self.views)
|
||||||
|
|
||||||
def simplify(self) -> ShapeTracker:
|
def simplify(self) -> ShapeTracker:
|
||||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user