import numpy as np import unittest from tinygrad.function import function from tinygrad import Tensor from tinygrad.uop.ops import UOp class TestFunction(unittest.TestCase): def test_simple(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a,b).numpy(), [5,7,9]) def test_simple_same(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3]) np.testing.assert_equal(f(a,a).numpy(), [2,4,6]) def test_implicit(self): inp = Tensor([7,8,9]) @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a,b).numpy(), [12,15,18]) def test_implicit_same_as_input(self): inp = Tensor([7,8,9]) @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp a = Tensor([1,2,3]) np.testing.assert_equal(f(a, inp).numpy(), [15,18,21]) def test_implicit_2(self): inp = Tensor([7,8,9]) @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp inp2 = Tensor([7,8,10]) @function def g(a:Tensor, b:Tensor) -> Tensor: return a+b+inp2 a = Tensor([1,2,3]) b = Tensor([4,5,6]) c = f(a,b) d = g(a,b) c.realize(d) np.testing.assert_equal(c.numpy(), [12,15,18]) np.testing.assert_equal(d.numpy(), [12,15,19]) def test_implicit_unrealized(self): inp = Tensor([1,2,3]) + Tensor([4,5,6]) @function def f(a:Tensor) -> Tensor: return a + inp np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39]) def test_detach(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a.detach() + b a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) def test_contiguous_backward(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return (a + b).contiguous_backward() a = Tensor([1,2,3]) b = Tensor([4,5,6]) np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) def test_method(self): class Foo: def __init__(self): self.w = Tensor([10,20,30]) @function def __call__(self, x:Tensor) -> Tensor: return x + self.w foo = Foo() np.testing.assert_equal(foo(Tensor([1,2,3])).numpy(), [11,22,33]) def test_grad_gemm(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a @ b a = Tensor([[1.,2.],[3.,4.]], requires_grad=True) b = Tensor([[5.,6.],[7.,8.]], requires_grad=True) (f(a, b).contiguous() * b).sum().backward() Tensor.realize(a, b, a.grad, b.grad) # L = sum((a@b) * b), dL/d(a@b) = b, dL/da = b @ b^T, dL/db = a^T @ b + (a@b) na, nb = a.numpy(), b.numpy() np.testing.assert_allclose(a.grad.numpy(), nb @ nb.T) np.testing.assert_allclose(b.grad.numpy(), na.T @ nb + na @ nb) def test_grad_implicit(self): w = Tensor([1., 2., 3.], requires_grad=True) w.realize() # TODO: this is required @function def f(x:Tensor) -> Tensor: return x * w x = Tensor([4., 5., 6.]) f(x).sum().backward() np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.]) def test_symbolic_index(self): table = Tensor([10,20,30,40]).contiguous().realize() @function def f(x:Tensor, start_pos:int|UOp) -> Tensor: return x + table[start_pos] v = UOp.variable("start_pos", 0, 3) np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13]) def test_symbolic_shape_input(self): table = Tensor([10,20,30,40]).contiguous().realize() @function def f(x:Tensor) -> Tensor: return x * 2 sz = UOp.variable("sz", 1, 3) slic = table[:sz.bind(2)] np.testing.assert_equal(f(slic)[:2].numpy(), [20,40]) def test_nested_calls(self): w = Tensor([10., 20., 30.]) @function def f(a:Tensor) -> Tensor: return a + w @function def g(a:Tensor) -> Tensor: return a * w a = Tensor([1., 2., 3.]) np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.]) def test_nested_calls_backward(self): w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize() @function def inner(x:Tensor) -> Tensor: return x + w @function def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2)) a = Tensor([1., 2.], requires_grad=True) b = Tensor([3., 4.], requires_grad=True) outer(a, b).sum().backward() np.testing.assert_allclose(a.grad.numpy(), [2., 2.]) np.testing.assert_allclose(b.grad.numpy(), [2., 2.]) def test_unused_param_backward(self): @function def f(a:Tensor, b:Tensor, c:Tensor) -> Tensor: return a + c # b is unused a = Tensor([1., 2., 3.], requires_grad=True) b = Tensor([4., 5., 6.], requires_grad=True) c = Tensor([7., 8., 9.], requires_grad=True) f(a, b, c).sum().backward() np.testing.assert_allclose(a.grad.numpy(), [1., 1., 1.]) np.testing.assert_allclose(b.grad.numpy(), [0., 0., 0.]) np.testing.assert_allclose(c.grad.numpy(), [1., 1., 1.]) def test_name(self): @function def f(a:Tensor) -> Tensor: return a + 1 assert f(Tensor([1])).uop.arg.name.endswith("f") def test_method_name(self): class Foo: @function def __call__(self, x:Tensor) -> Tensor: return x + 1 assert Foo()(Tensor([1])).uop.arg.name.endswith("Foo.__call__") def test_callable_instance(self): class Foo: def __init__(self): self.w = Tensor([10,20,30]) def __call__(self, x:Tensor) -> Tensor: return x + self.w foo = Foo() f = function(foo) np.testing.assert_equal(f(Tensor([1,2,3])).numpy(), [11,22,33]) assert f(Tensor([1,2,3])).uop.arg.name.endswith("Foo") def test_iadd(self): @function def f(x:Tensor) -> Tensor: x += 1 return x a = Tensor([1,2,3]).realize() np.testing.assert_equal(f(a).numpy(), [2,3,4]) np.testing.assert_equal(a.numpy(), [3,4,5]) # TODO: should be [1,2,3] def test_implicit_assign(self): a = Tensor([1,2,3]) a += 1 c = Tensor([2,2,2]).contiguous() @function def f(b:Tensor) -> Tensor: return a+b+c b = Tensor([10,20,30]).realize() np.testing.assert_equal(f(b).numpy(), [14,25,36]) def test_assign_input(self): @function def f(a:Tensor, b:Tensor) -> Tensor: a.assign(b+1) return a a = Tensor([1,2,3]).realize() b = Tensor([10,20,30]).realize() np.testing.assert_equal(f(a,b).numpy(), [11,21,31]) np.testing.assert_equal(a.numpy(), [11,21,31]) # TODO: should be [1,2,3] np.testing.assert_equal(b.numpy(), [10,20,30]) @unittest.expectedFailure def test_assign_slice(self): @function def f(a:Tensor, b:Tensor) -> Tensor: a[1:] = b[1:]+1 return a a = Tensor([1,2,3]).realize() b = Tensor([10,20,30]).realize() np.testing.assert_equal(f(a,b).numpy(), [1,21,31]) np.testing.assert_equal(a.numpy(), [1,2,3]) np.testing.assert_equal(b.numpy(), [10,20,30]) class TestFunctionMulti(unittest.TestCase): devices_2 = ("CPU:0", "CPU:1") def test_simple_multi(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3,4]).shard(self.devices_2, axis=None) b = Tensor([10,20,30,40]).shard(self.devices_2, axis=None) np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44]) def test_simple_multi_sharded(self): @function def f(a:Tensor, b:Tensor) -> Tensor: return a+b a = Tensor([1,2,3,4]).shard(self.devices_2, axis=0) b = Tensor([10,20,30,40]).shard(self.devices_2, axis=0) np.testing.assert_equal(f(a,b).numpy(), [11,22,33,44]) def test_data_parallel_multi(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor([[1.,2.],[3.,4.],[5.,6.],[7.,8.]]).shard(self.devices_2, axis=0) w = Tensor([[1.,0.],[0.,1.]]).shard(self.devices_2, axis=None) np.testing.assert_allclose(f(x, w).numpy(), [[1.,2.],[3.,4.],[5.,6.],[7.,8.]]) def test_grad_implicit_multi(self): w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None) w.realize() @function def f(x:Tensor) -> Tensor: return x * w x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None) f(x).sum().backward() np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6., 7.]) def test_call_axis(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]).shard(self.devices_2, axis=0) w = Tensor([[1.,2.],[3.,4.]]).shard(self.devices_2, axis=None) result = f(x, w) # CALL output should inherit axis=0 from the sharded input self.assertEqual(result.uop.axis, 0) # reduce on the sharded axis should remove it self.assertIsNone(result.sum().uop.axis) def test_call_axis_shard_inside(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x.shard(self.devices_2, axis=0) @ w.shard(self.devices_2, axis=None) x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]]) w = Tensor([[1.,2.],[3.,4.]]) result = f(x, w) self.assertEqual(result.uop.axis, 0) np.testing.assert_allclose(result.numpy(), x.numpy() @ w.numpy()) def test_data_parallel_backward(self): @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor([[1.,0.],[0.,1.],[1.,1.],[0.,0.]], requires_grad=True).shard(self.devices_2, axis=0) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(self.devices_2, axis=None) w.realize() f(x, w).sum().backward() # d/dx = ones @ w^T = [[1,3],[1,3],[1,3],[1,3]], but sum so ones(4,2) @ w^T? no: # L = sum(x @ w), dL/dx = ones(4,2) @ w^T... actually dL/d(xw) = ones(4,2), dL/dx = ones(4,2) @ w^T np.testing.assert_allclose(x.grad.numpy(), np.ones((4,2)) @ np.array([[1,3],[2,4]])) def test_data_parallel_backward_4(self): devices_4 = tuple(f"CPU:{i}" for i in range(4)) @function def f(x:Tensor, w:Tensor) -> Tensor: return x @ w x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) w.realize() f(x, w).sum().backward() np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]])) def test_data_parallel_backward_implicit(self): devices_4 = tuple(f"CPU:{i}" for i in range(4)) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) w.realize() @function def f(x:Tensor) -> Tensor: return x @ w x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) f(x).sum().backward() np.testing.assert_allclose(x.grad.numpy(), np.ones((8,2)) @ np.array([[1,3],[2,4]])) def test_data_parallel_backward_twice(self): devices_4 = tuple(f"CPU:{i}" for i in range(4)) w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None) w.realize() # pre-init grads like the training loop does w.grad = w.zeros_like().contiguous().realize() @function def f(x:Tensor) -> Tensor: return x @ w expected = np.ones((8,2)) @ np.array([[1,3],[2,4]]) for _ in range(2): x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0) f(x).sum().backward() np.testing.assert_allclose(x.grad.numpy(), expected) if __name__ == '__main__': unittest.main()