mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pm_render [pr] (#7430)
* pm_render [pr] * test fixes * use gep, not src * ptx only symbolic, not sym * move cast rules
This commit is contained in:
@@ -456,58 +456,58 @@ class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4
|
||||
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
|
||||
|
||||
def test_contract_simple(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
|
||||
self.assertEqual(sink.op, UOps.VCONST)
|
||||
self.assertTupleEqual(sink.arg, (0,1,2,3))
|
||||
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
|
||||
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14])
|
||||
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
|
||||
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6])
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6])
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7])
|
||||
assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
|
||||
@@ -520,25 +520,26 @@ class TestExpander(unittest.TestCase):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
|
||||
assert sink.src[0] == sink.src[1]
|
||||
assert sink.src[0] != sink.src[2]
|
||||
assert sink.src[6] == sink.src[7]
|
||||
assert sink.op is UOps.VCONST and len(sink.arg) == 8
|
||||
assert sink.arg[0] == sink.arg[1]
|
||||
assert sink.arg[0] != sink.arg[2]
|
||||
assert sink.arg[6] == sink.arg[7]
|
||||
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+e2)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15])
|
||||
self.assertEqual(sink.op, UOps.EXPAND)
|
||||
self.assertEqual(sink.src[0].op, UOps.VCONST)
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
|
||||
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
|
||||
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16
|
||||
assert sink.arg == ((1, 4), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
|
||||
|
||||
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
|
||||
|
||||
@@ -621,7 +622,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
|
||||
self.assertEqual(single_load.src[1].op, UOps.CONST)
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
|
||||
Reference in New Issue
Block a user