diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index e8a1895763..5ec2af476f 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -41,6 +41,10 @@ class GraphRewriteDetails(GraphRewriteMetadata): # ** API functions +def pcall(fxn, *args, **kwargs): + try: return fxn(*args, **kwargs) + except Exception as e: return f"ERROR: {e}" + def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]: kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {} for k,ctxs in contexts: @@ -69,7 +73,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: @functools.lru_cache(None) def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: - g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k)) + g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k)) replaces: Dict[UOp, UOp] = {} sink = ctx.sink for i,(u0,u1,upat,_) in enumerate(ctx.matches): @@ -83,7 +87,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") # update ret data g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST]) - g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines()))) + g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) g.graphs.append(sink:=new_sink) return g @@ -104,7 +108,7 @@ class Handler(BaseHTTPRequestHandler): query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]) - ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(str, g.graphs))}).encode() + ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(lambda x:pcall(str,x), g.graphs))}).encode() else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode() else: self.send_response(404)