mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix to_shape_strides (#1374)
* add tests for expr_node and expr_idxs * simplify condition and add missing optimization
This commit is contained in:
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
locals = {"idx": val, "valid": 1}
|
||||
@@ -120,43 +121,87 @@ class TestRealSimplifies(unittest.TestCase):
|
||||
View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
|
||||
class TestIndexExpressions(unittest.TestCase):
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.st = ShapeTracker((10, 10))
|
||||
self.numel = prod(self.st.shape)
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
self.sts = [ShapeTracker(base_shape, [View(base_shape, offset=offset)]) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [Variable.num(offset) for base_shape in shapes for offset in offsets]
|
||||
self.shapes = [shape for shape in shapes for offset in offsets]
|
||||
self.node_exprs = []
|
||||
self.idxs_exprs = []
|
||||
|
||||
def tearDown(self):
|
||||
assert self.expr(self.default_idx()) == self.st.expr_node()[0]
|
||||
for idx in [None, (0, 99), (7, 203), (2, 5), (0, 0), (0, self.numel-1), (self.numel, self.numel), (0, self.numel), (0, self.numel+1), (self.numel+100, self.numel+100)]:
|
||||
if idx is not None:
|
||||
idx = test_idx = Variable("idx", idx[0], idx[1])
|
||||
else:
|
||||
test_idx = self.default_idx()
|
||||
self.check_bounds(self.expr(test_idx))
|
||||
assert self.expr(test_idx) == self.st.expr_node(idx)[0]
|
||||
for st, offset, shape, node_expr, idxs_expr in zip(self.sts, self.offset, self.shapes, self.node_exprs, self.idxs_exprs):
|
||||
numel = prod(shape)
|
||||
assert node_expr(self.default_idx(st.shape)) == st.expr_node()[0]
|
||||
assert node_expr(self.default_idx(st.shape)) == st.expr_node(None)[0]
|
||||
assert node_expr(self.default_idx(st.shape)) == st.expr_node('idx')[0]
|
||||
self.check_bounds(node_expr(self.default_idx(st.shape)), offset, numel)
|
||||
for idx in [(0, numel-1), (7, 203), (2, 5), (0, 0), (numel, numel), (0, numel), (0, numel+1), (numel+100, numel+100)]:
|
||||
idx = Variable("idx", idx[0], idx[1])
|
||||
assert node_expr(idx) == st.expr_node(idx)[0]
|
||||
self.check_bounds(node_expr(idx), offset, numel)
|
||||
|
||||
def default_idx(self):
|
||||
return Variable("idx", 0, prod(self.st.shape)-1)
|
||||
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs()[0]
|
||||
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
|
||||
self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
|
||||
idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
|
||||
idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
|
||||
idx2s = [(0,0), (0, min(1, st.shape[2]-1)), (0, st.shape[2]-1), (min(3, st.shape[2]-1), min(6, st.shape[2]-1)), (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
|
||||
for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
|
||||
idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
|
||||
assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
|
||||
self.check_bounds(idxs_expr(idxs), offset, numel)
|
||||
|
||||
def check_bounds(self, expr):
|
||||
assert expr.min >= self.st.real_offset()
|
||||
assert expr.max <= self.st.real_offset() + self.numel - 1
|
||||
def default_idx(self, shape):
|
||||
return Variable("idx", 0, prod(shape)-1)
|
||||
|
||||
def test_noop(self):
|
||||
self.expr = lambda idx: idx%100
|
||||
def default_idxs(self, shape):
|
||||
return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)]
|
||||
|
||||
def check_bounds(self, expr, offset, numel):
|
||||
assert expr.min >= offset
|
||||
assert expr.max <= offset + numel - 1
|
||||
|
||||
def test_noop(self):
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset)
|
||||
|
||||
def test_permute(self):
|
||||
self.st.permute((1, 0))
|
||||
self.expr = lambda idx: idx%10*10 + idx//10%10
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset)
|
||||
|
||||
def test_reshape(self):
|
||||
self.st.reshape((10, 1, 10))
|
||||
self.expr = lambda idx: idx%10 + idx//10%10*10
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
|
||||
def test_reshape_expand(self):
|
||||
self.st.reshape((10, 1, 10))
|
||||
self.st.expand((10, 10, 10))
|
||||
self.expr = lambda idx: idx//100*10 + idx%10
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.reshape((base_shape[0], 1, base_shape[1]))
|
||||
st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st.reshape((base_shape[0]//5, 1, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
|
||||
def test_permute_reshape_2(self):
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
st.reshape((1, base_shape[0]//5, base_shape[1]*5))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset)
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -14,8 +14,10 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tu
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])] if len(shape) > 0 else []
|
||||
for i in range(1, len(shape)):
|
||||
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0):
|
||||
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
elif shape[i] == 1:
|
||||
continue
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
return tuple(ret)
|
||||
|
||||
Reference in New Issue
Block a user