Files
tinygrad/test/unit/test_callify.py
George Hotz 8a6dffc87e Tensor.callify will be the JIT (#14983)
* close

* simple callify, support linear in the scheduler

* all tests pass

* everyone is happy

* dumb test

* Remove unnecessary blank line in rangeify.py
2026-02-24 18:42:24 +08:00

112 lines
2.7 KiB
Python

import unittest
from tinygrad import Tensor, dtypes
class TestCallify(unittest.TestCase):
def test_basic(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = a + b
out.callify()
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
def test_const(self):
out = Tensor(2.0) + Tensor(3.0)
out.callify()
self.assertEqual(out.item(), 5.0)
def test_sum(self):
out = Tensor.ones(16).contiguous().sum()
out.callify()
self.assertEqual(out.item(), 16.0)
def test_multi_output(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
c = a + b
d = a * b
c.callify(d)
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
def test_two_callify_independent(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
c = a + b
c.callify()
d = Tensor([10.,20,30])
e = Tensor([1.,1,1])
f = d - e
f.callify()
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
self.assertListEqual(f.tolist(), [9.0, 19.0, 29.0])
def test_two_callify_shared_input(self):
a = Tensor([1.,2,3]).contiguous().realize()
b = a + 1
b.callify()
c = a * 2
c.callify()
self.assertListEqual(b.tolist(), [2.0, 3.0, 4.0])
self.assertListEqual(c.tolist(), [2.0, 4.0, 6.0])
def test_chained_callify(self):
a = Tensor([1.,2,3])
b = a + 1
b.callify()
b.realize()
c = b + 1
c.callify()
self.assertListEqual(c.tolist(), [3.0, 4.0, 5.0])
def test_gemm(self):
a = Tensor.ones(8, 8).contiguous()
b = Tensor.eye(8).contiguous()
out = a @ b
out.callify()
lst = out.tolist()
for y in range(8):
for x in range(8):
self.assertEqual(lst[y][x], 1.0)
def test_int_dtype(self):
a = Tensor([1,2,3], dtype=dtypes.int)
b = Tensor([4,5,6], dtype=dtypes.int)
out = a + b
out.callify()
self.assertListEqual(out.tolist(), [5, 7, 9])
def test_reduce(self):
out = Tensor([1.,2,3,4]).sum()
out.callify()
self.assertEqual(out.item(), 10.0)
def test_multiple_ops(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = (a + b) * (a - b)
out.callify()
self.assertListEqual(out.tolist(), [-15.0, -21.0, -27.0])
def test_double_callify(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
out = a + b
out.callify()
out.callify()
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
def test_double_callify_multi_output(self):
a = Tensor([1.,2,3])
b = Tensor([4.,5,6])
c = a + b
d = a * b
c.callify(d)
c.callify(d)
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
if __name__ == "__main__":
unittest.main()