simplify logcumsumexp (#10908)

clarify and remove some flatten and squeeze/unsqueeze
This commit is contained in:
chenyu
2025-06-20 22:56:42 -04:00
committed by GitHub
parent fa52bdb50f
commit 1373071f19

View File

@@ -2080,7 +2080,7 @@ class Tensor(MathTrait):
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
You can pass in the `axis` keyword argument to control the axis along which
the log-cum-sum-exp is computed.
the log-cumsum-exp is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
@@ -2100,12 +2100,11 @@ class Tensor(MathTrait):
if self.ndim == 0: return self
x = self.transpose(axis, -1)
last_dim_size = x.shape[-1]
x_reshaped = x.reshape(-1, last_dim_size)
x_cummax = x_reshaped.cummax(-1).unsqueeze(-1)
x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size)
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0)
ret = mask.where(x_expand - x_cummax, dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax.squeeze(-1)
return ret.reshape(*x.shape).transpose(-1, axis)
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
x_cummax = x.cummax(-1)
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
return ret.transpose(-1, axis)
def argmax(self, axis=None, keepdim=False) -> Tensor:
"""