pass viz render errors (#7369)

* pass viz render errors

* pcall
This commit is contained in:
qazal
2024-10-29 16:48:27 +02:00
committed by GitHub
parent 51c0c8d27e
commit 7bd79f4922

View File

@@ -41,6 +41,10 @@ class GraphRewriteDetails(GraphRewriteMetadata):
# ** API functions
def pcall(fxn, *args, **kwargs):
try: return fxn(*args, **kwargs)
except Exception as e: return f"ERROR: {e}"
def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]:
kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
for k,ctxs in contexts:
@@ -69,7 +73,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
@functools.lru_cache(None)
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k))
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat,_) in enumerate(ctx.matches):
@@ -83,7 +87,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
g.graphs.append(sink:=new_sink)
return g
@@ -104,7 +108,7 @@ class Handler(BaseHTTPRequestHandler):
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(str, g.graphs))}).encode()
ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(lambda x:pcall(str,x), g.graphs))}).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
else:
self.send_response(404)