search children in fusion (#4322)

* scheduler diff

* tests diff

* new changes

* realizes

* chores

* assign

* kind of r3

* forced_realize wont do it

* with forced_realize

* start with children

* test search

* r3 with parents

* diff cleanup

* add children

* crossing assign

* late fuse descendants

* update kernel counts

* assign diff doesnt belong here
This commit is contained in:
qazal
2024-05-04 22:22:15 +08:00
committed by GitHub
parent 249cadd106
commit 5f3bae378f
2 changed files with 15 additions and 7 deletions

View File

@@ -32,7 +32,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i, s in enumerate(sched):
print("kernel", i+1)
for op in s.ast: print_tree(op)
assert len(sched) == allowed
assert len(sched) == allowed, f"{len(sched)} != {allowed}"
# test the (non loadops) ops linearize
for s in sched:
if s.ast[0].op in LoadOps: continue
@@ -485,7 +485,7 @@ class TestSchedule(unittest.TestCase):
out0 = a.sum() + 2
out1 = a.sum() + 4
out2 = out0 * out1
check_schedule([out0, out1, out2], 2)
check_schedule([out0, out1, out2], 1)
def test_reduce_multiple_paths(self):
a = Tensor.empty(4, 4)
@@ -581,7 +581,7 @@ class TestSchedule(unittest.TestCase):
layer = nn.Linear(768, 768*4)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 12)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -590,7 +590,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 12)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -600,7 +600,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 15)
check_schedule(opt.schedule_step(), 14)
def test_sgd_conv_fuse(self):
with Tensor.train():
@@ -736,7 +736,7 @@ class TestSchedule(unittest.TestCase):
d = a.sum() * 2
e = c * d
f = b.sum() - e
check_schedule([c, d, e, f], 3)
check_schedule([c, d, e, f], 2)
def test_partial_fuse4(self):
a = Tensor.empty(16, 16)