remove realize in Tensor.svd (#14623)

This commit is contained in:
chenyu
2026-02-08 09:36:31 -05:00
committed by GitHub
parent 087dab4c3b
commit c28f7d0167

View File

@@ -3784,7 +3784,7 @@ class Tensor(OpMixin):
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
V = V.gather(-1, new_indices).realize()
V = V.gather(-1, new_indices)
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
padded_u[..., 0:num, 0:num] = U