mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
apply view.py patch (#1844)
This commit is contained in:
@@ -70,7 +70,7 @@ class CheckingShapeTracker:
|
||||
|
||||
class TestRealIssues(unittest.TestCase):
|
||||
def test_reshape_doesnt_multiview(self):
|
||||
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
|
||||
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
|
||||
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
@@ -84,21 +84,21 @@ class TestRealDoesntSimplify(unittest.TestCase):
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((8, 6, 11), views=[
|
||||
View((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View((8, 6, 11), (66, 11, 1), 0, None)])
|
||||
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)])
|
||||
assert self.st.real_strides() == (33, None, 1)
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((4, 4, 3, 3), views=[
|
||||
View((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
|
||||
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((16, 32, 4), views=[
|
||||
View((2048,), (1,), 0, ((0, 512),)),
|
||||
View((16, 32, 4), (128, 4, 1), 0, None)])
|
||||
View.create((2048,), (1,), 0, ((0, 512),)),
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)])
|
||||
st = self.st.real_strides()
|
||||
print(self.st, st)
|
||||
assert st == (None, 4, 1)
|
||||
@@ -113,13 +113,13 @@ class TestRealSimplifies(unittest.TestCase):
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[
|
||||
View((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
|
||||
View((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
|
||||
View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[
|
||||
View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user