Files
tinygrad/test/external/process_replay/test_diff_schedule.py
qazal 728b7e189e diff_schedule tests [run_process_replay] (#5958)
* diff_schedule tests [run_process_replay]

* ok to run serial
2024-08-07 13:50:27 +03:00

38 lines
1.1 KiB
Python

import unittest
from test.external.process_replay.diff_schedule import diff_schedule
from tinygrad import Tensor
from tinygrad.helpers import Context
from tinygrad.engine.schedule import SCHEDULES
class TestDiffSchedule(unittest.TestCase):
def test_diff_arange(self):
# diff a single arange kernel
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = X[idxs]
with Context(ARANGE_DIFF=1): xt.schedule()
self.assertEqual(len(SCHEDULES), 2)
changed = diff_schedule(SCHEDULES)
self.assertEqual(changed, 1)
SCHEDULES.clear()
# no diff
a = Tensor([1])+Tensor([2])
with Context(ARANGE_DIFF=1): a.schedule()
self.assertEqual(len(SCHEDULES), 2)
changed = diff_schedule(SCHEDULES)
self.assertEqual(changed, 0)
SCHEDULES.clear()
# no diff with two schedule creation calls
a = Tensor([1])+Tensor([2])
with Context(ARANGE_DIFF=1): a.schedule()
b = Tensor([3])+Tensor([4])
with Context(ARANGE_DIFF=1): b.schedule()
self.assertEqual(len(SCHEDULES), 4)
changed = diff_schedule(SCHEDULES)
self.assertEqual(changed, 0)
if __name__ == '__main__':
unittest.main()