mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simple UOp lt/ge folding (#5657)
works if lhs is a DEFINE_VAR. folds trivial x < -math.inf now, need to change SPECIAL to use DEFINE_VAR to fold more
This commit is contained in:
@@ -30,8 +30,7 @@ def Variable(expr, nmin, nmax):
|
||||
# TODO: fix DEFINE_VAR to not need this
|
||||
class TempVar:
|
||||
def __init__(self, x): self.expr = x
|
||||
#return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, tuple(), TempVar(expr))
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
|
||||
class Node:
|
||||
@staticmethod
|
||||
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
@@ -73,7 +72,6 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 3), 1, 1, "1")
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 3, 8), 2), 1, 1, "1")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lt(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 77), 1, 1, "1")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 3, 8), 9), 1, 1, "1")
|
||||
@@ -272,11 +270,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_ge_remove(self):
|
||||
self.helper_test_variable(create_ge_node(Variable("a", 0, 6), 25), 0, 0, "0")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lt_remove(self):
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), -3), 0, 0, "0")
|
||||
self.helper_test_variable(create_lt_node(Variable("a", 0, 6), 3), 0, 1, "(a<3)")
|
||||
|
||||
Reference in New Issue
Block a user