import unittest import numpy as np from tinygrad import Tensor from tinygrad.dtype import dtypes from tinygrad.uop.ops import UOp class TestCall(unittest.TestCase): def test_call_plus(self): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) Tensor.realize(a,b) # we define a plus function plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10)) c = Tensor.call(a, b, fxn=plus_fxn) np.testing.assert_equal(c.numpy(), (a+b).numpy()) def test_call_plus_backward(self): a = Tensor.ones(10, 10, requires_grad=True) b = Tensor.ones(10, 10, requires_grad=True) (a+b).mean().backward() gt_a_grad = a.grad.numpy() gt_b_grad = b.grad.numpy() a.grad, b.grad = None, None # this is the gradient for + def grad_fxn(grad:UOp, call:UOp): return (grad, grad) # we define a plus function plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10)) c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn) c.mean().backward() np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5) np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5) def test_call_plus_backward_auto(self): a = Tensor.ones(10, 10, requires_grad=True) b = Tensor.ones(10, 10, requires_grad=True) (a+b).mean().backward() gt_a_grad = a.grad.numpy() gt_b_grad = b.grad.numpy() a.grad, b.grad = None, None plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10)) c = Tensor.call(a, b, fxn=plus_fxn) c.mean().backward() np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5) np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5) def test_call_gemm(self): M, K, N = 4, 8, 4 a = Tensor.randn(M, K) b = Tensor.randn(K, N) Tensor.realize(a, b) c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1)) np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6) @unittest.skip("needs GEMM on mixins") def test_call_gemm_uop(self): M, K, N = 4, 8, 4 a = Tensor.randn(M, K) b = Tensor.randn(K, N) Tensor.realize(a, b) # we define a gemm function x = UOp.param(0, dtypes.float, shape=(M, K)) y = UOp.param(1, dtypes.float, shape=(K, N)) c = Tensor.call(a, b, fxn=x@y) np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6) def test_call_complex_backward_auto(self): # complex chain: (a*b + a).exp2() * b.reciprocal() - tests mul, add, exp2, reciprocal, param reuse a = Tensor.randn(10, 10, requires_grad=True) b = Tensor.randn(10, 10, requires_grad=True) + 2 # avoid div by zero Tensor.realize(a, b) ((a*b + a).exp2() * b.reciprocal()).mean().backward() gt_a_grad, gt_b_grad = a.grad.numpy(), b.grad.numpy() a.grad, b.grad = None, None p0, p1 = UOp.param(0, dtypes.float, (10,10)), UOp.param(1, dtypes.float, (10,10)) complex_fxn = (p0*p1 + p0).exp2() * p1.reciprocal() c = Tensor.call(a, b, fxn=complex_fxn) c.mean().backward() np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5) np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5) if __name__ == '__main__': unittest.main()