delete seen from the scheduler api [run_process_replay] (#6427)

docs
This commit is contained in:
qazal
2024-09-09 16:26:34 +08:00
committed by GitHub
parent 6c7abd18df
commit 935b6b658f
8 changed files with 39 additions and 41 deletions

View File

@@ -17,7 +17,6 @@ def get_sched_resnet():
BS = getenv("BS", 64)
# run model twice to get only what changes, these are the kernels of the model
seen = set()
for _ in range(2):
out = mdl(Tensor.empty(BS, 3, 224, 224))
targets = [out.lazydata]
@@ -25,7 +24,7 @@ def get_sched_resnet():
optim.zero_grad()
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
targets += [x.lazydata for x in optim.schedule_step()]
sched = create_schedule(targets, seen)
sched = create_schedule(targets)
print(f"schedule length {len(sched)}")
return sched

View File

@@ -16,8 +16,7 @@ if __name__ == "__main__":
#model.load_pretrained()
for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained
seen = set()
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)], seen)
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)])
#print(f"built model {len(early_sched)}")
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
@@ -38,10 +37,9 @@ if __name__ == "__main__":
tensors = optimizer.schedule_step()
else:
tensors = []
sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors], seen)
sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors])
print(f"calls {i}:", len(sched))
#run_schedule(sched[:])
del seen # free the LazyBuffers
sched = memory_planner(sched)
ast_dedup = dedup([si.ast for si in sched if si.ast.op is UOps.SINK])
srcs = {}

View File

@@ -18,7 +18,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
for combination in itertools.product(*ctx_vars.values()):
for var, val in zip(ctx_vars, combination): var.value = val
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
graph, in_degree = _graph_schedule(outs, set())
graph, in_degree = _graph_schedule(outs)
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
toposorts = list(unique_ts.items())
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))

View File

@@ -18,8 +18,8 @@ class TestDiffSchedule(unittest.TestCase):
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = cast(LazyBuffer, X[idxs].lazydata)
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set())
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set())
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt])
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt])
# 1 arange LazyBuffer folds, 1 arange child's kernel changes
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 1)
@@ -30,15 +30,15 @@ class TestDiffSchedule(unittest.TestCase):
for _ in range(2):
X = Tensor.randn(10, 10).realize()
xt = cast(LazyBuffer, X[idxs].lazydata)
with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt], set()))
with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt], set()))
with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt]))
with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt]))
changed = diff_schedule(schedules)
self.assertEqual(changed, 1)
def test_no_diff(self):
a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata)
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set())
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set())
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a])
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a])
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 0)
@@ -49,8 +49,8 @@ class TestDiffSchedule(unittest.TestCase):
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
outs = [cast(LazyBuffer, img.grad.lazydata), cast(LazyBuffer, c1.weight.grad.lazydata)]
with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs, set())
with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs, set())
with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs)
with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs)
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
# 1 reduceop folds, its child reduceop changes
self.assertEqual(changed, 1)

View File

@@ -11,12 +11,11 @@ from test.unit.test_shapetracker import shapetracker_getitem
class TestConvShapetracker(unittest.TestCase):
def test_conv_3x3_one_view(self):
conv = Conv2d(16, 32, (3, 3))
seen = set()
# first run to init the weights, they are saved in seen
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
# first run to init the weights, they are scheduled.
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is UOps.SINK]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is UOps.SINK]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
for st in [x.st_arg for x in sched[0].ast.parents if x.op is UOps.LOAD]:
assert len(st.views) == 1

View File

@@ -18,7 +18,7 @@ class TestFusionOp(unittest.TestCase):
def test_expand_fuse(self):
bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
out = (bt*2).expand(10,10).sum(1)
sched = create_schedule([out.lazydata], None)
sched = create_schedule([out.lazydata])
run_schedule(sched)
outd = out.tolist()
assert all(x == 20.0 for x in outd)
@@ -27,7 +27,7 @@ class TestFusionOp(unittest.TestCase):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched = create_schedule([a.lazydata], None)
sched = create_schedule([a.lazydata])
ei = lower_schedule_item(sched[-1])
self.assertLess(time.perf_counter()-st, 2.0)
assert len(ei.prg.p.src.splitlines()) < 250
@@ -36,13 +36,13 @@ class TestFusionOp(unittest.TestCase):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched1 = create_schedule([a.lazydata], None)
sched1 = create_schedule([a.lazydata])
b = Tensor([1,2,3,4])
for _ in range(24): b = b + b
sched2 = create_schedule([b.lazydata], None)
sched2 = create_schedule([b.lazydata])
c = Tensor([1,2,3,4])
for _ in range(23): c = c + c
sched3 = create_schedule([c.lazydata], None)
sched3 = create_schedule([c.lazydata])
assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
self.assertLess(time.perf_counter()-st, 2.0)

View File

@@ -1,7 +1,7 @@
import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
@@ -320,7 +320,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
def _get_output_groups(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
Dict[LazyBuffer, None], # these are all the realizes in the graph
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
@@ -399,7 +399,7 @@ def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
for buf in realizes:
if buf.realized is not None or buf.op is MetaOps.CONST or buf in seen: continue
if buf.realized is not None or buf.op is MetaOps.CONST: continue
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
# make things that can't be images not images
@@ -415,11 +415,11 @@ def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
return output_groups, realizes, assign_targets
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
def _graph_schedule(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph
"""create a graph for realizing the outputs"""
output_groups, realizes, assign_targets = _get_output_groups(outs, seen)
output_groups, realizes, assign_targets = _get_output_groups(outs)
# preschedule all buffers in realizes
prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()])
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
@@ -449,9 +449,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
# *** DAG ordering: breadth first search ***
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if seen is None: seen = set()
graph, in_degree = _graph_schedule(outs, seen)
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
graph, in_degree = _graph_schedule(outs)
if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1):
# NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers
with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
@@ -462,7 +461,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
kernel_number = GlobalCounters.kernel_count
while queue:
lsi = queue.popleft()
for buf in lsi.outputs: seen.add(buf)
if GRAPH:
kernel_number += 1
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
@@ -479,7 +477,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, var_vals
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
schedule, var_vals = create_schedule_with_vars(outs, seen)
def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
schedule, var_vals = create_schedule_with_vars(outs)
assert len(var_vals) == 0
return schedule

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses
import time, math, itertools, functools, struct, sys, inspect, pathlib, string
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set, Literal
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
from collections import defaultdict
import numpy as np
@@ -191,17 +191,21 @@ class Tensor:
# ***** data handlers ****
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
"""Creates the schedule needed to realize these Tensor(s), with Variables."""
def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
NOTE: A Tensor can only be scheduled once.
"""
if getenv("FUZZ_SCHEDULE"):
from test.external.fuzz_schedule import fuzz_schedule
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
return memory_planner(schedule), var_vals
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
"""Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
schedule, var_vals = self.schedule_with_vars(*lst)
assert len(var_vals) == 0
return schedule