delete multi output support (#8822)

* delete multioutput for now

* test_schedule

* test_assign too

* linter

* 515 for sd

* update tests and ctx

* update that assign check
This commit is contained in:
qazal
2025-01-30 22:45:50 -05:00
committed by GitHub
parent 7647cd8428
commit 1fce864a6d
4 changed files with 24 additions and 32 deletions

View File

@@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]:
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 11)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -682,6 +682,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 2, filter_sink=False)
# multireduce spec
@unittest.expectedFailure
def test_reduce_same_size(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -694,6 +695,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
# multireduce spec
@unittest.expectedFailure
def test_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -714,7 +716,7 @@ class TestSchedule(unittest.TestCase):
out2 = b.sum().exp2()
out3 = b.sum() + out2
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
run_schedule(check_schedule([out0, out1, out2, out3], 2))
run_schedule(check_schedule([out0, out1, out2, out3], 6))
np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
np_b = (a.numpy() + np_out0 + np_out1)
@@ -793,6 +795,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4)
@unittest.expectedFailure
def test_reduce_shrink_child(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
@@ -1039,7 +1042,7 @@ class TestSchedule(unittest.TestCase):
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
check_schedule(opt.schedule_step(), 16)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -1049,7 +1052,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
check_schedule(opt.schedule_step(), 16)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -1060,7 +1063,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 20)
def test_sgd_conv_fuse(self):
with Tensor.train():
@@ -1136,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
shared = x.sum().half().float()
a = shared * 2
b = shared * 3
sched = check_schedule([a, b], 1)
sched = check_schedule([a, b], 3)
for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
# reduce
@@ -1272,6 +1275,7 @@ class TestSchedule(unittest.TestCase):
# changed by: multireduce spec
# pattern in adam
@unittest.expectedFailure
def test_partial_fuse3(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
@@ -1288,6 +1292,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4)
# changed by: multireduce spec
@unittest.expectedFailure
def test_partial_fuse4(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
@@ -1763,6 +1768,7 @@ class TestIndexing(unittest.TestCase):
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
@unittest.expectedFailure
def test_arange_fuse_grouped_children(self):
X = Tensor.randn(4, 4).realize()
r = (X+Tensor.arange(16).reshape(4, 4)).sum()
@@ -1780,7 +1786,7 @@ class TestIndexing(unittest.TestCase):
self.check_schedule([r], 1)
np.testing.assert_allclose(r.numpy(), (X.numpy()+np.arange(16).reshape(4, 4)).sum(1, keepdims=True))
@unittest.expectedFailure
@unittest.skip("multi output isn't supported")
def test_multiview_arange_children(self):
X = Tensor.randn(2,3,4,4).numpy()
with Context(FUSE_ARANGE=1):