beautiful_mnist -4.3% kernels (#5709)

* add is_complete

* partially delete forced_realized

* p2

* start

* refactor to can_group

* remove steps

* _get_inputs is nicer

* fix the cache

* cache is dict now

* rename to group
This commit is contained in:
qazal
2024-07-26 01:30:49 +08:00
committed by GitHub
parent 92eefab4b0
commit 9ceb3a3d1f
2 changed files with 29 additions and 17 deletions

View File

@@ -209,7 +209,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 16)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -256,7 +256,7 @@ class TestSchedule(unittest.TestCase):
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
@@ -620,7 +620,7 @@ class TestSchedule(unittest.TestCase):
out0 = a.sum() + b.sum() + 2
out1 = a.sum() + b.sum() + 4
# run_schedule(check_schedule([out0, out1], 1))
run_schedule(check_schedule([out0, out1], 4))
run_schedule(check_schedule([out0, out1], 2))
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
@@ -649,7 +649,7 @@ class TestSchedule(unittest.TestCase):
out1 = b.max() + out0*2
out2 = a.sum() + out1
# run_schedule(check_schedule([out0, out1, out2], 1))
run_schedule(check_schedule([out0, out1, out2], 4))
run_schedule(check_schedule([out0, out1, out2], 3))
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
@@ -1096,7 +1096,7 @@ class TestSchedule(unittest.TestCase):
c = a.sum() + 2
d = (a.sum() - b.sum()) * 4
# run_schedule(check_schedule([c, d], 1))
run_schedule(check_schedule([c, d], 3))
run_schedule(check_schedule([c, d], 2))
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)