clean up Tensor.svd slices (#13948)

This commit is contained in:
chenyu
2026-01-01 08:18:45 -05:00
committed by GitHub
parent 1c5ed8e8b5
commit baff10d32c

View File

@@ -3690,10 +3690,9 @@ class Tensor(OpMixin):
for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
#extract singular values and sort. construct U from Q
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
U = U.gather(-1, new_indices[..., 0:num, 0:num]) / (S != 0).where(S, 1).unsqueeze(-2)
V = V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
V = V.gather(-1, new_indices).realize()
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
padded_u[..., 0:num, 0:num] = U