rename to KernelContext and move the linearize_sched comment [pr] (#8899)

* rename to KernelContext and move that comment [pr]

* 500
This commit is contained in:
qazal
2025-02-05 07:49:58 +01:00
committed by GitHub
parent 6fb0e5751b
commit 6f0cc2e9c5
2 changed files with 9 additions and 9 deletions

View File

@@ -19,7 +19,7 @@ os.environ["RUN_PROCESS_REPLAY"] = "0"
os.environ["CAPTURE_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
MAX_LINES = 1_000
MAX_LINES = 500
def trunc_log(x):
if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"]
logging.info("\n".join(lines))

View File

@@ -286,7 +286,7 @@ class Kernel:
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
@dataclass(frozen=True)
class ScheduleItemContext:
class KernelContext:
var_vals: dict[Variable, int]
bufs: list[UOp] = field(default_factory=list)
@@ -342,14 +342,14 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
def _append_st_vars(ctx:KernelContext, x:UOp) -> UOp|None:
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx.var_vals.update(var_vals)
return st.to_uop() if st != x.st else None
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
def _append_buf(ctx:KernelContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
@@ -380,7 +380,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
# unbind_vars + push views to edges
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(var_vals))
# deal with ASSIGN
if len(ctx.assigns) != 0:
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
@@ -430,10 +430,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
ops_metadata[b] = k.metadata
realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata))
# TODO: this should be the break between the "grouper" and the "linearizer"
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
# create kernels + map buffers to realized tensors
sinks: list[UOp] = []
var_vals: dict[Variable, int] = {}
@@ -449,6 +445,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
type_verify(list(sched_sink.toposort), kernel_spec)
# TODO: this should be the break between the "grouper" and the "linearizer"
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
# convert kernels to ScheduleItem
prescheduled = [kernel_to_si(k) for k in sched_sink.src]
# add ScheduleItem children