diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 5cb2b0cfc6..9917eaedf6 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -19,7 +19,7 @@ os.environ["RUN_PROCESS_REPLAY"] = "0" os.environ["CAPTURE_PROCESS_REPLAY"] = "0" early_stop = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format="%(message)s") -MAX_LINES = 1_000 +MAX_LINES = 500 def trunc_log(x): if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"] logging.info("\n".join(lines)) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 83c6e98059..3c9fc69b4e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -286,7 +286,7 @@ class Kernel: def __repr__(self): return f"" @dataclass(frozen=True) -class ScheduleItemContext: +class KernelContext: var_vals: dict[Variable, int] bufs: list[UOp] = field(default_factory=list) @@ -342,14 +342,14 @@ view_right = merge_views+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None: +def _append_st_vars(ctx:KernelContext, x:UOp) -> UOp|None: st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() ctx.var_vals.update(var_vals) return st.to_uop() if st != x.st else None -def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: +def _append_buf(ctx:KernelContext, x:UOp) -> UOp: ctx.bufs.append(x) return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) @@ -380,7 +380,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp: # unbind_vars + push views to edges sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right) # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL - ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals)) + ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(var_vals)) # deal with ASSIGN if len(ctx.assigns) != 0: assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] @@ -430,10 +430,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va ops_metadata[b] = k.metadata realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata)) - # TODO: this should be the break between the "grouper" and the "linearizer" - # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) - # call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]` - # create kernels + map buffers to realized tensors sinks: list[UOp] = [] var_vals: dict[Variable, int] = {} @@ -449,6 +445,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([])) type_verify(list(sched_sink.toposort), kernel_spec) + # TODO: this should be the break between the "grouper" and the "linearizer" + # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) + # call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]` + # convert kernels to ScheduleItem prescheduled = [kernel_to_si(k) for k in sched_sink.src] # add ScheduleItem children