wip: tracking pattern matcher [run_process_replay] (#6225)

* wip: tracking pattern matcher

* better

* proper dedup

* timing

* early reject

* mergable match stats

* TrackedPattenMatcher

* fix TrackedPattenMatcher

* cleanups

* clean that too

* remove early_reject

* Revert "remove early_reject"

This reverts commit dc2aef14b8f5da58f5ec9566daf252513cac394c.

* total

* sort by time

* match_stats cleanup
This commit is contained in:
George Hotz
2024-08-21 11:57:26 -07:00
committed by GitHub
parent a666450e4d
commit c3168952f0
2 changed files with 66 additions and 35 deletions

View File

@@ -2,7 +2,6 @@ import unittest, itertools
from test.helpers import TestUOps
from tinygrad.dtype import dtypes
from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.codegen.uopgraph import constant_folder
class TestPatternMatcher(TestUOps):
def test_simple_match(self):
@@ -161,25 +160,5 @@ class TestPatternMatcher(TestUOps):
return u.src[0]
for a,b in zip(simple_src(a), simple_src(b)): self._assert_eq_upat(a, b)
def test_upat_str(self):
dtypes._float2 = dtypes.float.vec(2)
dtypes._float4 = dtypes.float.vec(4)
dtypes._float8 = dtypes.float.vec(8)
dtypes._float16 = dtypes.float.vec(16)
dtypes._half2 = dtypes.half.vec(2)
dtypes._half4 = dtypes.half.vec(4)
dtypes._half8 = dtypes.half.vec(8)
dtypes._half16 = dtypes.half.vec(16)
upat = UPat(UOps.CONST, name="x", dtype=dtypes.float)
assert str(upat) == str(eval(str(upat)))
evpat:UPat = eval(repr(UPat(src = [UPat(name='a'), UPat(name='b')])))
assert len(evpat.src) == 2
for i in range(20): upat = UPat(UOps.ALU, name="x", src=[upat, upat], arg=BinaryOps.ADD)
assert len(str(upat)) < 10_000
assert str(eval(str(upat))) == str(upat)
for rules in constant_folder.pdict.values():
for pat,_ in rules:
self._assert_eq_upat(pat, eval(str(pat)))
if __name__ == '__main__':
unittest.main(verbosity=2)