viz early serialize to json + don't viz rewrites on const [pr] (#8435)

This commit is contained in:
qazal
2024-12-28 15:32:25 +02:00
committed by GitHub
parent 90ce2c6029
commit da2fa0b37f
2 changed files with 11 additions and 9 deletions

View File

@@ -15,7 +15,7 @@ def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
assert len(contexts[0]) == 1
k = get_metadata(keys, contexts)[0][0]
g = get_details(*k)
return g.graphs[1:]
return g.uops[1:]
class TestViz(unittest.TestCase):
def setUp(self):

View File

@@ -34,8 +34,9 @@ class GraphRewriteMetadata:
@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
graphs: list[UOp]
"""Sink at every step of graph_rewrite"""
uops: list[UOp]
graphs: list[dict]
"""Sink at every step of graph_rewrite + the json serialized version"""
diffs: list[list[str]]
""".diff style before and after of the rewritten UOp child"""
changed_nodes: list[list[int]]
@@ -88,10 +89,10 @@ def _replace_uop(base:UOp, replaces:dict[UOp, UOp]) -> UOp:
@functools.lru_cache(None)
def _prg(k:Kernel): return k.to_program().src
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None)
g = GraphRewriteDetails(**asdict(metadata), uops=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[])
replaces: dict[UOp, UOp] = {}
sink = g.graphs[0]
g.graphs.append(uop_to_json(sink:=g.uops[0]))
for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches):
u0 = pickle.loads(u0_b)
# if the match didn't result in a rewrite we move forward
@@ -104,9 +105,10 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -
# sanity check
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST])
g.graphs.append(new_sink_js:=uop_to_json(new_sink))
g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js])
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
g.graphs.append(sink:=new_sink)
g.uops.append(sink:=new_sink)
return g
# Profiler API
@@ -158,7 +160,7 @@ class Handler(BaseHTTPRequestHandler):
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]}
jret: Any = {**asdict(g), "uops": [pcall(str,x) for x in g.uops]}
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
ret, content_type = json.dumps(jret).encode(), "application/json"
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"