mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
viz early serialize to json + don't viz rewrites on const [pr] (#8435)
This commit is contained in:
@@ -15,7 +15,7 @@ def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
|
||||
assert len(contexts[0]) == 1
|
||||
k = get_metadata(keys, contexts)[0][0]
|
||||
g = get_details(*k)
|
||||
return g.graphs[1:]
|
||||
return g.uops[1:]
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -34,8 +34,9 @@ class GraphRewriteMetadata:
|
||||
@dataclass
|
||||
class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
"""Full details about a single call to graph_rewrite"""
|
||||
graphs: list[UOp]
|
||||
"""Sink at every step of graph_rewrite"""
|
||||
uops: list[UOp]
|
||||
graphs: list[dict]
|
||||
"""Sink at every step of graph_rewrite + the json serialized version"""
|
||||
diffs: list[list[str]]
|
||||
""".diff style before and after of the rewritten UOp child"""
|
||||
changed_nodes: list[list[int]]
|
||||
@@ -88,10 +89,10 @@ def _replace_uop(base:UOp, replaces:dict[UOp, UOp]) -> UOp:
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Kernel): return k.to_program().src
|
||||
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
|
||||
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None)
|
||||
g = GraphRewriteDetails(**asdict(metadata), uops=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
|
||||
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[])
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
sink = g.graphs[0]
|
||||
g.graphs.append(uop_to_json(sink:=g.uops[0]))
|
||||
for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches):
|
||||
u0 = pickle.loads(u0_b)
|
||||
# if the match didn't result in a rewrite we move forward
|
||||
@@ -104,9 +105,10 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -
|
||||
# sanity check
|
||||
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST])
|
||||
g.graphs.append(new_sink_js:=uop_to_json(new_sink))
|
||||
g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js])
|
||||
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
g.graphs.append(sink:=new_sink)
|
||||
g.uops.append(sink:=new_sink)
|
||||
return g
|
||||
|
||||
# Profiler API
|
||||
@@ -158,7 +160,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
|
||||
jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]}
|
||||
jret: Any = {**asdict(g), "uops": [pcall(str,x) for x in g.uops]}
|
||||
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
|
||||
ret, content_type = json.dumps(jret).encode(), "application/json"
|
||||
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
||||
|
||||
Reference in New Issue
Block a user