constant folding (#3675)

* constant fold

* bool math

* fix ptx
This commit is contained in:
George Hotz
2024-03-10 14:47:24 -07:00
committed by GitHub
parent 25aede6fd9
commit 44a67bf783
3 changed files with 18 additions and 9 deletions

View File

@@ -146,6 +146,16 @@ class TestExecALU(TestUOps):
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0))
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0))
def test_bool_neg(self):
self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (False,)), True)
self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (True,)), False)
def test_bool_cmplt(self):
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)