From 0b464df605abc1de9bebce59dfeda689c8a0f098 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 16 May 2024 18:30:49 +0800 Subject: [PATCH] base change scheduling spec (#4613) * spec and kernel cnt * dont use half * skip half --- test/test_schedule.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index cedecf5616..00d0bbc452 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -807,5 +807,40 @@ class TestSchedule(unittest.TestCase): with self.assertRaises(AssertionError): np.testing.assert_equal(out.numpy(), [2, 0]) + def test_base_change_shrink_pad(self): + a = Tensor.ones(3, 3).contiguous().realize() + b = a.exp2() + c = b[:-1, :-1] + d = c.pad(((0, 1), (0, 1))) * 2 + run_schedule(check_schedule(d, 1)) + with self.assertRaises(AssertionError): # TODO unsafe pads + np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2) + + def test_base_change_expand_pad(self): + a = Tensor.ones(3, 3).contiguous().realize() + b = a.exp2() + c = b[:, None, :] + d = c.pad(((0, 0), (1, 1), (0, 0))) * 2 + run_schedule(check_schedule(d, 2)) + np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2) + + # TODO like openpilot with imagef + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_base_change_expand_expand(self): + a = Tensor.ones(4, 4).contiguous().realize() + b = a.cast(dtypes.half).expand(2, 4, 4) + c = b.cast(dtypes.int).expand(2, 2, 4, 4) + run_schedule(check_schedule(c, 2)) + np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32)) + + def test_base_change_pad_expand(self): + a = Tensor.full((4, 4), 1.).contiguous().realize() + b = Tensor.full((4, 4), 2.).contiguous().realize() + c = (a + b).pad(((1, 1), (1, 1))) + d = c.cast(dtypes.int).expand((2, 6, 6)) * 4 + run_schedule(check_schedule(d, 2)) + c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0) + np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4) + if __name__ == '__main__': unittest.main(verbosity=2)