diff --git a/test/test_linalg.py b/test/test_linalg.py new file mode 100644 index 0000000000..ab2db74c5c --- /dev/null +++ b/test/test_linalg.py @@ -0,0 +1,66 @@ +import numpy as np +import unittest +from tinygrad import Tensor +from typing import List +import functools + +def orthogonality_helper(A:Tensor,tolerance=1.0e-5): + b_shape,m = A.shape[0:-2],A.shape[-2] #outer dimension should be the dim along orthogonality + A_identity = (Tensor.eye(m).reshape((1,) * len(b_shape)+(m,m)).expand(b_shape+(m,m))) + np.testing.assert_allclose((A @ A.transpose(-2,-1)).numpy(),A_identity.numpy(),atol=tolerance,rtol=tolerance) + +def reconstruction_helper(A:List[Tensor],B:Tensor, tolerance=1.0e-5): + reconstructed_tensor = functools.reduce(Tensor.matmul, A) + np.testing.assert_allclose(reconstructed_tensor.numpy(),B.numpy(),atol=tolerance,rtol=tolerance) + +class TestLinAlg(unittest.TestCase): + + def test_svd_general(self): + sizes = [(2,2),(5,3),(3,5),(2,2,2,2,3)] + for size in sizes: + a = Tensor.randn(size).realize() + U,S,V = Tensor.svd(a) + b_shape,m,n = size[0:-2],size[-2],size[-1] + k = min(m,n) + s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k))) + s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([(0,0) for _ in range(len(size)-2)] + [(0,m-k), (0,n-k)])) + orthogonality_helper(U) + orthogonality_helper(V) + reconstruction_helper([U,s_diag,V],a) + + def test_svd_nonfull(self): + sizes = [(2,2),(5,3),(3,5),(2,2,2,2,3)] + for size in sizes: + a = Tensor.randn(size).realize() + U,S,V = Tensor.svd(a,full_matrices=False) + b_shape,m,n = size[0:-2],size[-2],size[-1] + k = min(m,n) + s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k))) + #reduced U,V is only orthogonal along smaller dim + if (m < n): orthogonality_helper(U),orthogonality_helper(V) + else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1)) + reconstruction_helper([U,s_diag,V],a) + + @unittest.skip("very big. recommend wrapping with TinyJit around inner function") + def test_svd_large(self): + size = (1024,1024) + a = Tensor.randn(size).realize() + U,S,V = Tensor.svd(a) + b_shape,m,n = size[0:-2],size[-2],size[-1] + k = min(m,n) + s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k))) + s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([(0,0) for _ in range(len(size)-2)] + [(0,m-k), (0,n-k)])) + orthogonality_helper(U,tolerance=1.0e-3) + orthogonality_helper(V,tolerance=1.0e-3) + reconstruction_helper([U,s_diag,V],a,tolerance=1.0e-3) + + def test_qr_general(self): + sizes = [(3,3),(3,6),(6,3),(2,2,2,2,2)] + for size in sizes: + a = Tensor.randn(size).realize() + Q,R = Tensor.qr(a) + orthogonality_helper(Q) + reconstruction_helper([Q,R],a) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 477655c1a5..5d37ec4ea4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3979,6 +3979,72 @@ class Tensor(MathTrait): nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction) + def qr(self) -> tuple[Tensor, Tensor]: + assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}" + R = self.clone() + b_shape, m, n = self.shape[0:self.ndim - 2], int(R.shape[-2]), int(R.shape[-1]) + Q = Tensor.eye(m, dtype = self.dtype).reshape((1,) * (len(self.shape) - 2) + 2 * (m,)).expand(b_shape + 2 * (m,)).contiguous() + for i in range(int(min(m, n))): + x = R[..., i:m, i] + s = -x[..., 0].sign() + u1 = x[..., 0] - s * x.square().sum(-1).sqrt() + w = x.unsqueeze(-1) / u1.reshape(b_shape + 2 * (1,)) + w[..., 0, 0] = 1 + tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + 2 * (1,)).expand(w.shape) + R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :]) + Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau.transpose(-2, -1) * w.transpose(-2, -1)) + return Q,R + + def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]: + #partial implementation of https://www.netlib.org/lapack/lawnspdf/lawn169.pdf , pg 26 + assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}" + b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1]) + #preprocess the matrix + Q, R = (Tensor.qr(self) if m >= n else Tensor.qr(self.transpose(-2, -1))) + num, q_num = int(min(m, n)), int(max(m, n)) + U = R.shrink(tuple([(0, self.shape[i]) for i in range(self.ndim - 2)] + [(0, num), (0, num)])).contiguous() + V = Tensor.eye(num, dtype = self.dtype).reshape((1,) * (self.ndim - 2) + (num, num)).expand(b_shape + 2 * (num,)).contiguous() + #prepare round robin pairing + permute, inverse_permute = Tensor.arange(0, num, dtype = dtypes.int), Tensor.zeros(num, dtype = dtypes.int).contiguous() + permute[num//2:num] = permute[num//2:num].flip(0) + inverse_permute[permute] = Tensor.arange(num, dtype = dtypes.int) + def one_round_jacobi(U, V,permute,inverse_permute): + #pair all the columns + V_permuted, runoff_V = (V[..., permute].split(num - 1, -1)) if num % 2 == 1 else (V[..., permute], None) + V_left, V_right = V_permuted.split(num//2, -1) + U_permuted, runoff_U = (U[..., permute].split(num - 1, -1)) if num % 2 == 1 else (U[..., permute], None) + U_left, U_right = U_permuted.split(num//2, -1) + #compute the jacobi rotations for each pairing + gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2)) + alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1) + tau = (beta - alpha) / (2 * gamma) + t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt()) + c = 1 / (1 + t.square()).sqrt() + s = c * t + #apply the rotations + U_left, U_right = c * U_left - s * U_right, s * U_left + c * U_right + U = U_left.cat(U_right.cat(runoff_U, dim = -1) if num % 2 == 1 else U_right, dim = -1)[..., inverse_permute] + V_left, V_right = c * V_left - s * V_right, s * V_left + c * V_right + V = V_left.cat(V_right.cat(runoff_V, dim = -1) if num % 2 == 1 else V_right, dim = -1)[..., inverse_permute] + #prepare the next round robin pairings + if num % 2 == 1: permute = ((permute - 1) % num) + else: permute = permute[0].reshape(1).cat(((permute[1:num] - 2) % (num - 1)) + 1) + inverse_permute = inverse_permute.scatter(0,permute,Tensor.arange(num,dtype=dtypes.int32)) + return U, V, permute, inverse_permute + max_iterations, iterations_per_round = 1, int((num) * math.log2(num) * 2 + 2)#sorta heuristic, most use num*log2(num) + for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute) + #extract singular values and sort. construct U from Q + S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True) + new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + 2 * (num,)).contiguous() + new_indices[..., :num] = indices.reshape(b_shape + (1,) + (U.shape[0],)).expand(b_shape + 2 * (num,)) + U,V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]) + + padded_u = Tensor.eye(q_num, dtype = U.dtype).reshape((1,) * (self.ndim - 2) + 2 * (q_num,)).expand(b_shape + 2 * (q_num,)).contiguous() + padded_u[..., 0:num, 0:num] = U + U = Q @ padded_u + if not full_matrices: U, V = U[..., 0:num], V[..., 0:num] + return (U, S, V.transpose(-2,-1)) if m >= n else (V, S, U.transpose(-2, -1)) + # ***** Tensor Properties ***** @property