mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -70,5 +70,11 @@ class TestTensorUOpStack(unittest.TestCase):
|
||||
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
|
||||
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
|
||||
|
||||
class TestTensorUOpSoftmax(unittest.TestCase):
|
||||
def test_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.softmax())
|
||||
def test_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.softmax(axis=0))
|
||||
def test_log_softmax_default(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax())
|
||||
def test_log_softmax_axis0(self): _check(self, _t(2, 3).float(), lambda x: x.log_softmax(axis=0))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -260,6 +260,58 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis))
|
||||
|
||||
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Self, Self, Self]:
|
||||
m = self - self.max(axis=axis, keepdim=True).detach()
|
||||
if dtype is not None: m = m.cast(to_dtype(dtype))
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Applies the softmax function to the tensor along the specified axis.
|
||||
|
||||
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softmax().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softmax(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
_, e, ss = self._softmax(axis, dtype)
|
||||
return e * ss.reciprocal()
|
||||
|
||||
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Self:
|
||||
"""
|
||||
Applies the log-softmax function to the tensor along the specified axis.
|
||||
|
||||
The log-softmax function is a numerically stable alternative to the softmax function in log space.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.log_softmax().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.log_softmax(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
m, _, ss = self._softmax(axis, dtype)
|
||||
return m - ss.log()
|
||||
|
||||
def cat(self, *args:Self, dim:int=0) -> Self:
|
||||
"""
|
||||
Concatenates self with other tensors in `args` along an axis specified by `dim`.
|
||||
|
||||
@@ -25,6 +25,10 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
|
||||
def usum(self, *uops) -> Self: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, argfix(*uops), self)
|
||||
def uprod(self, *uops) -> Self: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, argfix(*uops), self)
|
||||
|
||||
# NOTE: Tensor overrides this to also set requires_grad=False
|
||||
def detach(self) -> Self:
|
||||
return self.alu(Ops.DETACH)
|
||||
|
||||
def logical_not(self) -> Self:
|
||||
"""
|
||||
Computes the logical NOT of the tensor element-wise.
|
||||
|
||||
@@ -1439,58 +1439,6 @@ class Tensor(OpMixin):
|
||||
|
||||
return data[:16]
|
||||
|
||||
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
|
||||
m = self - self.max(axis=axis, keepdim=True).detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies the softmax function to the tensor along the specified axis.
|
||||
|
||||
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softmax().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softmax(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
_, e, ss = self._softmax(axis, dtype)
|
||||
return e.div(ss)
|
||||
|
||||
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies the log-softmax function to the tensor along the specified axis.
|
||||
|
||||
The log-softmax function is a numerically stable alternative to the softmax function in log space.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.log_softmax().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.log_softmax(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
m, _, ss = self._softmax(axis, dtype)
|
||||
return m - ss.log()
|
||||
|
||||
def logcumsumexp(self, axis=0) -> Tensor:
|
||||
"""
|
||||
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -425,7 +425,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def vectorize(self, *srcs, **kwargs):
|
||||
return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx):
|
||||
|
||||
Reference in New Issue
Block a user