pretty print lazy op per default (#5505)

* pretty lop

* min diff

* walrus

* fix

* min diff

* simplify

* pretty helper function

* ws

* pretty uop upat

* tests

* stricter tests

* test passes

* ws

* stronger upat test

* delete print_tree

* min diff

* stricter exp test

* fix merge

* stronger uops eval test

* +readable and deep upat test

* +readable and deep upat test

* sort inv fix

* fix

* revert allowed_len
This commit is contained in:
kormann
2024-07-18 18:34:08 +02:00
committed by GitHub
parent c30092e56d
commit 2c4add6844
21 changed files with 80 additions and 126 deletions

View File

@@ -10,7 +10,6 @@ from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOp
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.graph import print_tree
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
from tinygrad.ops import LazyOp, UnaryOps, BufferOps
@@ -121,7 +120,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
SEED = getenv("SEED", 42)
random.seed(SEED)
np.random.seed(SEED)
print_tree(lin.ast)
print(lin.ast)
print(lin.colored_shape())
seen_uops = {}
last_lins = [lin]

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python3
import subprocess, pickle, shlex, sys, os
from typing import Dict, List, Tuple
from tinygrad.engine.graph import print_tree
from tinygrad.helpers import colored
from tinygrad.ops import LazyOp
@@ -27,7 +26,7 @@ if __name__ == "__main__":
except AssertionError as e:
print(colored("FAILED FOR AST: ", "red"))
print("expected:")
for op in m: print_tree(op)
for op in m: print(op)
print("got:")
for op in f: print_tree(op)
for op in f: print(op)
raise e

View File

@@ -4,7 +4,6 @@ from extra.optimization.helpers import kern_str_to_lin
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.helpers import colored
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.graph import print_tree
from tinygrad.engine.search import time_linearizer
# Use this with the LOGKERNS options to verify that all executed kernels are valid and evaluate to the same ground truth results
@@ -51,7 +50,6 @@ if __name__ == "__main__":
failures = defaultdict(list)
for i, test_lin in enumerate(test_lins):
print(f"testing kernel {i}")
print_tree(test_lin.ast)
print(test_lin.ast)
print(test_lin.applied_opts)
unoptimized_lin = Kernel(test_lin.ast)

View File

@@ -298,7 +298,7 @@ class TestEqStrDType(unittest.TestCase):
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
self.assertEqual(str(PtrDType(dtypes.float32)), "PtrDType(dtypes.float)")
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)

View File

@@ -16,7 +16,6 @@ from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.engine.graph import print_tree
from tinygrad.helpers import DEBUG, prod, Context, getenv, CI, flatten, dedup
from tinygrad.dtype import DType, dtypes
@@ -1235,7 +1234,7 @@ def _temp_create_multireduce_ast(r0:Tensor, r1:Tensor, replace_idxs:Dict[int,Ten
out = merge(op0, _deep_replace(op1, op0_loads))
# limitation: only tests single output
op = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, s0[-1].ast.src[-1].arg.dtype, s0[-1].ast.src[-1].arg.st))
if DEBUG >= 3: print_tree(op)
if DEBUG >= 3: print(op)
return op,
def check_fused_tc_opt(tc:TensorCore, r0:Tensor, r1:Tensor, inputs:List[Tensor]):

View File

@@ -1,9 +1,9 @@
import unittest
import unittest, itertools
from test.helpers import TestUOps
from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps # noqa: F401
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, UPat, _match
from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, UPat, _match, constant_folder
class TestPatternMatcher(TestUOps):
def test_simple_match(self):
@@ -175,5 +175,26 @@ class TestPatternMatcher(TestUOps):
self.assert_equiv_uops(e2, uops.uops[1])
self.assert_equiv_uops(e3, uops.uops[2])
def _assert_eq_upat(self, a:UPat, b:UPat):
assert (sorted(map(str,a.op)) if a.op else [] == (sorted(map(str,b.op)) if b.op else []))
assert (sorted(a.dtype) if a.dtype else [] == (sorted(b.dtype) if b.dtype else []))
assert (a.name, type(a.src)) == (b.name, type(b.src))
def simple_src(u:UPat):
if u.src is None: return []
if isinstance(u.src, itertools.repeat): return next(u.src[0])
return u.src[0]
for a,b in zip(simple_src(a), simple_src(b)): self._assert_eq_upat(a, b)
def test_upat_str(self):
dtypes._float2 = dtypes.float.vec(2)
upat = UPat(UOps.CONST, name="x", dtype=dtypes.float)
assert str(upat) == str(eval(str(upat)))
for i in range(20): upat = UPat(UOps.ALU, name="x", src=[upat, upat], arg=BinaryOps.ADD)
assert len(str(upat)) < 10_000
assert str(eval(str(upat))) == str(upat)
for rule in constant_folder.pdict.values():
pat = rule[0][0]
self._assert_eq_upat(pat, eval(str(pat)))
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,67 +0,0 @@
#%%
import unittest
from tinygrad.engine.graph import print_tree
from tinygrad import Tensor, dtypes
from tinygrad.codegen.uops import UOp
import sys, io
class TestPrintTree(unittest.TestCase):
def _capture_print(self, fn):
capturedOutput = io.StringIO()
sys.stdout = capturedOutput
fn()
sys.stdout = sys.__stdout__
return capturedOutput.getvalue()
def test_print_uop(self):
x = Tensor.arange(10).schedule()[-1].ast.src[0]
output = self._capture_print(lambda: print_tree(x))
assert output == '\
0 ━┳ BufferOps.STORE MemBuffer(idx=0, dtype=dtypes.int, \
st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))\n\
1 ┗━┳ BinaryOps.ADD None\n\
2 ┣━┳ ReduceOps.SUM (1,)\n\
3 ┃ ┗━━ BufferOps.CONST ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTrac\
ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19))\
, contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))\n\
4 ┗━━ BufferOps.CONST ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10,\
1), strides=(0, 0), offset=0, mask=None, contiguous=False),)))\n'
x = UOp.var("x", dtypes.int)
x = (x + x) - UOp.const(dtypes.int, 2)
output = self._capture_print(lambda: print_tree(x))
assert output == '\
0 ━┳ UOps.ALU BinaryOps.ADD\n\
1 ┣━┳ UOps.ALU BinaryOps.ADD\n\
2 ┃ ┣━━ UOps.VAR x\n\
3 ┃ ┗━━ UOps.VAR x\n\
4 ┗━┳ UOps.ALU UnaryOps.NEG\n\
5 ┗━━ UOps.CONST 2\n'
"""
x = UPat(UOp.alu(BinaryOps.ADD, UOp.var("x", dtypes.int), UOp.var("x", dtypes.int)))
assert self._capture_print(lambda: print_tree(x)) == '\
0 ━━ UOps.ALU : dtypes.int [<UOps.VAR: 2>, <UOps.VAR: 2>] BinaryOps.ADD None\n'
x = UPat.compile(UOp.store(UOp.var("buf"), UOp.var("idx"),
UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store)
assert self._capture_print(lambda: print_tree(x)) == '\
0 ━┳ UOps.STORE None\n\
1 ┣━━ None None\n\
2 ┣━━ None None\n\
3 ┗━┳ UOps.CAST None\n\
4 ┣━┳ UOps.GEP 0\n\
5 ┃ ┗━━ None None\n\
6 ┣━┳ UOps.GEP 1\n\
7 ┃ ┗━━ None None\n\
8 ┣━┳ UOps.GEP 2\n\
9 ┃ ┗━━ None None\n\
10 ┗━┳ UOps.GEP 3\n\
11 ┗━━ None None\n'
"""
if __name__ == "__main__":
unittest.main()

View File

@@ -11,7 +11,6 @@ from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
from tinygrad.helpers import DEBUG, flatten, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.graph import print_tree
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported
@@ -33,7 +32,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("kernel", i+1)
print_tree(s.ast)
print(s.ast)
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:

View File

@@ -5,7 +5,6 @@ from tinygrad.dtype import PtrDType
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uopgraph import UOpGraph, PatternMatcher, graph_rewrite
from tinygrad.engine.graph import print_tree # noqa: F401 # pylint: disable=unused-import
simple_pm = PatternMatcher([
(UOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
@@ -386,7 +385,7 @@ class TestExpander(unittest.TestCase):
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (e1,e2), ReduceOps.SUM)
sink = expander_rewrite(sink)
print_tree(sink)
print(sink)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -5,13 +5,13 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uopgraph import UOpGraph
from test.helpers import is_dtype_supported
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
def _uops_to_prg(uops_list, print_uops=False):
uops = UOpGraph(uops_list)
@@ -333,5 +333,18 @@ class TestUOpCompare(unittest.TestCase):
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
class TestUOpStr(TestEqUOps):
def test_uop_str(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
for _ in range(20): a = a + a
assert len(str(a)) < 10_000, "exponential string growth"
assert str(eval(str(a))) == str(a)
t = Tensor.arange(10)
t = t + t * Tensor.rand(10)
# nice big complicated uop
sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink
self.assert_equiv_uops(sink, eval(str(sink)))
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import unittest
from tinygrad.codegen.kernel import Kernel
#from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.graph import print_tree
from tinygrad.helpers import DEBUG
from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, MetaOps, verify_lazyop
from tinygrad.shape.shapetracker import ShapeTracker
@@ -13,7 +11,7 @@ class InvalidLazyOpException(Exception): pass
def lower(*ast:LazyOp):
sink_ast = LazyOp(MetaOps.KERNEL, ast)
if DEBUG >= 3:
for op in ast: print_tree(op)
for op in ast: print(op)
try: verify_lazyop(sink_ast)
except AssertionError: raise InvalidLazyOpException()
k = Kernel(sink_ast)