refactor schedule creation (#5297)

This commit is contained in:
qazal
2024-07-05 21:14:38 +03:00
committed by GitHub
parent 5292d37db6
commit b369e75ed0
2 changed files with 24 additions and 33 deletions

View File

@@ -5,7 +5,7 @@ from tinygrad.device import Buffer
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.schedule import _graph_schedule, _LBScheduleItem, ScheduleItem
from tinygrad.engine.schedule import _graph_schedule, ScheduleItem
from tinygrad.ops import LoadOps
from tinygrad.tensor import Tensor, _to_np_dtype
@@ -13,7 +13,7 @@ ctx_vars = { MULTIOUTPUT: (0, 1) }
def fuzz_schedule(outs:List[LazyBuffer]):
# find toposorts across all tunable params
unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {}
unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, Tuple]]] = {}
for combination in itertools.product(*ctx_vars.values()):
for var, val in zip(ctx_vars, combination): var.value = val
graph, in_degree, prescheduled = _graph_schedule(outs, set())
@@ -29,16 +29,16 @@ def fuzz_schedule(outs:List[LazyBuffer]):
seed = Tensor._seed
ts, (_, prescheduled) = toposorts[0]
for key in ts:
for out in (ps:=prescheduled[key]).outputs:
for out in (ps:=prescheduled[key])[0]:
# freeze assign state before exec
if out.op is LoadOps.ASSIGN:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in ps.inputs:
for x in ps[2]:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))
si = ScheduleItem(ps[1], tuple(x.buffer for x in (tuple(ps[0])+ps[2]) if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
for out in ps[0]:
ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run
@@ -47,19 +47,19 @@ def fuzz_schedule(outs:List[LazyBuffer]):
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
rawbufs: Dict[LazyBuffer, Buffer] = {}
for key in ts:
for out in (ps:=prescheduled[key]).outputs:
for out in (ps:=prescheduled[key])[0]:
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
if out.op is LoadOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in ps.inputs:
for x in ps[2]:
if x not in rawbufs:
# override the assign_target after ASSIGN
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0))
si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in (tuple(ps[0])+ps[2]) if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
for out in ps[0]:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
except Exception as e:

View File

@@ -34,14 +34,6 @@ class ScheduleItem:
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
@dataclass(frozen=True)
class _LBScheduleItem:
ast: Tuple[LazyOp, ...]
outputs: Tuple[LazyBuffer, ...]
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
"""recursively create a lazyop"""
@@ -100,7 +92,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
return ret
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]):
"""create a schedule item from a list of outputs"""
inputs: List[LazyBuffer] = []
ast: List[LazyOp] = []
@@ -123,7 +115,7 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
return tuple(ast), tuple(inputs), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -181,8 +173,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa
if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r)
_recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group)
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int],
Dict[LazyBuffer, _LBScheduleItem]]:
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
"""create a graph for realizing the outputs"""
# start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
@@ -269,20 +260,20 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
buf.buffer.options = None
# preschedule all buffers in realizes
prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
prescheduled = {group[0]:(group, *_schedule_group(tuple(group), realizes, reduce_for_op)) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
for key, lsi in prescheduled.items():
if key not in in_degree: in_degree[key] = 0
# realize outputs after all parents are realized
scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets)
for x in scheduled_parents:
graph[x].append(key)
in_degree[key] += 1
# realize outputs before a parent is assigned to
parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets)
for assign in parents_assigns:
graph[key].append(assign)
in_degree[assign] += 1
@@ -301,15 +292,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
kernel_number = GlobalCounters.kernel_count
while queue:
ps = queue.popleft()
for buf in ps.outputs: seen.add(buf)
for buf in ps[0]: seen.add(buf)
if GRAPH:
kernel_number += 1
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps.var_vals])
for out in ps.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
for out in ps[0]: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps[3]])
for out in ps[0]: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in (tuple(ps[0])+ps[2]) if x.size != 0)))
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps.outputs[0]]:
for x in graph[ps[0][0]]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(prescheduled[x])
@@ -318,7 +309,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
if len(SCHEDULES) == 0: atexit.register(_save)
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
SCHEDULES.extend((ps[1] for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")