mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -17,7 +17,6 @@ def get_sched_resnet():
|
||||
BS = getenv("BS", 64)
|
||||
|
||||
# run model twice to get only what changes, these are the kernels of the model
|
||||
seen = set()
|
||||
for _ in range(2):
|
||||
out = mdl(Tensor.empty(BS, 3, 224, 224))
|
||||
targets = [out.lazydata]
|
||||
@@ -25,7 +24,7 @@ def get_sched_resnet():
|
||||
optim.zero_grad()
|
||||
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
|
||||
targets += [x.lazydata for x in optim.schedule_step()]
|
||||
sched = create_schedule(targets, seen)
|
||||
sched = create_schedule(targets)
|
||||
print(f"schedule length {len(sched)}")
|
||||
return sched
|
||||
|
||||
|
||||
@@ -16,8 +16,7 @@ if __name__ == "__main__":
|
||||
#model.load_pretrained()
|
||||
for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained
|
||||
|
||||
seen = set()
|
||||
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)], seen)
|
||||
#early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)])
|
||||
#print(f"built model {len(early_sched)}")
|
||||
|
||||
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
|
||||
@@ -38,10 +37,9 @@ if __name__ == "__main__":
|
||||
tensors = optimizer.schedule_step()
|
||||
else:
|
||||
tensors = []
|
||||
sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors], seen)
|
||||
sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors])
|
||||
print(f"calls {i}:", len(sched))
|
||||
#run_schedule(sched[:])
|
||||
del seen # free the LazyBuffers
|
||||
sched = memory_planner(sched)
|
||||
ast_dedup = dedup([si.ast for si in sched if si.ast.op is UOps.SINK])
|
||||
srcs = {}
|
||||
|
||||
2
test/external/fuzz_schedule.py
vendored
2
test/external/fuzz_schedule.py
vendored
@@ -18,7 +18,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
for combination in itertools.product(*ctx_vars.values()):
|
||||
for var, val in zip(ctx_vars, combination): var.value = val
|
||||
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
|
||||
graph, in_degree = _graph_schedule(outs, set())
|
||||
graph, in_degree = _graph_schedule(outs)
|
||||
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
|
||||
toposorts = list(unique_ts.items())
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
|
||||
|
||||
@@ -18,8 +18,8 @@ class TestDiffSchedule(unittest.TestCase):
|
||||
X = Tensor.randn(10, 10).realize()
|
||||
idxs = Tensor([0, 2]).realize()
|
||||
xt = cast(LazyBuffer, X[idxs].lazydata)
|
||||
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set())
|
||||
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set())
|
||||
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt])
|
||||
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt])
|
||||
# 1 arange LazyBuffer folds, 1 arange child's kernel changes
|
||||
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
|
||||
self.assertEqual(changed, 1)
|
||||
@@ -30,15 +30,15 @@ class TestDiffSchedule(unittest.TestCase):
|
||||
for _ in range(2):
|
||||
X = Tensor.randn(10, 10).realize()
|
||||
xt = cast(LazyBuffer, X[idxs].lazydata)
|
||||
with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt], set()))
|
||||
with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt], set()))
|
||||
with Context(FUSE_ARANGE=0): schedules.append(_graph_schedule([xt]))
|
||||
with Context(FUSE_ARANGE=1): schedules.append(_graph_schedule([xt]))
|
||||
changed = diff_schedule(schedules)
|
||||
self.assertEqual(changed, 1)
|
||||
|
||||
def test_no_diff(self):
|
||||
a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata)
|
||||
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set())
|
||||
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set())
|
||||
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a])
|
||||
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a])
|
||||
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
|
||||
self.assertEqual(changed, 0)
|
||||
|
||||
@@ -49,8 +49,8 @@ class TestDiffSchedule(unittest.TestCase):
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
outs = [cast(LazyBuffer, img.grad.lazydata), cast(LazyBuffer, c1.weight.grad.lazydata)]
|
||||
with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs, set())
|
||||
with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs, set())
|
||||
with Context(FUSE_CONV_BW=0): ref_graph, ref_in_degree = _graph_schedule(outs)
|
||||
with Context(FUSE_CONV_BW=1): compare_graph, compare_in_degree = _graph_schedule(outs)
|
||||
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
|
||||
# 1 reduceop folds, its child reduceop changes
|
||||
self.assertEqual(changed, 1)
|
||||
|
||||
@@ -11,12 +11,11 @@ from test.unit.test_shapetracker import shapetracker_getitem
|
||||
class TestConvShapetracker(unittest.TestCase):
|
||||
def test_conv_3x3_one_view(self):
|
||||
conv = Conv2d(16, 32, (3, 3))
|
||||
seen = set()
|
||||
|
||||
# first run to init the weights, they are saved in seen
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
|
||||
# first run to init the weights, they are scheduled.
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is UOps.SINK]
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is UOps.SINK]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
for st in [x.st_arg for x in sched[0].ast.parents if x.op is UOps.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
def test_expand_fuse(self):
|
||||
bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
|
||||
out = (bt*2).expand(10,10).sum(1)
|
||||
sched = create_schedule([out.lazydata], None)
|
||||
sched = create_schedule([out.lazydata])
|
||||
run_schedule(sched)
|
||||
outd = out.tolist()
|
||||
assert all(x == 20.0 for x in outd)
|
||||
@@ -27,7 +27,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched = create_schedule([a.lazydata], None)
|
||||
sched = create_schedule([a.lazydata])
|
||||
ei = lower_schedule_item(sched[-1])
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
assert len(ei.prg.p.src.splitlines()) < 250
|
||||
@@ -36,13 +36,13 @@ class TestFusionOp(unittest.TestCase):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched1 = create_schedule([a.lazydata], None)
|
||||
sched1 = create_schedule([a.lazydata])
|
||||
b = Tensor([1,2,3,4])
|
||||
for _ in range(24): b = b + b
|
||||
sched2 = create_schedule([b.lazydata], None)
|
||||
sched2 = create_schedule([b.lazydata])
|
||||
c = Tensor([1,2,3,4])
|
||||
for _ in range(23): c = c + c
|
||||
sched3 = create_schedule([c.lazydata], None)
|
||||
sched3 = create_schedule([c.lazydata])
|
||||
assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
|
||||
with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys, pickle, atexit, importlib, contextlib
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
|
||||
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
|
||||
from tinygrad.ops import PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
@@ -320,7 +320,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||
def _get_output_groups(outs:List[LazyBuffer]) -> \
|
||||
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
|
||||
Dict[LazyBuffer, None], # these are all the realizes in the graph
|
||||
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
|
||||
@@ -399,7 +399,7 @@ def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||
|
||||
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
||||
for buf in realizes:
|
||||
if buf.realized is not None or buf.op is MetaOps.CONST or buf in seen: continue
|
||||
if buf.realized is not None or buf.op is MetaOps.CONST: continue
|
||||
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
|
||||
|
||||
# make things that can't be images not images
|
||||
@@ -415,11 +415,11 @@ def _get_output_groups(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||
return output_groups, realizes, assign_targets
|
||||
|
||||
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
|
||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||
def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
|
||||
DefaultDict[LBScheduleItem, int]]: # this is the in-degree of the graph
|
||||
"""create a graph for realizing the outputs"""
|
||||
output_groups, realizes, assign_targets = _get_output_groups(outs, seen)
|
||||
output_groups, realizes, assign_targets = _get_output_groups(outs)
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()])
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
@@ -449,9 +449,8 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||
|
||||
# *** DAG ordering: breadth first search ***
|
||||
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if seen is None: seen = set()
|
||||
graph, in_degree = _graph_schedule(outs, seen)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
graph, in_degree = _graph_schedule(outs)
|
||||
if getenv("RUN_PROCESS_REPLAY") and getenv("COMPARE_SCHEDULE", 1):
|
||||
# NOTE: process relpay needs PYTHONPATH=., remove this once it just pickles LazyBuffers
|
||||
with contextlib.suppress(Exception): importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
|
||||
@@ -462,7 +461,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||
kernel_number = GlobalCounters.kernel_count
|
||||
while queue:
|
||||
lsi = queue.popleft()
|
||||
for buf in lsi.outputs: seen.add(buf)
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
|
||||
@@ -479,7 +477,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
return schedule, var_vals
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
schedule, var_vals = create_schedule_with_vars(outs, seen)
|
||||
def create_schedule(outs:List[LazyBuffer]) -> List[ScheduleItem]:
|
||||
schedule, var_vals = create_schedule_with_vars(outs)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set, Literal
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
@@ -191,17 +191,21 @@ class Tensor:
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
"""Creates the schedule needed to realize these Tensor(s), with Variables."""
|
||||
def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
"""
|
||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
if getenv("FUZZ_SCHEDULE"):
|
||||
from test.external.fuzz_schedule import fuzz_schedule
|
||||
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
|
||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
return memory_planner(schedule), var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
|
||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
|
||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
|
||||
Reference in New Issue
Block a user