diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 6cc2e87ab6..21b27aa1ea 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -5,7 +5,7 @@ from tinygrad.device import Buffer from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv from tinygrad.lazy import LazyBuffer -from tinygrad.engine.schedule import _graph_schedule, _LBScheduleItem, ScheduleItem +from tinygrad.engine.schedule import _graph_schedule, ScheduleItem from tinygrad.ops import LoadOps from tinygrad.tensor import Tensor, _to_np_dtype @@ -13,7 +13,7 @@ ctx_vars = { MULTIOUTPUT: (0, 1) } def fuzz_schedule(outs:List[LazyBuffer]): # find toposorts across all tunable params - unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {} + unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, Tuple]]] = {} for combination in itertools.product(*ctx_vars.values()): for var, val in zip(ctx_vars, combination): var.value = val graph, in_degree, prescheduled = _graph_schedule(outs, set()) @@ -29,16 +29,16 @@ def fuzz_schedule(outs:List[LazyBuffer]): seed = Tensor._seed ts, (_, prescheduled) = toposorts[0] for key in ts: - for out in (ps:=prescheduled[key]).outputs: + for out in (ps:=prescheduled[key])[0]: # freeze assign state before exec if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer() assign_targets[out.srcs[1]] = out - for x in ps.inputs: + for x in ps[2]: if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() - si = ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)) + si = ScheduleItem(ps[1], tuple(x.buffer for x in (tuple(ps[0])+ps[2]) if x.size != 0)) _exec_si(si, seed) - for out in ps.outputs: + for out in ps[0]: ground_truth[out] = out.buffer.as_buffer() del out.srcs # only schedule the LazyBuffer in this fuzz run @@ -47,19 +47,19 @@ def fuzz_schedule(outs:List[LazyBuffer]): if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow")) rawbufs: Dict[LazyBuffer, Buffer] = {} for key in ts: - for out in (ps:=prescheduled[key]).outputs: + for out in (ps:=prescheduled[key])[0]: rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype) if out.op is LoadOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) - for x in ps.inputs: + for x in ps[2]: if x not in rawbufs: # override the assign_target after ASSIGN if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]] elif x.device == "NPY": rawbufs[x] = x.buffer # copy the pre realized input else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x]) - si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0)) + si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in (tuple(ps[0])+ps[2]) if x.size != 0)) _exec_si(si, seed) - for out in ps.outputs: + for out in ps[0]: outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype)) try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2) except Exception as e: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2397dc425b..50271f9f8d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -34,14 +34,6 @@ class ScheduleItem: # *** DAG transformation: List[LazyBuffer] -> ScheduleItem *** -# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort -@dataclass(frozen=True) -class _LBScheduleItem: - ast: Tuple[LazyOp, ...] - outputs: Tuple[LazyBuffer, ...] - inputs: Tuple[LazyBuffer, ...] - var_vals: Dict[Variable, int] - def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker, realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp: """recursively create a lazyop""" @@ -100,7 +92,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg) return ret -def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem: +def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]): """create a schedule item from a list of outputs""" inputs: List[LazyBuffer] = [] ast: List[LazyOp] = [] @@ -123,7 +115,7 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None] output_view, vv = output_view.simplify().unbind() if vv: var_vals.update(vv) ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view))) - return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals) + return tuple(ast), tuple(inputs), var_vals # *** DAG creation: decide which LazyBuffers should realize *** @@ -181,8 +173,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa if len(st_childs:=dedup(s for s in tr_next.srcs if s.base == tr)) > 1: return group.add(r) _recursive_group(tr_next, st+st_childs[0].st, r, children, realizes, reduce_for_op, group) -def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int], - Dict[LazyBuffer, _LBScheduleItem]]: +def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): """create a graph for realizing the outputs""" # start by just realizing the buffers passed in realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None} @@ -269,20 +260,20 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul buf.buffer.options = None # preschedule all buffers in realizes - prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()} - schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs} + prescheduled = {group[0]:(group, *_schedule_group(tuple(group), realizes, reduce_for_op)) for group in output_groups.values()} + schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]} graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int) for key, lsi in prescheduled.items(): if key not in in_degree: in_degree[key] = 0 # realize outputs after all parents are realized - scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets) + scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets) for x in scheduled_parents: graph[x].append(key) in_degree[key] += 1 # realize outputs before a parent is assigned to - parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets) + parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets) for assign in parents_assigns: graph[key].append(assign) in_degree[assign] += 1 @@ -301,15 +292,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe kernel_number = GlobalCounters.kernel_count while queue: ps = queue.popleft() - for buf in ps.outputs: seen.add(buf) + for buf in ps[0]: seen.add(buf) if GRAPH: kernel_number += 1 - for out in ps.outputs: realized_lazybuffer(out, kernel_number) - var_vals = merge_dicts([var_vals, ps.var_vals]) - for out in ps.outputs: del out.srcs # can only schedule once - schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))) + for out in ps[0]: realized_lazybuffer(out, kernel_number) + var_vals = merge_dicts([var_vals, ps[3]]) + for out in ps[0]: del out.srcs # can only schedule once + schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in (tuple(ps[0])+ps[2]) if x.size != 0))) if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n") - for x in graph[ps.outputs[0]]: + for x in graph[ps[0][0]]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(prescheduled[x]) @@ -318,7 +309,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl")) with open(fp, "wb") as f: pickle.dump(SCHEDULES, f) if len(SCHEDULES) == 0: atexit.register(_save) - SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)]) + SCHEDULES.extend((ps[1] for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)]) # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")