mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
little symbolic changes [pr] (#6849)
* little symbolic changes [pr] * symbolic needs resolve too * no resolve * less change
This commit is contained in:
@@ -4,6 +4,9 @@ from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def assert_tuple_equal(self, x, y):
|
||||
for a,b in zip(x,y): self.assertFalse(a != b)
|
||||
|
||||
def test_symbolic_st(self):
|
||||
x = Variable("x", 1, 100)
|
||||
st = ShapeTracker.from_shape((x, 3))
|
||||
@@ -31,11 +34,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
k = Variable("k", 1, 5).bind(3)
|
||||
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (i+j+k, 4)
|
||||
self.assert_tuple_equal(st.shape, (i+j+k, 4))
|
||||
assert st.real_strides() == (4, 1)
|
||||
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (2*i+3, 3)
|
||||
self.assert_tuple_equal(st.shape, (2*i+3, 3))
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
def test_cat_dim1_strides(self):
|
||||
@@ -44,10 +47,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
k = Variable("k", 1, 5).bind(4)
|
||||
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
||||
self.assert_tuple_equal(st.shape, (3, i+j+k))
|
||||
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
|
||||
|
||||
class TestSymbolicVarVals(unittest.TestCase):
|
||||
def assert_equal(self, x, y): self.assertFalse(x != y)
|
||||
def test_var_vals_empty(self):
|
||||
assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
|
||||
|
||||
@@ -58,7 +62,7 @@ class TestSymbolicVarVals(unittest.TestCase):
|
||||
def test_var_vals_offset(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
|
||||
assert st.views[-1].offset == x * 3
|
||||
self.assert_equal(st.views[-1].offset, x * 3)
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
|
||||
def test_var_vals_mask(self):
|
||||
|
||||
@@ -39,6 +39,10 @@ class TestUOpResolve(unittest.TestCase):
|
||||
u = UOp.const(dtypes.int, 4) > 7
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_ssimplify(self):
|
||||
self.assertEqual((8 % UOp.const(dtypes.int, 4)).ssimplify(), 0)
|
||||
self.assertEqual((8 * UOp.const(dtypes.int, 4)).ssimplify(), 32)
|
||||
|
||||
def test_ambiguous_less_than(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
self.assertTrue(resolve(u < 4))
|
||||
|
||||
Reference in New Issue
Block a user