mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
@@ -41,6 +41,10 @@ class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
|
||||
# ** API functions
|
||||
|
||||
def pcall(fxn, *args, **kwargs):
|
||||
try: return fxn(*args, **kwargs)
|
||||
except Exception as e: return f"ERROR: {e}"
|
||||
|
||||
def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]:
|
||||
kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
|
||||
for k,ctxs in contexts:
|
||||
@@ -69,7 +73,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k))
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
sink = ctx.sink
|
||||
for i,(u0,u1,upat,_) in enumerate(ctx.matches):
|
||||
@@ -83,7 +87,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
|
||||
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
g.graphs.append(sink:=new_sink)
|
||||
return g
|
||||
|
||||
@@ -104,7 +108,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
|
||||
ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(str, g.graphs))}).encode()
|
||||
ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs)), "uops": list(map(lambda x:pcall(str,x), g.graphs))}).encode()
|
||||
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
||||
Reference in New Issue
Block a user