mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simpler newton_schulz transpose (#12853)
This commit is contained in:
@@ -4149,10 +4149,10 @@ class Tensor(MathTrait):
|
||||
```
|
||||
"""
|
||||
assert self.ndim > 1, "NS only works for two or more dims"
|
||||
if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1)
|
||||
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
|
||||
if (swap := self.shape[-2] > self.shape[-1]): G = G.transpose(-2, -1)
|
||||
for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
|
||||
return G.transpose(-2, -1) if swap else G
|
||||
return G
|
||||
|
||||
def qr(self) -> tuple[Tensor, Tensor]:
|
||||
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
|
||||
|
||||
Reference in New Issue
Block a user