mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
label_smoothing in sparse_cat_crossentropy (#3568)
* label_smoothing in sparse_cat_crossentropy * test multiple values, assert
This commit is contained in:
@@ -13,13 +13,13 @@ class TestNN(unittest.TestCase):
|
||||
# create in tinygrad
|
||||
input = Tensor.randn(3, 5)
|
||||
target = Tensor.randint((3,), low=0, high=4)
|
||||
loss = input.sparse_categorical_crossentropy(target)
|
||||
|
||||
torch_input = torch.tensor(input.numpy())
|
||||
torch_target = torch.tensor(target.numpy(), dtype=torch.long)
|
||||
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(torch_input, torch_target)
|
||||
|
||||
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
||||
for smoothing in [0.0, 0.1, 0.5, 1.0]:
|
||||
loss = input.sparse_categorical_crossentropy(target, label_smoothing=smoothing)
|
||||
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=smoothing)(torch_input, torch_target)
|
||||
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_batchnorm2d(self, training=False):
|
||||
with Tensor.train(training):
|
||||
|
||||
Reference in New Issue
Block a user