improve matcher speed [run_process_replay] (#5438)

* improve matcher speed [run_process_replay]

* don't use arg set in ptx
This commit is contained in:
George Hotz
2024-07-12 20:02:19 -07:00
committed by GitHub
parent 03c2dc8bd7
commit fb3011ac61
5 changed files with 69 additions and 45 deletions

View File

@@ -1,7 +1,8 @@
from extra.models.resnet import ResNet50
from tinygrad import Tensor
from tinygrad.helpers import Profiling, Timing, getenv
from tinygrad.engine.realize import lower_schedule
from tinygrad.helpers import Profiling, Timing, getenv, dedup
from tinygrad.ops import MetaOps
from tinygrad.codegen.kernel import Kernel
if __name__ == "__main__":
mdl = ResNet50()
@@ -19,10 +20,29 @@ if __name__ == "__main__":
with Timing("***** model schedule in "):
sched = out.schedule()
# snakeviz /tmp/schedule.prof
asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.SINK])
uops = []
with Profiling(PROFILE):
with Timing("***** model uops in "):
for ast in asts:
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
uops.append((k.name, k.uops))
with Profiling(PROFILE, fn="/tmp/schedule.prof"):
with Timing("***** model lower in "):
eis = list(lower_schedule(sched))
with Timing("***** model linearize in "):
for _,u in uops: u.linearize()
#renderer = Device[Device.DEFAULT].renderer
#with Profiling(PROFILE, fn="/tmp/schedule.prof"):
# with Timing("***** model render in "):
# for n,u in uops: renderer.render(n, u)
# snakeviz /tmp/schedule.prof
#with Profiling(PROFILE, fn="/tmp/schedule.prof"):
# with Timing("***** model lower in "):
# eis = list(lower_schedule(sched))
# random makes this slow
#with Profiling(PROFILE):

View File

@@ -1,7 +1,7 @@
import unittest
from test.helpers import TestUOps
from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, UPat, _match
@@ -47,6 +47,7 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c4), None)
self.assertEqual(matcher.rewrite(c5), None)
@unittest.skip("this is not supported any more")
def test_arg_set(self):
matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)])
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
@@ -123,14 +124,14 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
#c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.NEG)
c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
self.assertEqual(matcher.rewrite(c4), c4)
#self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), None)
self.assertEqual(matcher.rewrite(c6), c6)

View File

@@ -3,9 +3,7 @@ import unittest
from tinygrad.engine.graph import print_tree
from tinygrad import Tensor, dtypes
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uopgraph import UPat
from tinygrad.ops import BinaryOps
from tinygrad.codegen.uops import UOp
import sys, io
@@ -43,6 +41,7 @@ ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)
4 ┗━┳ UOps.ALU UnaryOps.NEG\n\
5 ┗━━ UOps.CONST 2\n'
"""
x = UPat(UOp.alu(BinaryOps.ADD, UOp.var("x", dtypes.int), UOp.var("x", dtypes.int)))
assert self._capture_print(lambda: print_tree(x)) == '\
0 ━━ UOps.ALU : dtypes.int [<UOps.VAR: 2>, <UOps.VAR: 2>] BinaryOps.ADD None\n'
@@ -62,6 +61,7 @@ ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)
9 ┃ ┗━━ None None\n\
10 ┗━┳ UOps.GEP 3\n\
11 ┗━━ None None\n'
"""
if __name__ == "__main__":
unittest.main()