remove nop, use upat [run_process_replay] (#6489)

* remove nop, use upat [run_process_replay]

* mypy passes

* no wonder nothing worked

* fixes
This commit is contained in:
George Hotz
2024-09-12 12:16:19 +08:00
committed by GitHub
parent f12f0857d8
commit 76487a3533
8 changed files with 165 additions and 171 deletions

View File

@@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase):
def test_arg(self):
matcher = PatternMatcher([
(UPat(UOps.CONST, 0, name="x"), lambda x: x),
(UPat(UOps.CONST, False, name="x"), lambda x: x),
(UPat(UOps.ALU, BinaryOps.MAX, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=0, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=False, name="x"), lambda x: x),
(UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
@@ -47,7 +47,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)], name="x"),
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(UOps.CONST, dtypes.int, arg=1)