mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
search children in fusion (#4322)
* scheduler diff * tests diff * new changes * realizes * chores * assign * kind of r3 * forced_realize wont do it * with forced_realize * start with children * test search * r3 with parents * diff cleanup * add children * crossing assign * late fuse descendants * update kernel counts * assign diff doesnt belong here
This commit is contained in:
@@ -32,7 +32,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i, s in enumerate(sched):
|
||||
print("kernel", i+1)
|
||||
for op in s.ast: print_tree(op)
|
||||
assert len(sched) == allowed
|
||||
assert len(sched) == allowed, f"{len(sched)} != {allowed}"
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s.ast[0].op in LoadOps: continue
|
||||
@@ -485,7 +485,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + 4
|
||||
out2 = out0 * out1
|
||||
check_schedule([out0, out1, out2], 2)
|
||||
check_schedule([out0, out1, out2], 1)
|
||||
|
||||
def test_reduce_multiple_paths(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
@@ -581,7 +581,7 @@ class TestSchedule(unittest.TestCase):
|
||||
layer = nn.Linear(768, 768*4)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -590,7 +590,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -600,7 +600,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 15)
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -736,7 +736,7 @@ class TestSchedule(unittest.TestCase):
|
||||
d = a.sum() * 2
|
||||
e = c * d
|
||||
f = b.sum() - e
|
||||
check_schedule([c, d, e, f], 3)
|
||||
check_schedule([c, d, e, f], 2)
|
||||
|
||||
def test_partial_fuse4(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
|
||||
Reference in New Issue
Block a user