From 487cb027dc5120542755446d1595ec7b76c207e8 Mon Sep 17 00:00:00 2001 From: Louis Novy <101842021+louisnovy@users.noreply.github.com> Date: Sun, 13 Oct 2024 15:21:23 -0700 Subject: [PATCH] oops this shouldn't be edited --- test/test_schedule.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 6a9c97f797..337c37c2a9 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1549,18 +1549,18 @@ class TestIndexing(unittest.TestCase): self.check_schedule(loss, 4) np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6) - # def test_mnist_val(self): - # from tinygrad.nn.datasets import mnist - # import torch - # _, Y_train, _, _ = mnist() - # samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize() - # yt = Tensor.randn(BS, 10).realize() - # with Context(SPLIT_REDUCEOP=0): - # loss = yt.sparse_categorical_crossentropy(Y_train[samples]) - # self.check_schedule(loss, 6) - # loss_fused = loss.numpy() - # loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) - # np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) + def test_mnist_val(self): + from tinygrad.nn.datasets import mnist + import torch + _, Y_train, _, _ = mnist() + samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize() + yt = Tensor.randn(BS, 10).realize() + with Context(SPLIT_REDUCEOP=0): + loss = yt.sparse_categorical_crossentropy(Y_train[samples]) + self.check_schedule(loss, 6) + loss_fused = loss.numpy() + loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) + np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) def test_arange_fuse_grouped_children(self): X = Tensor.randn(4, 4).realize()