mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* start function and add walk rewrite * work * add function on feed_forward * llm progress * stuff * none of that
43 lines
877 B
Python
43 lines
877 B
Python
import unittest
|
|
from tinygrad.function import function
|
|
from tinygrad import Tensor
|
|
|
|
class TestFunction(unittest.TestCase):
|
|
def test_simple(self):
|
|
@function
|
|
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
|
|
|
|
a = Tensor([1,2,3])
|
|
b = Tensor([4,5,6])
|
|
c = f(a,b)
|
|
c.realize()
|
|
|
|
def test_implicit(self):
|
|
inp = Tensor([7,8,9])
|
|
@function
|
|
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
|
|
|
|
a = Tensor([1,2,3])
|
|
b = Tensor([4,5,6])
|
|
c = f(a,b)
|
|
c.realize()
|
|
|
|
def test_implicit_2(self):
|
|
inp = Tensor([7,8,9])
|
|
@function
|
|
def f(a:Tensor, b:Tensor) -> Tensor:
|
|
return a+b+inp
|
|
inp2 = Tensor([7,8,10])
|
|
@function
|
|
def g(a:Tensor, b:Tensor) -> Tensor:
|
|
return a+b+inp2
|
|
|
|
a = Tensor([1,2,3])
|
|
b = Tensor([4,5,6])
|
|
c = f(a,b)
|
|
d = g(a,b)
|
|
c.realize(d)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|