rangeify: only test correctness in multi (#12339)

* work

* more work

* back here

* skip tests

* work
This commit is contained in:
qazal
2025-09-30 09:55:59 +03:00
committed by GitHub
parent ab6b0d3a21
commit 6a56d3c859
4 changed files with 8 additions and 7 deletions

View File

@@ -533,12 +533,9 @@ jobs:
-k "not test_assign_diamond_cycle" \
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_tensor_variable.py \
test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py \
test/test_setitem.py test/test_assign.py
test/test_setitem.py test/test_assign.py test/test_multitensor.py
- name: Test const folding
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
- name: Test multitensor
run: |
CPU=1 RANGEIFY=1 python3 -m pytest -n=auto test/test_multitensor.py::TestMultiTensor -k 'not const_folding'
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding"
# RANGEIFY=2 isn't supported
#- name: Test CPU=1 RANGEIFY=2
# run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20

View File

@@ -245,7 +245,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
assert t.sum().dtype == t.contiguous().sum().dtype
@unittest.skipIf(not_support_multi_device(), "no multi")
@unittest.skipIf(not_support_multi_device() or RANGEIFY, "no multi, RANGEIFY doesn't support multi const folding")
class TestMultiConstFolding(unittest.TestCase):
def test_multi_const_folding_literal(self):
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))

View File

@@ -793,6 +793,7 @@ class TestMultiTensor(unittest.TestCase):
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
@unittest.skipIf(RANGEIFY, "RANGEIFY doesn't support multi const folding")
def test_multi_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.arange(3).realize()

View File

@@ -48,6 +48,9 @@ earliest_rewrites = double_reshape+PatternMatcher([
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
# make inputs to mstack contiguous
(UPat(Ops.MSTACK, name="ms"), lambda ms: ms.replace(src=tuple(s if s.op in ALWAYS_CONTIGUOUS else s.contiguous() for s in ms.src))),
# assign only to buffer, otherwise make it a CONTIGUOUS
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
@@ -358,7 +361,7 @@ pm_rangeify = pm_mops+PatternMatcher([
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
def cleanup_dead_axes(b:UOp):
# if it's user contiguous or assigned to something, we don't touch it
if b.src[0].op in {Ops.CONTIGUOUS, Ops.ASSIGN}: return None
if b.src[0].op in {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN}: return None
new_rng = []
hit = False