From 1373071f19d5871144ce3fceeba0b108d4890464 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 20 Jun 2025 22:56:42 -0400 Subject: [PATCH] simplify logcumsumexp (#10908) clarify and remove some flatten and squeeze/unsqueeze --- tinygrad/tensor.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 65d449be44..9a4a599d6e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2080,7 +2080,7 @@ class Tensor(MathTrait): The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials. You can pass in the `axis` keyword argument to control the axis along which - the log-cum-sum-exp is computed. + the log-cumsum-exp is computed. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) @@ -2100,12 +2100,11 @@ class Tensor(MathTrait): if self.ndim == 0: return self x = self.transpose(axis, -1) last_dim_size = x.shape[-1] - x_reshaped = x.reshape(-1, last_dim_size) - x_cummax = x_reshaped.cummax(-1).unsqueeze(-1) - x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size) - mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0) - ret = mask.where(x_expand - x_cummax, dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax.squeeze(-1) - return ret.reshape(*x.shape).transpose(-1, axis) + x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None)) + x_cummax = x.cummax(-1) + mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril() + ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax + return ret.transpose(-1, axis) def argmax(self, axis=None, keepdim=False) -> Tensor: """