RANGEIFY test_tensor (#12235)

This commit is contained in:
chenyu
2025-09-18 10:35:43 -04:00
committed by GitHub
parent 7487c13b61
commit f82b16a0e9
2 changed files with 16 additions and 8 deletions

View File

@@ -4,7 +4,7 @@ import torch
import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _METADATA
from tinygrad.helpers import getenv, temp, mv_address
from tinygrad.helpers import getenv, temp, mv_address, RANGEIFY
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported
@@ -871,11 +871,18 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.uop.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1]
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 2)
self.assertEqual(bw[0].name, "sigmoid")
if not RANGEIFY:
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 2)
self.assertEqual(bw[0].name, "sigmoid")
else:
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 1)
self.assertEqual(bw[0].name, "sigmoid")
class TestIdxUpcast(unittest.TestCase):
def _find_op(self, ast: UOp, op: Ops):