diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 22d426fb60..e51f85b1fa 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -436,6 +436,12 @@ class TestVizProfiler(BaseTestViz): sz = len(get_profile(prof)) self.assertLessEqual(sz/n_events, 26) + def test_calltrace(self): + def fxn(): return Tensor.empty(10).mul(2).realize() + fxn() + trace = get_viz_list()[0]["steps"][0]["trace"] + assert any(fxn.__code__.co_filename == f and fxn.__code__.co_firstlineno == l for f,l,*_ in trace), str(trace) + # can pack up to 1hr 11 min of trace events def test_trace_duration(self): dur_mins = 72 diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 65f649739d..ba3d451655 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -240,11 +240,29 @@ class Profiling(contextlib.ContextDecorator): def perf_counter_us() -> decimal.Decimal: return decimal.Decimal(time.perf_counter_ns())/1000 +@functools.cache +def lines(fn) -> list[str]: + try: + with open(fn, encoding="utf-8") as f: return f.readlines() + except (FileNotFoundError, OSError): return [] + +def printable(loc:tuple[str, int]) -> str: + try: return lines(loc[0])[loc[1]-1].strip() + except IndexError: return "" + +def get_stacktrace(frm, max_frames=30) -> tuple[tuple, ...]: + ret:list[tuple] = [] + for i in range(max_frames): + if (frm:=frm.f_back) is None: break + ret.append(((fc:=frm.f_code).co_filename, frm.f_lineno, fc.co_name, printable((fc.co_filename, frm.f_lineno)))) + return tuple(ret) + @dataclass(frozen=True) class TracingKey: display_name:str # display name of this trace event keys:tuple[Any, ...]=() # optional keys to search for related traces ret:Any=None + tb:tuple[tuple, ...]|None=field(default_factory=lambda: get_stacktrace(sys._getframe(1)) if VIZ else None) class ProfileEvent: pass diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ef4e4cce7e..322e8b5f9b 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -8,7 +8,7 @@ from tinygrad.mixin import OpMixin from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI -from tinygrad.helpers import strip_parens, colored, ansilen +from tinygrad.helpers import strip_parens, colored, ansilen, printable if TYPE_CHECKING: from tinygrad.device import Buffer, MultiBuffer @@ -865,14 +865,6 @@ def get_location() -> tuple[str, int]: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno -@functools.cache -def lines(fn) -> list[str]: - with open(fn) as f: return f.readlines() - -def printable(loc:tuple[str, int]) -> str: - try: return lines(loc[0])[loc[1]-1].strip() - except FileNotFoundError: return "" - class UPat(OpMixin): __slots__ = ("op", "dtype", "arg", "name", "src") def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None, diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index b989260cea..415aaa742e 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -551,6 +551,7 @@ document.getElementById("zoom-to-fit-btn").addEventListener("click", () => { // **** main VIZ interfacae +const pathLink = (fp, lineno) => d3.create("a").attr("href", "vscode://file/"+fp+":"+lineno).text(`${fp.split("/").at(-1)}:${lineno}`); function codeBlock(st, language, { loc, wrap }={}) { const code = document.createElement("code"); // plaintext renders like a terminal print, otherwise render with syntax highlighting @@ -559,11 +560,7 @@ function codeBlock(st, language, { loc, wrap }={}) { code.className = "hljs"; const ret = document.createElement("pre"); if (wrap) ret.className = "wrap"; - if (loc != null) { - const link = ret.appendChild(document.createElement("a")); - link.href = "vscode://file/"+loc.join(":"); - link.textContent = `${loc[0].split("/").at(-1)}:${loc[1]}`+"\n\n"; - } + if (loc != null) ret.appendChild(pathLink(loc[0], loc[1]).style("margin-bottom", "4px").node()); ret.appendChild(code); return ret; } @@ -763,6 +760,15 @@ async function main() { // ** right sidebar code blocks const codeElement = codeBlock(ret[currentRewrite].uop, "python", { wrap:false }); metadata.replaceChildren(toggleLabel, codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeElement); + if (step.trace) { + const trace = d3.create("pre").append("code").classed("hljs", true); + for (let i=step.trace.length-1; i>=0; i--) { + const [fp, lineno, fn, code] = step.trace[i]; + trace.append("div").style("margin-bottom", "2px").style("display","flex").text(fn+" ").append(() => pathLink(fp, lineno).node()); + trace.append("div").html(hljs.highlight(code, { language: "python" }).value).style("margin-bottom", "1ex"); + } + metadata.insertBefore(trace.node().parentNode, codeElement); + } // ** rewrite steps if (step.match_count >= 1) { const rewriteList = metadata.appendChild(document.createElement("div")); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1a19ec4fcd..952928699f 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -7,7 +7,8 @@ from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, TypeVar, Generator, Callable from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp -from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender +from tinygrad.helpers import printable +from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.uop.ops import print_uops, range_start, multirange_str from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device from tinygrad.renderer import ProgramSpec @@ -30,7 +31,7 @@ ref_map:dict[Any, int] = {} def get_rewrites(t:RewriteTrace) -> list[dict]: ret = [] for i,(k,v) in enumerate(zip(t.keys, t.rewrites)): - steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc), + steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc), "trace":k.tb if j == 0 else None, "query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)] if isinstance(k.ret, ProgramSpec): steps.append({"name":"View UOp List", "query":f"/render?ctx={i}&fmt=uops", "depth":0})