diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 1690b9aa08..700443f0c9 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -18,7 +18,7 @@ def fuzz_schedule(outs:List[LazyBuffer]): for combination in itertools.product(*ctx_vars.values()): for var, val in zip(ctx_vars, combination): var.value = val ctx_var_values = dict(zip([v.key for v in ctx_vars], combination)) - graph, in_degree = _graph_schedule(outs) + graph, in_degree, _ = _graph_schedule(outs) for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values toposorts = list(unique_ts.items()) if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow")) diff --git a/test/external/process_replay/diff_schedule.py b/test/external/process_replay/diff_schedule.py index cb2ac33002..1639f32b3a 100644 --- a/test/external/process_replay/diff_schedule.py +++ b/test/external/process_replay/diff_schedule.py @@ -17,7 +17,7 @@ def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List if not os.path.isfile(fp): shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py", allow_caching=False), fp) # create the reference graph - ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set()) + ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs) # compare diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)]) diff --git a/test/external/process_replay/test_diff_schedule.py b/test/external/process_replay/test_diff_schedule.py index ece6122e55..69c4233d12 100644 --- a/test/external/process_replay/test_diff_schedule.py +++ b/test/external/process_replay/test_diff_schedule.py @@ -18,8 +18,8 @@ class TestDiffSchedule(unittest.TestCase): X = Tensor.randn(10, 10).realize() idxs = Tensor([0, 2]).realize() xt = cast(LazyBuffer, X[idxs].lazydata) - with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt]) - with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt]) + with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree, _ = _graph_schedule([xt]) + with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree, _ = _graph_schedule([xt]) # 1 arange LazyBuffer folds, 1 arange child's kernel changes changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) self.assertEqual(changed, 1) @@ -30,15 +30,15 @@ class TestDiffSchedule(unittest.TestCase): for _ in range(2): X = Tensor.randn(10, 10).realize() xt = cast(LazyBuffer, X[idxs].lazydata) - with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt])) - with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt])) + with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt])[:-1]) + with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt])[:-1]) changed = diff_schedule(schedules) self.assertEqual(changed, 1) def test_no_diff(self): a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata) - with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a]) - with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a]) + with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree, _ = _graph_schedule([a]) + with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree, _ = _graph_schedule([a]) changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) self.assertEqual(changed, 0) @@ -49,8 +49,8 @@ class TestDiffSchedule(unittest.TestCase): c1(img).relu().mean().backward() assert img.grad is not None and c1.weight.grad is not None outs = [cast(LazyBuffer, img.grad.lazydata), cast(LazyBuffer, c1.weight.grad.lazydata)] - with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs) - with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs) + with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree, _ = _graph_schedule(outs) + with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree, _ = _graph_schedule(outs) changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) # 1 reduceop folds, its child reduceop changes self.assertEqual(changed, 1) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ec22eb562c..616ea4224a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -38,7 +38,6 @@ class LBScheduleItem: ast: UOp outputs: List[LazyBuffer] inputs: List[LazyBuffer] - var_vals: Dict[Variable, int] = field(default_factory=dict) metadata: List[Metadata] = field(default_factory=list) def __hash__(self): """The unique identifier of a schedule item in the toposort.""" @@ -159,10 +158,10 @@ reduceop_fusor = PatternMatcher([ (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem: +def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LBScheduleItem, Dict[Variable, int]]: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: - return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs]) + return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs]), {} # create the stores var_vals = merge_dicts([out.st.var_vals.copy() for out in outs]) assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN} @@ -185,7 +184,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> from tinygrad.engine.graph import graph_uop graph_uop(sink) raise e - return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) + return LBScheduleItem(sink, outs, list(inputs), dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals # *** DAG creation: decide which LazyBuffers should realize *** @@ -360,11 +359,16 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \ SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = [] def _graph_schedule(outs:List[LazyBuffer]) -> \ Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph - DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph + DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph + Dict[Variable, int]]: # this has all the var values of the schedule """create a graph for realizing the outputs""" output_groups, realizes, assign_targets = _get_output_groups(outs) # preschedule all buffers in realizes - prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()] + prescheduled: List[LBScheduleItem] = [] + var_vals: Dict[Variable, int] = {} + for group in output_groups.values(): + prescheduled.append((ret:=_lower_lazybuffer(group, realizes))[0]) + var_vals = merge_dicts([var_vals, ret[1]]) schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs} graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list) @@ -388,26 +392,24 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \ with open(fp, "wb") as f: pickle.dump(SCHEDULES, f) if len(SCHEDULES) == 0: atexit.register(_save) SCHEDULES.append((graph, in_degree)) - return graph, in_degree + return graph, in_degree, var_vals # *** DAG ordering: breadth first search *** def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: - graph, in_degree = _graph_schedule(outs) + graph, in_degree, var_vals = _graph_schedule(outs) if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1): # NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree) queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0) schedule: List[ScheduleItem] = [] - var_vals: Dict[Variable, int] = {} kernel_number = GlobalCounters.kernel_count while queue: lsi = queue.popleft() if GRAPH: kernel_number += 1 for out in lsi.outputs: realized_lazybuffer(out, kernel_number) - var_vals = merge_dicts([var_vals, lsi.var_vals]) for out in lsi.outputs: del out.srcs # can only schedule once schedule.append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)) for x in graph[lsi]: