mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
move sparse_categorical_crossentropy to test_ops (#11083)
also flattened the tests
This commit is contained in:
@@ -2950,6 +2950,39 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10)], lambda x: torch.nn.functional.cross_entropy(x, torch.tensor(classes), label_smoothing=ls),
|
||||
lambda x: x.cross_entropy(Tensor(classes), label_smoothing=ls))
|
||||
|
||||
def test_sparse_categorical_crossentropy(self):
|
||||
classes = np.random.randint(0, 10, (12,), dtype=np.int32).tolist()
|
||||
helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss()(x, torch.tensor(classes)),
|
||||
lambda x: x.sparse_categorical_crossentropy(Tensor(classes)))
|
||||
|
||||
# combine args
|
||||
helper_test_op([(12,10)],
|
||||
lambda x: torch.nn.CrossEntropyLoss(reduction="mean", ignore_index=classes[0], label_smoothing=0.3)(x, torch.tensor(classes)),
|
||||
lambda x: x.sparse_categorical_crossentropy(Tensor(classes), reduction="mean", ignore_index=classes[0], label_smoothing=0.3))
|
||||
|
||||
# with batch. somehow this does not match torch
|
||||
classes = np.random.randint(0, 10, (3,12), dtype=np.int32).tolist()
|
||||
helper_test_op([(3,12,10)], lambda x: torch.nn.CrossEntropyLoss()(x.permute(0,2,1), torch.tensor(classes)),
|
||||
lambda x: x.sparse_categorical_crossentropy(Tensor(classes)))
|
||||
|
||||
def test_sparse_categorical_crossentropy_reductions(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
classes = np.random.randint(0, 10, (12,), dtype=np.int32).tolist()
|
||||
helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(reduction=r)(x, torch.tensor(classes)),
|
||||
lambda x: x.sparse_categorical_crossentropy(Tensor(classes), reduction=r))
|
||||
|
||||
def test_sparse_categorical_crossentropy_ignore_index(self):
|
||||
classes = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
|
||||
for i in (-1, 0, 3):
|
||||
helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(ignore_index=i)(x, torch.tensor(classes)),
|
||||
lambda x: x.sparse_categorical_crossentropy(Tensor(classes), ignore_index=i))
|
||||
|
||||
def test_sparse_categorical_crossentropy_label_smoothing(self):
|
||||
for s in (0.3, 0.9):
|
||||
classes = np.random.randint(0, 10, (12,), dtype=np.int32).tolist()
|
||||
helper_test_op([(12,10)], lambda x: torch.nn.CrossEntropyLoss(label_smoothing=s)(x, torch.tensor(classes)),
|
||||
lambda x: x.sparse_categorical_crossentropy(Tensor(classes), label_smoothing=s))
|
||||
|
||||
def test_nll_loss(self):
|
||||
target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
|
||||
helper_test_op([(32,10)],
|
||||
|
||||
Reference in New Issue
Block a user