label_smoothing in sparse_cat_crossentropy (#3568)

* label_smoothing in sparse_cat_crossentropy

* test multiple values, assert
This commit is contained in:
David Hou
2024-03-01 17:02:46 -08:00
committed by GitHub
parent 6b29c70b3d
commit b3cdc11a58
2 changed files with 8 additions and 7 deletions

View File

@@ -13,13 +13,13 @@ class TestNN(unittest.TestCase):
# create in tinygrad
input = Tensor.randn(3, 5)
target = Tensor.randint((3,), low=0, high=4)
loss = input.sparse_categorical_crossentropy(target)
torch_input = torch.tensor(input.numpy())
torch_target = torch.tensor(target.numpy(), dtype=torch.long)
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(torch_input, torch_target)
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
for smoothing in [0.0, 0.1, 0.5, 1.0]:
loss = input.sparse_categorical_crossentropy(target, label_smoothing=smoothing)
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=smoothing)(torch_input, torch_target)
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
def test_batchnorm2d(self, training=False):
with Tensor.train(training):

View File

@@ -959,12 +959,13 @@ class Tensor:
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
# NOTE: self is a logits input
loss_mask = (Y != ignore_index)
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()
return (1 - label_smoothing) * log_probs.mul(y).sum() / loss_mask.sum() + -1 * (label_smoothing * log_probs.mean())
# ***** cast ops *****