From 0d7bd4f389ca40ae4bfe3317fa33c270e671a85e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:29:33 -0800 Subject: [PATCH] empty graph rewrite to VIZ tensor graph [pr] (#8658) * empty graph rewrite to VIZ tensor graph [pr] * fix lint --- tinygrad/engine/schedule.py | 7 +++++-- tinygrad/viz/serve.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d879daff54..1db3f226e4 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -519,9 +519,12 @@ remove_movement_ops = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: - if not skip_check: type_verify(list(UOp.sink(*outs).toposort), tensor_uop_spec) + big_sink = UOp.sink(*outs) + # if using VIZ, do an empty graph rewrite to vizualize the Tensor graph + if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([])) + if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) # to_uop is removing (many) of the movement ops - sink = add_buffers(UOp.sink(*outs), ctx:=ScheduleContext(), cache={}) + sink = add_buffers(big_sink, ctx:=ScheduleContext(), cache={}) # const folding and fusion sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx) sink = graph_rewrite(sink, merge_bufs, ctx) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 5a3bb71df6..c95fb5f800 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -73,7 +73,7 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]: if u.op is Ops.VIEW: argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+ ("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views])) - label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" + label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" for idx,x in enumerate(u.src): if x in excluded: if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"