simpler newton_schulz transpose (#12853)

This commit is contained in:
chenyu
2025-10-21 17:21:45 -04:00
committed by GitHub
parent 60d7e232f2
commit 0b673eddec

View File

@@ -4149,10 +4149,10 @@ class Tensor(MathTrait):
```
"""
assert self.ndim > 1, "NS only works for two or more dims"
if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1)
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
if (swap := self.shape[-2] > self.shape[-1]): G = G.transpose(-2, -1)
for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
return G.transpose(-2, -1) if swap else G
return G
def qr(self) -> tuple[Tensor, Tensor]:
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"