mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Fix QR/SVD NaNs on zero/orthogonal inputs (#13943)
This commit is contained in:
@@ -65,6 +65,24 @@ class TestLinAlg(unittest.TestCase):
|
|||||||
orthogonality_helper(Q)
|
orthogonality_helper(Q)
|
||||||
reconstruction_helper([Q,R],a)
|
reconstruction_helper([Q,R],a)
|
||||||
|
|
||||||
|
def test_qr_zero_column(self):
|
||||||
|
a = Tensor([[0.0, 1.0], [0.0, 2.0]]).realize()
|
||||||
|
Q,R = a.qr()
|
||||||
|
assert not np.isnan(Q.numpy()).any()
|
||||||
|
assert not np.isnan(R.numpy()).any()
|
||||||
|
orthogonality_helper(Q)
|
||||||
|
reconstruction_helper([Q,R], a)
|
||||||
|
|
||||||
|
def test_svd_identity(self):
|
||||||
|
for a in (Tensor.eye(2), Tensor.zeros(2, 2)):
|
||||||
|
a = a.realize()
|
||||||
|
U,S,V = a.svd()
|
||||||
|
assert not np.isnan(U.numpy()).any()
|
||||||
|
assert not np.isnan(S.numpy()).any()
|
||||||
|
assert not np.isnan(V.numpy()).any()
|
||||||
|
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
|
||||||
|
reconstruction_helper([U, s_diag, V], a)
|
||||||
|
|
||||||
def test_newton_schulz(self):
|
def test_newton_schulz(self):
|
||||||
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
|
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
|
||||||
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
|
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
|
||||||
|
|||||||
@@ -3637,11 +3637,13 @@ class Tensor(OpMixin):
|
|||||||
Q = Tensor.eye(m, dtype=self.dtype).reshape((1,) * len(b_shape) + (m, m)).expand(b_shape + (m, m)).contiguous()
|
Q = Tensor.eye(m, dtype=self.dtype).reshape((1,) * len(b_shape) + (m, m)).expand(b_shape + (m, m)).contiguous()
|
||||||
for i in range(min(m, n)):
|
for i in range(min(m, n)):
|
||||||
x = R[..., i:m, i].contiguous() # TODO: without contigous this can silently be wrong, should at least assert
|
x = R[..., i:m, i].contiguous() # TODO: without contigous this can silently be wrong, should at least assert
|
||||||
s = -x[..., 0].sign()
|
norm = x.square().sum(-1).sqrt()
|
||||||
u1 = x[..., 0] - s * x.square().sum(-1).sqrt()
|
s = (x[..., 0] != 0).where(-x[..., 0].sign(), -1)
|
||||||
w = x.unsqueeze(-1) / u1.reshape(b_shape + (1, 1))
|
u1 = x[..., 0] - s * norm
|
||||||
|
w = x.unsqueeze(-1) / (norm != 0).where(u1, 1).reshape(b_shape + (1, 1))
|
||||||
w[..., 0, 0] = 1
|
w[..., 0, 0] = 1
|
||||||
tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + (1, 1))
|
tau = (-s * u1 / (norm != 0).where(norm, 1)).reshape(b_shape + (1, 1))
|
||||||
|
tau = (norm != 0).reshape(b_shape + (1, 1)).where(tau, 0)
|
||||||
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
|
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
|
||||||
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1)
|
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1)
|
||||||
return Q,R
|
return Q,R
|
||||||
@@ -3668,8 +3670,10 @@ class Tensor(OpMixin):
|
|||||||
#compute the jacobi rotations for each pairing
|
#compute the jacobi rotations for each pairing
|
||||||
gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2))
|
gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2))
|
||||||
alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
|
alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
|
||||||
tau = (beta - alpha) / (2 * gamma)
|
rot = gamma != 0
|
||||||
|
tau = (beta - alpha) / (2 * rot.where(gamma, 1))
|
||||||
t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
|
t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
|
||||||
|
t = rot.where(t, 0)
|
||||||
c = 1 / (1 + t.square()).sqrt()
|
c = 1 / (1 + t.square()).sqrt()
|
||||||
s = c * t
|
s = c * t
|
||||||
#apply the rotations
|
#apply the rotations
|
||||||
@@ -3688,7 +3692,8 @@ class Tensor(OpMixin):
|
|||||||
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
||||||
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
|
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
|
||||||
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||||
U, V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
U = U.gather(-1, new_indices[..., 0:num, 0:num]) / (S != 0).where(S, 1).unsqueeze(-2)
|
||||||
|
V = V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
||||||
|
|
||||||
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
|
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
|
||||||
padded_u[..., 0:num, 0:num] = U
|
padded_u[..., 0:num, 0:num] = U
|
||||||
|
|||||||
Reference in New Issue
Block a user