From 935b6b658fb875be09190cf8ee3a509392a23664 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:26:34 +0800 Subject: [PATCH] delete seen from the scheduler api [run_process_replay] (#6427) docs --- examples/handcode_opt.py | 3 +-- examples/llm.c/export.py | 6 ++---- test/external/fuzz_schedule.py | 2 +- .../process_replay/test_diff_schedule.py | 16 +++++++-------- test/test_conv_shapetracker.py | 7 +++---- test/test_fusion_op.py | 10 +++++----- tinygrad/engine/schedule.py | 20 +++++++++---------- tinygrad/tensor.py | 16 +++++++++------ 8 files changed, 39 insertions(+), 41 deletions(-) diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index 3d6ab58064..468790c975 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -17,7 +17,6 @@ def get_sched_resnet(): BS = getenv("BS", 64) # run model twice to get only what changes, these are the kernels of the model - seen = set() for _ in range(2): out = mdl(Tensor.empty(BS, 3, 224, 224)) targets = [out.lazydata] @@ -25,7 +24,7 @@ def get_sched_resnet(): optim.zero_grad() out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward() targets += [x.lazydata for x in optim.schedule_step()] - sched = create_schedule(targets, seen) + sched = create_schedule(targets) print(f"schedule length {len(sched)}") return sched diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index 7fa1cc1dcb..0d6b4d938f 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -16,8 +16,7 @@ if __name__ == "__main__": #model.load_pretrained() for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained - seen = set() - #early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)], seen) + #early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)]) #print(f"built model {len(early_sched)}") #B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64) @@ -38,10 +37,9 @@ if __name__ == "__main__": tensors = optimizer.schedule_step() else: tensors = [] - sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors], seen) + sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors]) print(f"calls {i}:", len(sched)) #run_schedule(sched[:]) - del seen # free the LazyBuffers sched = memory_planner(sched) ast_dedup = dedup([si.ast for si in sched if si.ast.op is UOps.SINK]) srcs = {} diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 96ea4e71c9..1690b9aa08 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -18,7 +18,7 @@ def fuzz_schedule(outs:List[LazyBuffer]): for combination in itertools.product(*ctx_vars.values()): for var, val in zip(ctx_vars, combination): var.value = val ctx_var_values = dict(zip([v.key for v in ctx_vars], combination)) - graph, in_degree = _graph_schedule(outs, set()) + graph, in_degree = _graph_schedule(outs) for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values toposorts = list(unique_ts.items()) if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow")) diff --git a/test/external/process_replay/test_diff_schedule.py b/test/external/process_replay/test_diff_schedule.py index 3df2e2a966..ece6122e55 100644 --- a/test/external/process_replay/test_diff_schedule.py +++ b/test/external/process_replay/test_diff_schedule.py @@ -18,8 +18,8 @@ class TestDiffSchedule(unittest.TestCase): X = Tensor.randn(10, 10).realize() idxs = Tensor([0, 2]).realize() xt = cast(LazyBuffer, X[idxs].lazydata) - with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set()) - with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set()) + with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt]) + with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt]) # 1 arange LazyBuffer folds, 1 arange child's kernel changes changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) self.assertEqual(changed, 1) @@ -30,15 +30,15 @@ class TestDiffSchedule(unittest.TestCase): for _ in range(2): X = Tensor.randn(10, 10).realize() xt = cast(LazyBuffer, X[idxs].lazydata) - with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt], set())) - with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt], set())) + with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt])) + with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt])) changed = diff_schedule(schedules) self.assertEqual(changed, 1) def test_no_diff(self): a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata) - with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set()) - with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set()) + with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a]) + with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a]) changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) self.assertEqual(changed, 0) @@ -49,8 +49,8 @@ class TestDiffSchedule(unittest.TestCase): c1(img).relu().mean().backward() assert img.grad is not None and c1.weight.grad is not None outs = [cast(LazyBuffer, img.grad.lazydata), cast(LazyBuffer, c1.weight.grad.lazydata)] - with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs, set()) - with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs, set()) + with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs) + with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs) changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) # 1 reduceop folds, its child reduceop changes self.assertEqual(changed, 1) diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index 4156b2991d..2a7bb85ced 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -11,12 +11,11 @@ from test.unit.test_shapetracker import shapetracker_getitem class TestConvShapetracker(unittest.TestCase): def test_conv_3x3_one_view(self): conv = Conv2d(16, 32, (3, 3)) - seen = set() - # first run to init the weights, they are saved in seen - create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) + # first run to init the weights, they are scheduled. + create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) # run it again to get the kernels - sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is UOps.SINK] + sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is UOps.SINK] assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" for st in [x.st_arg for x in sched[0].ast.parents if x.op is UOps.LOAD]: assert len(st.views) == 1 diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 0337bf55a4..fd2ed89fd7 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -18,7 +18,7 @@ class TestFusionOp(unittest.TestCase): def test_expand_fuse(self): bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32) out = (bt*2).expand(10,10).sum(1) - sched = create_schedule([out.lazydata], None) + sched = create_schedule([out.lazydata]) run_schedule(sched) outd = out.tolist() assert all(x == 20.0 for x in outd) @@ -27,7 +27,7 @@ class TestFusionOp(unittest.TestCase): st = time.perf_counter() a = Tensor([1,2,3,4]) for _ in range(24): a = a + a - sched = create_schedule([a.lazydata], None) + sched = create_schedule([a.lazydata]) ei = lower_schedule_item(sched[-1]) self.assertLess(time.perf_counter()-st, 2.0) assert len(ei.prg.p.src.splitlines()) < 250 @@ -36,13 +36,13 @@ class TestFusionOp(unittest.TestCase): st = time.perf_counter() a = Tensor([1,2,3,4]) for _ in range(24): a = a + a - sched1 = create_schedule([a.lazydata], None) + sched1 = create_schedule([a.lazydata]) b = Tensor([1,2,3,4]) for _ in range(24): b = b + b - sched2 = create_schedule([b.lazydata], None) + sched2 = create_schedule([b.lazydata]) c = Tensor([1,2,3,4]) for _ in range(23): c = c + c - sched3 = create_schedule([c.lazydata], None) + sched3 = create_schedule([c.lazydata]) assert_equiv_uops(sched1[-1].ast, sched2[-1].ast) with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast) self.assertLess(time.perf_counter()-st, 2.0) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5976969aa5..c6efe3af4a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ import sys, pickle, atexit, importlib, contextlib from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Callable, Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args +from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps from tinygrad.ops import PatternMatcher, UPat, graph_rewrite from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer @@ -320,7 +320,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={}) return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) -def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ +def _get_output_groups(outs:List[LazyBuffer]) -> \ Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups Dict[LazyBuffer, None], # these are all the realizes in the graph Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule @@ -399,7 +399,7 @@ def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) for buf in realizes: - if buf.realized is not None or buf.op is MetaOps.CONST or buf in seen: continue + if buf.realized is not None or buf.op is MetaOps.CONST: continue output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf) # make things that can't be images not images @@ -415,11 +415,11 @@ def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ return output_groups, realizes, assign_targets SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = [] -def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ +def _graph_schedule(outs:List[LazyBuffer]) -> \ Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph """create a graph for realizing the outputs""" - output_groups, realizes, assign_targets = _get_output_groups(outs, seen) + output_groups, realizes, assign_targets = _get_output_groups(outs) # preschedule all buffers in realizes prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()]) schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs} @@ -449,9 +449,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ # *** DAG ordering: breadth first search *** -def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: - if seen is None: seen = set() - graph, in_degree = _graph_schedule(outs, seen) +def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: + graph, in_degree = _graph_schedule(outs) if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1): # NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree) @@ -462,7 +461,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe kernel_number = GlobalCounters.kernel_count while queue: lsi = queue.popleft() - for buf in lsi.outputs: seen.add(buf) if GRAPH: kernel_number += 1 for out in lsi.outputs: realized_lazybuffer(out, kernel_number) @@ -479,7 +477,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") return schedule, var_vals -def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: - schedule, var_vals = create_schedule_with_vars(outs, seen) +def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]: + schedule, var_vals = create_schedule_with_vars(outs) assert len(var_vals) == 0 return schedule diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5c35fc24e5..6797b45e53 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses import time, math, itertools, functools, struct, sys, inspect, pathlib, string from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set, Literal +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal from collections import defaultdict import numpy as np @@ -191,17 +191,21 @@ class Tensor: # ***** data handlers **** - def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: - """Creates the schedule needed to realize these Tensor(s), with Variables.""" + def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: + """ + Creates the schedule needed to realize these Tensor(s), with Variables. + + NOTE: A Tensor can only be scheduled once. + """ if getenv("FUZZ_SCHEDULE"): from test.external.fuzz_schedule import fuzz_schedule fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst])) - schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen) + schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst])) return memory_planner(schedule), var_vals - def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: + def schedule(self, *lst:Tensor) -> List[ScheduleItem]: """Creates the schedule needed to realize these Tensor(s).""" - schedule, var_vals = self.schedule_with_vars(*lst, seen=seen) + schedule, var_vals = self.schedule_with_vars(*lst) assert len(var_vals) == 0 return schedule