pm_render [pr] (#7430)

* pm_render [pr]

* test fixes

* use gep, not src

* ptx only symbolic, not sym

* move cast rules
This commit is contained in:
George Hotz
2024-10-31 14:04:50 +07:00
committed by GitHub
parent 8fff8fc3e7
commit 17c9a9fde4
4 changed files with 50 additions and 47 deletions

View File

@@ -456,58 +456,58 @@ class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6])
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
def test_contract_simple(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
self.assertEqual(sink.op, UOps.VCONST)
self.assertTupleEqual(sink.arg, (0,1,2,3))
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
def test_contract_axis_2_big(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4])
self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14])
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
def test_contract_multi_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6])
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6])
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7])
assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
def test_contract_no_expand(self):
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
@@ -520,25 +520,26 @@ class TestExpander(unittest.TestCase):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
assert sink.src[0] == sink.src[1]
assert sink.src[0] != sink.src[2]
assert sink.src[6] == sink.src[7]
assert sink.op is UOps.VCONST and len(sink.arg) == 8
assert sink.arg[0] == sink.arg[1]
assert sink.arg[0] != sink.arg[2]
assert sink.arg[6] == sink.arg[7]
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+e2)
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15])
self.assertEqual(sink.op, UOps.EXPAND)
self.assertEqual(sink.src[0].op, UOps.VCONST)
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
@@ -621,7 +622,7 @@ class TestLoadStoreFolder(unittest.TestCase):
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
self.assertEqual(single_load.src[1].op, UOps.CONST)
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())