mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
rename to KernelContext and move the linearize_sched comment [pr] (#8899)
* rename to KernelContext and move that comment [pr] * 500
This commit is contained in:
@@ -19,7 +19,7 @@ os.environ["RUN_PROCESS_REPLAY"] = "0"
|
||||
os.environ["CAPTURE_PROCESS_REPLAY"] = "0"
|
||||
early_stop = multiprocessing.Event()
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
MAX_LINES = 1_000
|
||||
MAX_LINES = 500
|
||||
def trunc_log(x):
|
||||
if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"]
|
||||
logging.info("\n".join(lines))
|
||||
|
||||
@@ -286,7 +286,7 @@ class Kernel:
|
||||
def __repr__(self): return f"<Kernel {len(list(self.ast.toposort))} {self.ast.op} {self.metadata}>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
class KernelContext:
|
||||
var_vals: dict[Variable, int]
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
|
||||
@@ -342,14 +342,14 @@ view_right = merge_views+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
def _append_st_vars(ctx:KernelContext, x:UOp) -> UOp|None:
|
||||
st = unwrap(x.st).simplify()
|
||||
if any(x.op is Ops.BIND for x in st.vars()):
|
||||
st, var_vals = st.unbind()
|
||||
ctx.var_vals.update(var_vals)
|
||||
return st.to_uop() if st != x.st else None
|
||||
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
def _append_buf(ctx:KernelContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
|
||||
@@ -380,7 +380,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
|
||||
# unbind_vars + push views to edges
|
||||
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(var_vals))
|
||||
# deal with ASSIGN
|
||||
if len(ctx.assigns) != 0:
|
||||
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
|
||||
@@ -430,10 +430,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
ops_metadata[b] = k.metadata
|
||||
realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata))
|
||||
|
||||
# TODO: this should be the break between the "grouper" and the "linearizer"
|
||||
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
||||
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
|
||||
|
||||
# create kernels + map buffers to realized tensors
|
||||
sinks: list[UOp] = []
|
||||
var_vals: dict[Variable, int] = {}
|
||||
@@ -449,6 +445,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
|
||||
type_verify(list(sched_sink.toposort), kernel_spec)
|
||||
|
||||
# TODO: this should be the break between the "grouper" and the "linearizer"
|
||||
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
||||
# call into `def linearize_schedule(sched_sink:UOp) -> list[ScheduleItem]`
|
||||
|
||||
# convert kernels to ScheduleItem
|
||||
prescheduled = [kernel_to_si(k) for k in sched_sink.src]
|
||||
# add ScheduleItem children
|
||||
|
||||
Reference in New Issue
Block a user