mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
clean up Tensor.svd slices (#13948)
This commit is contained in:
@@ -3690,10 +3690,9 @@ class Tensor(OpMixin):
|
||||
for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
|
||||
#extract singular values and sort. construct U from Q
|
||||
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
||||
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
|
||||
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||
U = U.gather(-1, new_indices[..., 0:num, 0:num]) / (S != 0).where(S, 1).unsqueeze(-2)
|
||||
V = V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
||||
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
|
||||
V = V.gather(-1, new_indices).realize()
|
||||
|
||||
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
|
||||
padded_u[..., 0:num, 0:num] = U
|
||||
|
||||
Reference in New Issue
Block a user