constant folding (#3675)

* constant fold

* bool math

* fix ptx
This commit is contained in:
George Hotz
2024-03-10 14:47:24 -07:00
committed by GitHub
parent 25aede6fd9
commit 44a67bf783
3 changed files with 18 additions and 9 deletions

View File

@@ -66,7 +66,7 @@ class TestLinearizer(unittest.TestCase):
b_bufs = [u.uop for u in lin.uops.uops[-2].vin[1].vin]
assert a_bufs == [UOps.LOAD, UOps.CONST]
assert b_bufs == [UOps.CONST, UOps.CONST]
assert b_bufs == [] # [UOps.CONST, UOps.CONST] will be folded
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
@@ -126,7 +126,6 @@ class TestLinearizer(unittest.TestCase):
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
assert num_ops == 0, "more alu uops than needed"
@unittest.skip("constant folding not supported yet")
def test_constant_fold(self):
a, b = Tensor(2), Tensor(3)
r = a * b
@@ -216,7 +215,7 @@ class TestLinearizer(unittest.TestCase):
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
arg=TernaryOps.WHERE).uop == UOps.CONST
arg=TernaryOps.WHERE).arg == c0.arg
def helper_realized_ast(r:Tensor):
s = create_schedule([r.lazydata])