diff --git a/test/test_nn.py b/test/test_nn.py index 000e459164..72fc0d77df 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13,7 +13,7 @@ from test.helpers import not_support_multi_device @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow") class TestNN(unittest.TestCase): - def test_sparse_cat_cross_entropy(self): + def test_sparse_categorical_crossentropy(self): # create in tinygrad input_tensor = Tensor.randn(6, 5) # not square to test that mean scaling uses the correct dimension target = Tensor([0, 0, 0, 1, 2, 3]) # torch doesn't support target=-1 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 40bd2e20af..55e8cf6735 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3919,7 +3919,7 @@ class Tensor(MathTrait): ``` """ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" - assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']" + assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}" log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool) y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1]) y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])