mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
simplify logcumsumexp (#10908)
clarify and remove some flatten and squeeze/unsqueeze
This commit is contained in:
@@ -2080,7 +2080,7 @@ class Tensor(MathTrait):
|
||||
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which
|
||||
the log-cum-sum-exp is computed.
|
||||
the log-cumsum-exp is computed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
@@ -2100,12 +2100,11 @@ class Tensor(MathTrait):
|
||||
if self.ndim == 0: return self
|
||||
x = self.transpose(axis, -1)
|
||||
last_dim_size = x.shape[-1]
|
||||
x_reshaped = x.reshape(-1, last_dim_size)
|
||||
x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
|
||||
x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
|
||||
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
|
||||
ret = mask.where(x_expand - x_cummax, dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax.squeeze(-1)
|
||||
return ret.reshape(*x.shape).transpose(-1, axis)
|
||||
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
|
||||
x_cummax = x.cummax(-1)
|
||||
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
|
||||
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
|
||||
return ret.transpose(-1, axis)
|
||||
|
||||
def argmax(self, axis=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user