mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add Tensor.nll_loss (#7683)
* move nll_loss to new branch * make nll_loss examples practical * self *is* * add to docs * small
This commit is contained in:
@@ -44,3 +44,4 @@
|
||||
::: tinygrad.Tensor.binary_crossentropy_logits
|
||||
::: tinygrad.Tensor.sparse_categorical_crossentropy
|
||||
::: tinygrad.Tensor.cross_entropy
|
||||
::: tinygrad.Tensor.nll_loss
|
||||
|
||||
@@ -2201,6 +2201,36 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
|
||||
lambda x,y: x.cross_entropy(y, label_smoothing=ls))
|
||||
|
||||
def test_nll_loss(self):
|
||||
helper_test_op([(32,10), (32)], lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
|
||||
torch.clip(y,0).type(torch.long)),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
|
||||
|
||||
def test_nll_loss_reductions(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32)], lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
|
||||
torch.clip(y,0).type(torch.long), reduction=r),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True)
|
||||
self.helper_test_exception([(32,10), (32)], lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
|
||||
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
|
||||
|
||||
def test_nll_loss_weight(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32), (10)], lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
|
||||
torch.clip(y,0).type(torch.long), weight=z, reduction=r),
|
||||
lambda x,y,z: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long),
|
||||
weight=z, reduction=r), forward_only=True)
|
||||
|
||||
def test_nll_loss_ignore_index(self):
|
||||
logits = [[2.0, 0.5, -1.0],
|
||||
[1.5, 2.5, -0.5],
|
||||
[0.0, -2.0, 1.0]]
|
||||
targets = [0, 1, 2]
|
||||
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
|
||||
torch.clip(y,0).type(torch.long), ignore_index=1),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1),
|
||||
forward_only=True, vals=[logits, targets])
|
||||
|
||||
def test_one_hot(self):
|
||||
data = [1, 2, 4]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32),
|
||||
|
||||
@@ -3336,6 +3336,32 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
|
||||
return ret._do_reduction(reduction)
|
||||
|
||||
def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
|
||||
"""
|
||||
Compute the negative log likelihood loss between log-probabilities and target labels.
|
||||
|
||||
NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities.
|
||||
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.nll_loss.html
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
||||
Y = Tensor([1, 2])
|
||||
print(t.log_softmax().nll_loss(Y).item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[-1, 2, -3], [1, -2, 3]])
|
||||
Y = Tensor([1, 2])
|
||||
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
|
||||
```
|
||||
"""
|
||||
t, Y, target_shape = self.reshape(None, None, -1), Y.reshape(None, -1), Y.shape
|
||||
mask = Tensor.ones_like(Y, requires_grad=False, device=t.device) if ignore_index is None else (Y != ignore_index)
|
||||
masked_weight = mask if weight is None else weight[Y] * mask
|
||||
ret = (-t.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight).reshape(target_shape)
|
||||
if reduction == "mean": return ret.sum() / (masked_weight.sum())
|
||||
return ret._do_reduction(reduction)
|
||||
|
||||
# ***** Tensor Properties *****
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user