mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
label_smoothing in sparse_cat_crossentropy (#3568)
* label_smoothing in sparse_cat_crossentropy * test multiple values, assert
This commit is contained in:
@@ -13,13 +13,13 @@ class TestNN(unittest.TestCase):
|
||||
# create in tinygrad
|
||||
input = Tensor.randn(3, 5)
|
||||
target = Tensor.randint((3,), low=0, high=4)
|
||||
loss = input.sparse_categorical_crossentropy(target)
|
||||
|
||||
torch_input = torch.tensor(input.numpy())
|
||||
torch_target = torch.tensor(target.numpy(), dtype=torch.long)
|
||||
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(torch_input, torch_target)
|
||||
|
||||
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
||||
for smoothing in [0.0, 0.1, 0.5, 1.0]:
|
||||
loss = input.sparse_categorical_crossentropy(target, label_smoothing=smoothing)
|
||||
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=smoothing)(torch_input, torch_target)
|
||||
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_batchnorm2d(self, training=False):
|
||||
with Tensor.train(training):
|
||||
|
||||
@@ -959,12 +959,13 @@ class Tensor:
|
||||
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
|
||||
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
|
||||
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
|
||||
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
||||
# NOTE: self is a logits input
|
||||
loss_mask = (Y != ignore_index)
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
return (1 - label_smoothing) * log_probs.mul(y).sum() / loss_mask.sum() + -1 * (label_smoothing * log_probs.mean())
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user