fix to_shape_strides (#1374)

* add tests for expr_node and expr_idxs

* simplify condition and add missing optimization
This commit is contained in:
S-Lykles
2023-07-31 03:42:46 +02:00
committed by GitHub
parent 1fdf560fb1
commit c2b82ea8ac
2 changed files with 73 additions and 26 deletions

View File

@@ -4,6 +4,7 @@ import numpy as np
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Variable
from itertools import product
def shapetracker_getitem(st, val):
locals = {"idx": val, "valid": 1}
@@ -120,43 +121,87 @@ class TestRealSimplifies(unittest.TestCase):
View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
class TestIndexExpressions(unittest.TestCase):
class TestIndexExpressions2d(unittest.TestCase):
def setUp(self):
self.st = ShapeTracker((10, 10))
self.numel = prod(self.st.shape)
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
offsets = [0, 1, 15, 28, 10000]
self.sts = [ShapeTracker(base_shape, [View(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets]
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
self.shapes = [shape for shape in shapes for offset in offsets]
self.node_exprs = []
self.idxs_exprs = []
def tearDown(self):
assert self.expr(self.default_idx()) == self.st.expr_node()[0]
for idx in [None, (0, 99), (7, 203), (2, 5), (0, 0), (0, self.numel-1), (self.numel, self.numel), (0, self.numel), (0, self.numel+1), (self.numel+100, self.numel+100)]:
if idx is not None:
idx = test_idx = Variable("idx", idx[0], idx[1])
else:
test_idx = self.default_idx()
self.check_bounds(self.expr(test_idx))
assert self.expr(test_idx) == self.st.expr_node(idx)[0]
for st, offset, shape, node_expr, idxs_expr in zip(self.sts, self.offset, self.shapes, self.node_exprs, self.idxs_exprs):
numel = prod(shape)
assert node_expr(self.default_idx(st.shape)) == st.expr_node()[0]
assert node_expr(self.default_idx(st.shape)) == st.expr_node(None)[0]
assert node_expr(self.default_idx(st.shape)) == st.expr_node('idx')[0]
self.check_bounds(node_expr(self.default_idx(st.shape)), offset, numel)
for idx in [(0, numel-1), (7, 203), (2, 5), (0, 0), (numel, numel), (0, numel), (0, numel+1), (numel+100, numel+100)]:
idx = Variable("idx", idx[0], idx[1])
assert node_expr(idx) == st.expr_node(idx)[0]
self.check_bounds(node_expr(idx), offset, numel)
def default_idx(self):
return Variable("idx", 0, prod(self.st.shape)-1)
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs()[0]
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
idx2s = [(0,0), (0, min(1, st.shape[2]-1)), (0, st.shape[2]-1), (min(3, st.shape[2]-1), min(6, st.shape[2]-1)), (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
self.check_bounds(idxs_expr(idxs), offset, numel)
def check_bounds(self, expr):
assert expr.min >= self.st.real_offset()
assert expr.max <= self.st.real_offset() + self.numel - 1
def default_idx(self, shape):
return Variable("idx", 0, prod(shape)-1)
def test_noop(self):
self.expr = lambda idx: idx%100
def default_idxs(self, shape):
return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
def check_bounds(self, expr, offset, numel):
assert expr.min >= offset
assert expr.max <= offset + numel - 1
def test_noop(self):
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
def test_permute(self):
self.st.permute((1, 0))
self.expr = lambda idx: idx%10*10 + idx//10%10
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
def test_reshape(self):
self.st.reshape((10, 1, 10))
self.expr = lambda idx: idx%10 + idx//10%10*10
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.reshape((base_shape[0], 1, base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
def test_reshape_expand(self):
self.st.reshape((10, 1, 10))
self.st.expand((10, 10, 10))
self.expr = lambda idx: idx//100*10 + idx%10
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.reshape((base_shape[0], 1, base_shape[1]))
st.expand((base_shape[0], base_shape[1], base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
def test_permute_reshape_1(self): # This tests multiple views
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
def test_permute_reshape_2(self):
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
st.reshape((1, base_shape[0]//5, base_shape[1]*5))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
class TestSimplifyingShapeTracker(unittest.TestCase):
def setUp(self):

View File

@@ -14,8 +14,10 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tu
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if len(shape) > 0 else []
for i in range(1, len(shape)):
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0):
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
ret[-1] = (ret[-1][0] * shape[i], strides[i])
elif shape[i] == 1:
continue
else:
ret.append((shape[i], strides[i]))
return tuple(ret)