From 1fac03ce540bfd0c46d4d3d2dd140bf0924dce83 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 16 Apr 2026 23:03:37 -0400 Subject: [PATCH] softmax and friends to mixin (#15778) with detach now --- test/null/test_tensor_uop_mixin.py | 6 ++++ tinygrad/mixin/__init__.py | 52 ++++++++++++++++++++++++++++++ tinygrad/mixin/elementwise.py | 4 +++ tinygrad/tensor.py | 52 ------------------------------ tinygrad/uop/ops.py | 1 - 5 files changed, 62 insertions(+), 53 deletions(-) diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 2f05a6f3a3..6b498f5182 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -70,5 +70,11 @@ class TestTensorUOpStack(unittest.TestCase): def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0)) def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1)) +class TestTensorUOpSoftmax(unittest.TestCase): + def test_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.softmax()) + def test_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.softmax(axis=0)) + def test_log_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax()) + def test_log_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax(axis=0)) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 8aa9e38c0f..ab2efc9e78 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -260,6 +260,58 @@ class OpMixin(ElementwiseMixin, ReduceMixin): m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis)) + def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Self, Self, Self]: + m = self - self.max(axis=axis, keepdim=True).detach() + if dtype is not None: m = m.cast(to_dtype(dtype)) + e = m.exp() + return m, e, e.sum(axis=axis, keepdim=True) + + def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Self: + """ + Applies the softmax function to the tensor along the specified axis. + + Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1. + + You can pass in the `axis` keyword argument to control the axis along which the softmax is computed. + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + t = Tensor.randn(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.softmax().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.softmax(axis=0).numpy()) + ``` + """ + _, e, ss = self._softmax(axis, dtype) + return e * ss.reciprocal() + + def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Self: + """ + Applies the log-softmax function to the tensor along the specified axis. + + The log-softmax function is a numerically stable alternative to the softmax function in log space. + + You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed. + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + t = Tensor.randn(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.log_softmax().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.log_softmax(axis=0).numpy()) + ``` + """ + m, _, ss = self._softmax(axis, dtype) + return m - ss.log() + def cat(self, *args:Self, dim:int=0) -> Self: """ Concatenates self with other tensors in `args` along an axis specified by `dim`. diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index 4ab02dcbcc..9df2c03bce 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -25,6 +25,10 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): def usum(self, *uops) -> Self: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, argfix(*uops), self) def uprod(self, *uops) -> Self: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, argfix(*uops), self) + # NOTE: Tensor overrides this to also set requires_grad=False + def detach(self) -> Self: + return self.alu(Ops.DETACH) + def logical_not(self) -> Self: """ Computes the logical NOT of the tensor element-wise. diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f4127b99fe..e3abb0f3fb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1439,58 +1439,6 @@ class Tensor(OpMixin): return data[:16] - def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]: - m = self - self.max(axis=axis, keepdim=True).detach() - if dtype is not None: m = m.cast(dtype) - e = m.exp() - return m, e, e.sum(axis=axis, keepdim=True) - - def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor: - """ - Applies the softmax function to the tensor along the specified axis. - - Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1. - - You can pass in the `axis` keyword argument to control the axis along which the softmax is computed. - - ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.softmax().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.softmax(axis=0).numpy()) - ``` - """ - _, e, ss = self._softmax(axis, dtype) - return e.div(ss) - - def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor: - """ - Applies the log-softmax function to the tensor along the specified axis. - - The log-softmax function is a numerically stable alternative to the softmax function in log space. - - You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed. - - ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.log_softmax().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.log_softmax(axis=0).numpy()) - ``` - """ - m, _, ss = self._softmax(axis, dtype) - return m - ss.log() - def logcumsumexp(self, axis=0) -> Tensor: """ Computes the log-cumsum-exp of the tensor along the specified axis or axes. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index dbb0f1657a..5c46025e04 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -425,7 +425,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) def vectorize(self, *srcs, **kwargs): return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs) - def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, *srcs:UOp|None, ptr=False, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) def __getitem__(self, idx):