mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
base change scheduling spec (#4613)
* spec and kernel cnt * dont use half * skip half
This commit is contained in:
@@ -807,5 +807,40 @@ class TestSchedule(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_equal(out.numpy(), [2, 0])
|
||||
|
||||
def test_base_change_shrink_pad(self):
|
||||
a = Tensor.ones(3, 3).contiguous().realize()
|
||||
b = a.exp2()
|
||||
c = b[:-1, :-1]
|
||||
d = c.pad(((0, 1), (0, 1))) * 2
|
||||
run_schedule(check_schedule(d, 1))
|
||||
with self.assertRaises(AssertionError): # TODO unsafe pads
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
|
||||
|
||||
def test_base_change_expand_pad(self):
|
||||
a = Tensor.ones(3, 3).contiguous().realize()
|
||||
b = a.exp2()
|
||||
c = b[:, None, :]
|
||||
d = c.pad(((0, 0), (1, 1), (0, 0))) * 2
|
||||
run_schedule(check_schedule(d, 2))
|
||||
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
|
||||
|
||||
# TODO like openpilot with imagef
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_base_change_expand_expand(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.cast(dtypes.half).expand(2, 4, 4)
|
||||
c = b.cast(dtypes.int).expand(2, 2, 4, 4)
|
||||
run_schedule(check_schedule(c, 2))
|
||||
np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32))
|
||||
|
||||
def test_base_change_pad_expand(self):
|
||||
a = Tensor.full((4, 4), 1.).contiguous().realize()
|
||||
b = Tensor.full((4, 4), 2.).contiguous().realize()
|
||||
c = (a + b).pad(((1, 1), (1, 1)))
|
||||
d = c.cast(dtypes.int).expand((2, 6, 6)) * 4
|
||||
run_schedule(check_schedule(d, 2))
|
||||
c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0)
|
||||
np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user