mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -15,7 +15,7 @@ def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
return s
|
||||
|
||||
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
def create_schedule(sched_sink:UOp) -> UOp:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
@@ -47,20 +47,17 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
|
||||
with cpu_profile(TracingKey("linearize schedule")):
|
||||
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
linearized: list[UOp] = []
|
||||
while len(queue):
|
||||
rk = queue.popleft()
|
||||
k = rk.src[0] if rk.op is Ops.END else rk
|
||||
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
return UOp(Ops.LINEAR, src=tuple(linearized))
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||
@@ -76,42 +73,55 @@ pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
schedule_cache: dict[bytes, UOp] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
big_sink, buffer_map = transform_to_call(big_sink)
|
||||
|
||||
function = big_sink.src[0]
|
||||
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
pre_schedule, buf_uops_sink = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = (pre_schedule, buf_uops_sink)
|
||||
linear = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = linear
|
||||
else:
|
||||
# schedule cache hit
|
||||
pre_schedule, buf_uops_sink = sc_ret
|
||||
# it's a call that we late apply
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
linear = sc_ret
|
||||
|
||||
# add bufs to pre_schedule
|
||||
# it's a call that we late apply
|
||||
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
|
||||
# vars used in the schedule
|
||||
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
for b in big_sink.src[1:]:
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
if nm not in used_vars: continue
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
|
||||
# convert LINEAR to ExecItems
|
||||
schedule: list[ExecItem] = []
|
||||
for i, si in enumerate(pre_schedule):
|
||||
buf_uops = buf_uops_sink.src[i].src
|
||||
for si in linear.src:
|
||||
ast, buf_uops = si.src[0], si.src[1:]
|
||||
# create subbuffers if needed
|
||||
if si.ast.op is Ops.BUFFER_VIEW:
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
base = buf_uops[1].buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, si.ast.dtype, si.ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(b.buffer for b in buf_uops)
|
||||
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = [b.buffer for b in buf_uops]
|
||||
metadata = si.arg.metadata
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in si.ast.variables() if x.expr == '_device_num']
|
||||
dnums = [x for x in ast.variables() if x.expr == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
|
||||
schedule.append(ExecItem(ast, list(ubufs), metadata))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
@@ -124,15 +134,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||
|
||||
# vars used in the schedule
|
||||
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
for i,b in enumerate(big_sink.src[1:]):
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
if nm not in used_vars: continue
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
return buffer_map, schedule, var_vals
|
||||
|
||||
Reference in New Issue
Block a user