mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
RANGEIFY test_tensor (#12235)
This commit is contained in:
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@@ -522,11 +522,12 @@ jobs:
|
|||||||
# test_embedding issue with jit
|
# test_embedding issue with jit
|
||||||
# test_load_state_dict_sharded_model_dict_same_axis issue with multi
|
# test_load_state_dict_sharded_model_dict_same_axis issue with multi
|
||||||
# test_instancenorm_3d is very slow
|
# test_instancenorm_3d is very slow
|
||||||
|
# test_copy_from_disk issue with DISK
|
||||||
run: |
|
run: |
|
||||||
CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
|
CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
|
||||||
-k "not test_embedding and not test_load_state_dict_sharded_model_dict_same_axis and not test_instancenorm_3d" \
|
-k "not test_embedding and not test_load_state_dict_sharded_model_dict_same_axis and not test_instancenorm_3d and not test_copy_from_disk" \
|
||||||
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_tensor_variable.py \
|
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_tensor_variable.py \
|
||||||
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py test/test_nn.py test/test_arange.py
|
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py
|
||||||
- name: Test const folding
|
- name: Test const folding
|
||||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
|
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
|
||||||
- name: Test multitensor
|
- name: Test multitensor
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import torch
|
|||||||
import unittest, copy, mmap, random, math, array
|
import unittest, copy, mmap, random, math, array
|
||||||
from tinygrad import Tensor, Device, dtypes
|
from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.tensor import _METADATA
|
from tinygrad.tensor import _METADATA
|
||||||
from tinygrad.helpers import getenv, temp, mv_address
|
from tinygrad.helpers import getenv, temp, mv_address, RANGEIFY
|
||||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||||
from hypothesis import given, settings, strategies as strat
|
from hypothesis import given, settings, strategies as strat
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
@@ -871,11 +871,18 @@ class TestTensorMetadata(unittest.TestCase):
|
|||||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
if not RANGEIFY:
|
||||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
|
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||||
bw = [m for m in si.metadata if m.backward]
|
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"})
|
||||||
self.assertEqual(len(bw), 2)
|
bw = [m for m in si.metadata if m.backward]
|
||||||
self.assertEqual(bw[0].name, "sigmoid")
|
self.assertEqual(len(bw), 2)
|
||||||
|
self.assertEqual(bw[0].name, "sigmoid")
|
||||||
|
else:
|
||||||
|
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||||
|
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
|
||||||
|
bw = [m for m in si.metadata if m.backward]
|
||||||
|
self.assertEqual(len(bw), 1)
|
||||||
|
self.assertEqual(bw[0].name, "sigmoid")
|
||||||
|
|
||||||
class TestIdxUpcast(unittest.TestCase):
|
class TestIdxUpcast(unittest.TestCase):
|
||||||
def _find_op(self, ast: UOp, op: Ops):
|
def _find_op(self, ast: UOp, op: Ops):
|
||||||
|
|||||||
Reference in New Issue
Block a user