mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
beautiful_mnist -4.3% kernels (#5709)
* add is_complete * partially delete forced_realized * p2 * start * refactor to can_group * remove steps * _get_inputs is nicer * fix the cache * cache is dict now * rename to group
This commit is contained in:
@@ -209,7 +209,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 16)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -256,7 +256,7 @@ class TestSchedule(unittest.TestCase):
|
||||
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||
fw.sum().backward()
|
||||
# TODO: this is too many
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
@@ -620,7 +620,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out0 = a.sum() + b.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
# run_schedule(check_schedule([out0, out1], 1))
|
||||
run_schedule(check_schedule([out0, out1], 4))
|
||||
run_schedule(check_schedule([out0, out1], 2))
|
||||
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -649,7 +649,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out1 = b.max() + out0*2
|
||||
out2 = a.sum() + out1
|
||||
# run_schedule(check_schedule([out0, out1, out2], 1))
|
||||
run_schedule(check_schedule([out0, out1, out2], 4))
|
||||
run_schedule(check_schedule([out0, out1, out2], 3))
|
||||
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
|
||||
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
|
||||
@@ -1096,7 +1096,7 @@ class TestSchedule(unittest.TestCase):
|
||||
c = a.sum() + 2
|
||||
d = (a.sum() - b.sum()) * 4
|
||||
# run_schedule(check_schedule([c, d], 1))
|
||||
run_schedule(check_schedule([c, d], 3))
|
||||
run_schedule(check_schedule([c, d], 2))
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user