mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
should mark the ones that are expected to work with expectedFailure, and delete and ones that are not expected to work
186 lines
6.5 KiB
Python
186 lines
6.5 KiB
Python
import unittest
|
|
from typing import List
|
|
from tinygrad.helpers import prod
|
|
from tinygrad.shape.view import View
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad import Variable
|
|
from test.unit.test_shapetracker import shapetracker_getitem
|
|
|
|
class MultiShapeTracker:
|
|
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
|
|
@property
|
|
def shape(self): return self.sts[0].shape
|
|
def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts]
|
|
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
|
|
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
|
|
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
|
|
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
|
|
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
|
|
|
|
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
|
|
if st1.shape != st2.shape: return False
|
|
if st1 == st2: return True
|
|
for i in range(0, prod(st1.shape)):
|
|
st1_off, st1_v = shapetracker_getitem(st1, i)
|
|
st2_off, st2_v = shapetracker_getitem(st2, i)
|
|
if st1_v != st2_v or (st1_off != st2_off and st1_v):
|
|
print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
|
|
print(st1)
|
|
print(st2)
|
|
return False
|
|
return True
|
|
|
|
class TestShapeTrackerBasics(unittest.TestCase):
|
|
def test_pad_shrink_removes_mask(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.pad(((0,2), (0,2)))
|
|
a = a.shrink(((0,10), (0,10)))
|
|
assert len(a.views) == 1 and a.views[-1].mask is None
|
|
|
|
def test_pad_shrink_leaves_mask(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.pad(((0,2), (0,2)))
|
|
a = a.shrink(((0,10), (0,11)))
|
|
assert len(a.views) == 1 and a.views[-1].mask is not None
|
|
|
|
def test_reshape_makes_same(self):
|
|
a = ShapeTracker.from_shape((2, 5))
|
|
x = a.pad( ((2, 0), (0, 0)) )
|
|
x = x.reshape( (2, 2, 5) )
|
|
x1 = x.reshape( (4, 5) )
|
|
x1 = x1.reshape( (2, 2, 5) )
|
|
assert x == x1.simplify()
|
|
|
|
def test_simplify_is_correct(self):
|
|
multiv = ShapeTracker(views=(View(shape=(15, 3), strides=(9, 1), offset=6, mask=None, contiguous=False),
|
|
View(shape=(4, 3), strides=(12, 4), offset=0, mask=None, contiguous=False)))
|
|
assert st_equal(multiv, multiv.simplify())
|
|
|
|
class TestShapeTrackerAdd(unittest.TestCase):
|
|
def test_simple_add_reshape(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.reshape((100,))
|
|
b = ShapeTracker.from_shape((100,))
|
|
assert a+b == b
|
|
|
|
def test_simple_add_permute(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.permute((1,0))
|
|
b = ShapeTracker.from_shape((10, 10))
|
|
b = b.permute((1,0))
|
|
assert a+b == ShapeTracker.from_shape((10, 10))
|
|
|
|
def test_plus_real1(self):
|
|
st = MultiShapeTracker([ShapeTracker.from_shape((15, 9))])
|
|
st.shrink( ((0, 15), (6, 9)) )
|
|
backup = st.sts[0]
|
|
st.sts.append(ShapeTracker.from_shape(backup.shape))
|
|
st.reshape( (45,) )
|
|
st.stride( (4,) )
|
|
st.reshape( (4, 3) )
|
|
assert st_equal(backup + st.sts[1], st.sts[0])
|
|
|
|
def test_off_by_one(self):
|
|
st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True),
|
|
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
|
st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True),
|
|
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
|
assert not (st_equal(st1, st2))
|
|
|
|
class TestShapeTrackerAddVariable(unittest.TestCase):
|
|
def test_self_add(self):
|
|
j = Variable("j", 0, 20).bind(10)
|
|
a = ShapeTracker.from_shape((10,10))
|
|
x = a.reshape((10, j))
|
|
out = x + x
|
|
assert out == x
|
|
|
|
def test_self_add_reshape(self):
|
|
j = Variable("j", 0, 20).bind(10)
|
|
a = ShapeTracker.from_shape((10,10))
|
|
x = a.reshape((10, j))
|
|
out = x.reshape((5, 2, j)) + x
|
|
assert out == x
|
|
|
|
def test_merge_symbolic_views(self):
|
|
var_i = Variable('i', 1, 10)
|
|
var_j = Variable('i', 1, 10)
|
|
vm1 = View(shape=(var_i, var_j, 3), strides=(3, 0, 1), offset=0, mask=None, contiguous=False)
|
|
vm2 = View(shape=(var_i, var_j, 3), strides=(var_j*3, 3, 1), offset=0, mask=None, contiguous=True)
|
|
ShapeTracker((vm1,)) + ShapeTracker((vm2,))
|
|
|
|
def test_merge_symbolic_views_2(self):
|
|
var_i = Variable('i', 1, 10)
|
|
var_j = Variable('j', 1, 10)
|
|
vm1 = View(shape=(var_i, var_j), strides=(0, 0), offset=0, mask=None, contiguous=False)
|
|
vm2 = View(shape=(var_i, var_j), strides=(var_j, 1), offset=0, mask=None, contiguous=True)
|
|
ret = (ShapeTracker((vm1,)) + ShapeTracker((vm2,))).reshape((var_i, var_j, 1))
|
|
ret_2 = ShapeTracker((vm1,)) + ShapeTracker((vm2,)).reshape((var_i, var_j, 1))
|
|
assert ret == ret_2
|
|
|
|
class TestShapeTrackerInvert(unittest.TestCase):
|
|
def test_invert_reshape(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
x = a.reshape((5, 20))
|
|
ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_invert_permute(self):
|
|
a = ShapeTracker.from_shape((5, 20))
|
|
x = a.permute((1,0))
|
|
ap = x + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_invert_permute_3(self):
|
|
a = ShapeTracker.from_shape((8, 4, 5))
|
|
x = a.permute((1,2,0))
|
|
ap = x + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_invert_real1(self):
|
|
a = ShapeTracker.from_shape((3, 6, 10))
|
|
x = a.reshape( (3, 3, 2, 10) )
|
|
x = x.permute( (2, 1, 3, 0) )
|
|
ap = x + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_cant_invert_expand(self):
|
|
a = ShapeTracker.from_shape((10, 1))
|
|
x = a.expand((10,10))
|
|
assert x.invert(a.shape) is None
|
|
|
|
def test_cant_invert_shrink(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
x = a.shrink(((0,10),(2,8)))
|
|
assert x.invert(a.shape) is None
|
|
|
|
def test_can_invert_flip(self):
|
|
a = ShapeTracker.from_shape((20, 10))
|
|
x = a.stride((-1,1))
|
|
ap = x + x.invert(a.shape)
|
|
assert st_equal(ap, a)
|
|
|
|
def test_can_invert_flip_permute(self):
|
|
a = ShapeTracker.from_shape((20, 10))
|
|
x = a.permute((1,0))
|
|
x = x.stride((-1,1))
|
|
ap = x + x.invert(a.shape)
|
|
assert st_equal(ap, a)
|
|
|
|
def test_cant_invert_stride(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
x = a.stride((2,2))
|
|
assert x.invert(a.shape) is None
|
|
|
|
def test_invert_failure(self):
|
|
a = ShapeTracker.from_shape((2, 5))
|
|
x = a.pad( ((2, 0), (0, 0)) )
|
|
x = x.reshape( (2, 2, 5) )
|
|
x = x.reshape( (4, 5) )
|
|
ap = x + x.invert(a.shape)
|
|
assert st_equal(ap, a)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|