mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
25 lines
694 B
Python
25 lines
694 B
Python
import unittest
|
|
from tinygrad import Tensor
|
|
from tinygrad.engine.schedule import schedule_cache
|
|
|
|
class TestScheduleCache(unittest.TestCase):
|
|
def test_simple(self):
|
|
a = Tensor.ones(10).contiguous()
|
|
b = Tensor.ones(10).contiguous()
|
|
Tensor.realize(a, b)
|
|
|
|
# warm up
|
|
for _ in range(2):
|
|
num = (a.sum().contiguous()+b.sum().contiguous()).item()
|
|
print(num)
|
|
|
|
# confirm schedule cache doesn't grow
|
|
start_len_schedule_cache = len(schedule_cache)
|
|
for _ in range(3):
|
|
num = (a.sum().contiguous()+b.sum().contiguous()).item()
|
|
print(num)
|
|
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|