mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* close * simple callify, support linear in the scheduler * all tests pass * everyone is happy * dumb test * Remove unnecessary blank line in rangeify.py
112 lines
2.7 KiB
Python
112 lines
2.7 KiB
Python
import unittest
|
|
from tinygrad import Tensor, dtypes
|
|
|
|
class TestCallify(unittest.TestCase):
|
|
def test_basic(self):
|
|
a = Tensor([1.,2,3])
|
|
b = Tensor([4.,5,6])
|
|
out = a + b
|
|
out.callify()
|
|
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
|
|
|
def test_const(self):
|
|
out = Tensor(2.0) + Tensor(3.0)
|
|
out.callify()
|
|
self.assertEqual(out.item(), 5.0)
|
|
|
|
def test_sum(self):
|
|
out = Tensor.ones(16).contiguous().sum()
|
|
out.callify()
|
|
self.assertEqual(out.item(), 16.0)
|
|
|
|
def test_multi_output(self):
|
|
a = Tensor([1.,2,3])
|
|
b = Tensor([4.,5,6])
|
|
c = a + b
|
|
d = a * b
|
|
c.callify(d)
|
|
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
|
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
|
|
|
|
def test_two_callify_independent(self):
|
|
a = Tensor([1.,2,3])
|
|
b = Tensor([4.,5,6])
|
|
c = a + b
|
|
c.callify()
|
|
|
|
d = Tensor([10.,20,30])
|
|
e = Tensor([1.,1,1])
|
|
f = d - e
|
|
f.callify()
|
|
|
|
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
|
self.assertListEqual(f.tolist(), [9.0, 19.0, 29.0])
|
|
|
|
def test_two_callify_shared_input(self):
|
|
a = Tensor([1.,2,3]).contiguous().realize()
|
|
b = a + 1
|
|
b.callify()
|
|
c = a * 2
|
|
c.callify()
|
|
self.assertListEqual(b.tolist(), [2.0, 3.0, 4.0])
|
|
self.assertListEqual(c.tolist(), [2.0, 4.0, 6.0])
|
|
|
|
def test_chained_callify(self):
|
|
a = Tensor([1.,2,3])
|
|
b = a + 1
|
|
b.callify()
|
|
b.realize()
|
|
c = b + 1
|
|
c.callify()
|
|
self.assertListEqual(c.tolist(), [3.0, 4.0, 5.0])
|
|
|
|
def test_gemm(self):
|
|
a = Tensor.ones(8, 8).contiguous()
|
|
b = Tensor.eye(8).contiguous()
|
|
out = a @ b
|
|
out.callify()
|
|
lst = out.tolist()
|
|
for y in range(8):
|
|
for x in range(8):
|
|
self.assertEqual(lst[y][x], 1.0)
|
|
|
|
def test_int_dtype(self):
|
|
a = Tensor([1,2,3], dtype=dtypes.int)
|
|
b = Tensor([4,5,6], dtype=dtypes.int)
|
|
out = a + b
|
|
out.callify()
|
|
self.assertListEqual(out.tolist(), [5, 7, 9])
|
|
|
|
def test_reduce(self):
|
|
out = Tensor([1.,2,3,4]).sum()
|
|
out.callify()
|
|
self.assertEqual(out.item(), 10.0)
|
|
|
|
def test_multiple_ops(self):
|
|
a = Tensor([1.,2,3])
|
|
b = Tensor([4.,5,6])
|
|
out = (a + b) * (a - b)
|
|
out.callify()
|
|
self.assertListEqual(out.tolist(), [-15.0, -21.0, -27.0])
|
|
|
|
def test_double_callify(self):
|
|
a = Tensor([1.,2,3])
|
|
b = Tensor([4.,5,6])
|
|
out = a + b
|
|
out.callify()
|
|
out.callify()
|
|
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
|
|
|
def test_double_callify_multi_output(self):
|
|
a = Tensor([1.,2,3])
|
|
b = Tensor([4.,5,6])
|
|
c = a + b
|
|
d = a * b
|
|
c.callify(d)
|
|
c.callify(d)
|
|
self.assertListEqual(c.tolist(), [5.0, 7.0, 9.0])
|
|
self.assertListEqual(d.tolist(), [4.0, 10.0, 18.0])
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|