Files
tinygrad/test/unit/test_call.py
2026-02-03 13:51:02 +08:00

97 lines
3.2 KiB
Python

import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp
class TestCall(unittest.TestCase):
def test_call_plus(self):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
Tensor.realize(a,b)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn)
np.testing.assert_equal(c.numpy(), (a+b).numpy())
def test_call_plus_backward(self):
a = Tensor.ones(10, 10, requires_grad=True)
b = Tensor.ones(10, 10, requires_grad=True)
(a+b).mean().backward()
gt_a_grad = a.grad.numpy()
gt_b_grad = b.grad.numpy()
a.grad, b.grad = None, None
# this is the gradient for +
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_plus_backward_auto(self):
a = Tensor.ones(10, 10, requires_grad=True)
b = Tensor.ones(10, 10, requires_grad=True)
(a+b).mean().backward()
gt_a_grad = a.grad.numpy()
gt_b_grad = b.grad.numpy()
a.grad, b.grad = None, None
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_gemm(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
@unittest.skip("needs GEMM on mixins")
def test_call_gemm_uop(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
# we define a gemm function
x = UOp.param(0, dtypes.float, shape=(M, K))
y = UOp.param(1, dtypes.float, shape=(K, N))
c = Tensor.call(a, b, fxn=x@y)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
def test_call_complex_backward_auto(self):
# complex chain: (a*b + a).exp2() * b.reciprocal() - tests mul, add, exp2, reciprocal, param reuse
a = Tensor.randn(10, 10, requires_grad=True)
b = Tensor.randn(10, 10, requires_grad=True) + 2 # avoid div by zero
Tensor.realize(a, b)
((a*b + a).exp2() * b.reciprocal()).mean().backward()
gt_a_grad, gt_b_grad = a.grad.numpy(), b.grad.numpy()
a.grad, b.grad = None, None
p0, p1 = UOp.param(0, dtypes.float, (10,10)), UOp.param(1, dtypes.float, (10,10))
complex_fxn = (p0*p1 + p0).exp2() * p1.reciprocal()
c = Tensor.call(a, b, fxn=complex_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
if __name__ == '__main__':
unittest.main()