softmax and friends to mixin (#15778)

with detach now
This commit is contained in:
chenyu
2026-04-16 23:03:37 -04:00
committed by GitHub
parent ec00cefa5b
commit 1fac03ce54
5 changed files with 62 additions and 53 deletions

View File

@@ -70,5 +70,11 @@ class TestTensorUOpStack(unittest.TestCase):
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
class TestTensorUOpSoftmax(unittest.TestCase):
def test_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.softmax())
def test_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.softmax(axis=0))
def test_log_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax())
def test_log_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax(axis=0))
if __name__ == "__main__":
unittest.main()

View File

@@ -260,6 +260,58 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis))
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Self, Self, Self]:
m = self - self.max(axis=axis, keepdim=True).detach()
if dtype is not None: m = m.cast(to_dtype(dtype))
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Self:
"""
Applies the softmax function to the tensor along the specified axis.
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softmax().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softmax(axis=0).numpy())
```
"""
_, e, ss = self._softmax(axis, dtype)
return e * ss.reciprocal()
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Self:
"""
Applies the log-softmax function to the tensor along the specified axis.
The log-softmax function is a numerically stable alternative to the softmax function in log space.
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.log_softmax().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.log_softmax(axis=0).numpy())
```
"""
m, _, ss = self._softmax(axis, dtype)
return m - ss.log()
def cat(self, *args:Self, dim:int=0) -> Self:
"""
Concatenates self with other tensors in `args` along an axis specified by `dim`.

View File

@@ -25,6 +25,10 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
def usum(self, *uops) -> Self: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, argfix(*uops), self)
def uprod(self, *uops) -> Self: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, argfix(*uops), self)
# NOTE: Tensor overrides this to also set requires_grad=False
def detach(self) -> Self:
return self.alu(Ops.DETACH)
def logical_not(self) -> Self:
"""
Computes the logical NOT of the tensor element-wise.

View File

@@ -1439,58 +1439,6 @@ class Tensor(OpMixin):
return data[:16]
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
m = self - self.max(axis=axis, keepdim=True).detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
"""
Applies the softmax function to the tensor along the specified axis.
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softmax().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softmax(axis=0).numpy())
```
"""
_, e, ss = self._softmax(axis, dtype)
return e.div(ss)
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
"""
Applies the log-softmax function to the tensor along the specified axis.
The log-softmax function is a numerically stable alternative to the softmax function in log space.
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.log_softmax().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.log_softmax(axis=0).numpy())
```
"""
m, _, ss = self._softmax(axis, dtype)
return m - ss.log()
def logcumsumexp(self, axis=0) -> Tensor:
"""
Computes the log-cumsum-exp of the tensor along the specified axis or axes.

View File

@@ -425,7 +425,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs, **kwargs):
return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):