little symbolic changes [pr] (#6849)

* little symbolic changes [pr]

* symbolic needs resolve too

* no resolve

* less change
This commit is contained in:
George Hotz
2024-10-02 17:12:30 +08:00
committed by GitHub
parent fc78716d31
commit 7214450c23
6 changed files with 34 additions and 21 deletions

View File

@@ -4,6 +4,9 @@ from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):
def assert_tuple_equal(self, x, y):
for a,b in zip(x,y): self.assertFalse(a != b)
def test_symbolic_st(self):
x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3))
@@ -31,11 +34,11 @@ class TestSymbolic(unittest.TestCase):
k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
assert st.shape == (i+j+k, 4)
self.assert_tuple_equal(st.shape, (i+j+k, 4))
assert st.real_strides() == (4, 1)
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
self.assert_tuple_equal(st.shape, (2*i+3, 3))
assert st.real_strides() == (3, 1)
def test_cat_dim1_strides(self):
@@ -44,10 +47,11 @@ class TestSymbolic(unittest.TestCase):
k = Variable("k", 1, 5).bind(4)
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
self.assert_tuple_equal(st.shape, (3, i+j+k))
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
class TestSymbolicVarVals(unittest.TestCase):
def assert_equal(self, x, y): self.assertFalse(x != y)
def test_var_vals_empty(self):
assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
@@ -58,7 +62,7 @@ class TestSymbolicVarVals(unittest.TestCase):
def test_var_vals_offset(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
assert st.views[-1].offset == x * 3
self.assert_equal(st.views[-1].offset, x * 3)
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_mask(self):

View File

@@ -39,6 +39,10 @@ class TestUOpResolve(unittest.TestCase):
u = UOp.const(dtypes.int, 4) > 7
self.assertFalse(u)
def test_ssimplify(self):
self.assertEqual((8 % UOp.const(dtypes.int, 4)).ssimplify(), 0)
self.assertEqual((8 * UOp.const(dtypes.int, 4)).ssimplify(), 32)
def test_ambiguous_less_than(self):
u = UOp.define_var("i", dtypes.pyint, 1, 10)
self.assertTrue(resolve(u < 4))