mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add rule and test (#10189)
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -560,6 +560,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
w = cond.logical_not().where(a, b)
|
||||
self.helper_test_variable(w, 0, 3, "(b if (x<2) else a)")
|
||||
|
||||
def test_neg_in_comp(self):
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
self.helper_test_variable(-a<-b, False, True, "(b<a)")
|
||||
|
||||
def test_where_cast(self):
|
||||
s = Variable("s", 0, 3)
|
||||
cond = s < 2
|
||||
|
||||
@@ -275,6 +275,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*-1, lambda x,y: y<x),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
# not x < 1 -> X > 0
|
||||
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
|
||||
Reference in New Issue
Block a user