mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
search children in fusion (#4322)
* scheduler diff * tests diff * new changes * realizes * chores * assign * kind of r3 * forced_realize wont do it * with forced_realize * start with children * test search * r3 with parents * diff cleanup * add children * crossing assign * late fuse descendants * update kernel counts * assign diff doesnt belong here
This commit is contained in:
@@ -32,7 +32,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i, s in enumerate(sched):
|
||||
print("kernel", i+1)
|
||||
for op in s.ast: print_tree(op)
|
||||
assert len(sched) == allowed
|
||||
assert len(sched) == allowed, f"{len(sched)} != {allowed}"
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s.ast[0].op in LoadOps: continue
|
||||
@@ -485,7 +485,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + 4
|
||||
out2 = out0 * out1
|
||||
check_schedule([out0, out1, out2], 2)
|
||||
check_schedule([out0, out1, out2], 1)
|
||||
|
||||
def test_reduce_multiple_paths(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
@@ -581,7 +581,7 @@ class TestSchedule(unittest.TestCase):
|
||||
layer = nn.Linear(768, 768*4)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -590,7 +590,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -600,7 +600,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 15)
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -736,7 +736,7 @@ class TestSchedule(unittest.TestCase):
|
||||
d = a.sum() * 2
|
||||
e = c * d
|
||||
f = b.sum() - e
|
||||
check_schedule([c, d, e, f], 3)
|
||||
check_schedule([c, d, e, f], 2)
|
||||
|
||||
def test_partial_fuse4(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
|
||||
@@ -169,11 +169,19 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
||||
forced_realize = True
|
||||
break
|
||||
if len(realized_children) > 1:
|
||||
rc_parents = deque(realized_children)
|
||||
rc_parents, rc_children = deque(realized_children), deque(realized_children)
|
||||
while rc_parents and not forced_realize:
|
||||
# max one reduceop per kernel
|
||||
if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True
|
||||
else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r)
|
||||
realized_descendants: Set[LazyBuffer] = set()
|
||||
while rc_children and not forced_realize:
|
||||
if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op:
|
||||
realized_descendants.clear()
|
||||
break
|
||||
if c in realizes and c not in (*realized_children, tr): realized_descendants.add(c)
|
||||
rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device)
|
||||
realized_children.update((rd, st) for rd in realized_descendants)
|
||||
continue
|
||||
for tr_next in children[tr].keys():
|
||||
if not tr_next.realized:
|
||||
|
||||
Reference in New Issue
Block a user