diff --git a/extra/viz/README.md b/extra/viz/README.md index c665f5b6b7..25f327f5c1 100644 --- a/extra/viz/README.md +++ b/extra/viz/README.md @@ -7,7 +7,7 @@ Supported on all backends. Flags: VIZ=-1 to only save the trace to a file, VIZ=1 also launches a web server. 1. Set VIZ to -1 to save the trace. -2. Use `extra/viz/cli.py` to inspect the trace files. +2. Use `extra/viz/cli.py` to inspect the trace files. Set NO_COLOR=1 to disable colored output. ## Inspect runtime profiling diff --git a/extra/viz/cli.py b/extra/viz/cli.py index 74bb5b4269..9b3e6b0a45 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -5,7 +5,7 @@ if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL) from typing import Iterator from tinygrad.viz import serve as viz from tinygrad.uop.ops import RewriteTrace -from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap +from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR # profile decoder used in CLI and tests def decode_profile(data:bytes) -> dict: @@ -55,7 +55,7 @@ def get(data:dict, key:str): def main(args) -> None: viz.load_rewrites(viz_data:=viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))) - def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s + def format_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s if args.profile: events:list = viz.load_pickle(args.profile_path, default=[]) @@ -86,7 +86,7 @@ def main(args) -> None: assert isinstance(e.name, TracingKey) op_name, info = e.name.display_name, e.name.ret or "" color = next((v for k,v in viz.wave_colors.items() if k in op_name), None) - op_str = hex_colored(op_name, color) if color and not args.no_color else op_name + op_str = hex_colored(op_name, color) if color and not NO_COLOR else op_name phase, delay = None, 0 idx = next(pkt_idxs.setdefault(e.device, itertools.count())) if e.device.startswith("WAVE"): @@ -161,7 +161,7 @@ def main(args) -> None: if args.item is None: for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else '')) else: - data = viz.get_render(data, get(steps, args.item)["query"]) + data = viz.get_render(viz_data, get(steps, args.item)["query"]) if isinstance(data.get("value"), Iterator): for m in data["value"]: if m.get("uop"): print(f"Input UOp:\n{m['uop']}") @@ -169,7 +169,7 @@ def main(args) -> None: loc = pathlib.Path(m["upat"][0][0]) print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}") for line in m["diff"]: - print(line if args.no_color else colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) + print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) if data.get("src") is not None: print(data["src"]) def get_arg_parser() -> argparse.ArgumentParser: @@ -180,7 +180,6 @@ def get_arg_parser() -> argparse.ArgumentParser: g_opts = parser.add_argument_group("optional args") g_opts.add_argument("-s", "--src", type=str, default=None, metavar="NAME", help="Select a data source (default: list all sources)") g_opts.add_argument("-i", "--item", type=str, default=None, metavar="NAME", help="Select an item within the source (default: list all items)") - g_opts.add_argument("--no-color", action="store_true", help="Turn off colored names") g_opts.add_argument("--profile-path", type=pathlib.Path, metavar="PATH", help="Path to profile.pkl (optional file, default: latest profile)", default=pathlib.Path(temp("profile.pkl", append_user=True))) g_opts.add_argument("--rewrites-path", type=pathlib.Path, metavar="PATH", help="Path to rewrites.pkl (optional file, default: latest rewrites)", diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 06c9655044..157066359f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -32,6 +32,7 @@ def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) def all_same(items:Sequence): return all(x == items[0] for x in items) # works for empty input def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t) def colored(st, color:str|None, background=False): # replace the termcolor library + if NO_COLOR: return st colors = ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'] return f"\u001b[{10*background+60*(color.upper() == color)+30+colors.index(color.lower())}m{st}\u001b[0m" if color is not None else st def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow') @@ -224,7 +225,7 @@ class _DEV(ContextVar): DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0) JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32) -WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) +WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)