RANGEIFY=1 test_jit (#12254)

* RANGEIFY=1 test_jit

* don't do any of that

* disk

* simple disk tensor

* more work

* run more tests

* it also doesn't copy everytime

* skip tests that hang everything
This commit is contained in:
qazal
2025-09-20 17:34:32 +03:00
committed by GitHub
parent 393c6b236c
commit 57c7e0a8f8
4 changed files with 22 additions and 11 deletions

View File

@@ -553,7 +553,7 @@ jobs:
opencl: 'true'
llvm: "true"
- name: Test CL=1 RANGEIFY=1
run: CL=1 RANGEIFY=1 pytest -n auto test/test_ops.py test/test_schedule.py test/test_symbolic_ops.py --durations 20
run: CL=1 RANGEIFY=1 pytest -n auto test/test_ops.py test/test_schedule.py test/test_symbolic_ops.py test/test_jit.py test/unit/test_disk_tensor.py test/models/test_mnist.py test/unit/test_mnist_dataset.py --durations 20
- name: Test Fuse
run: CL=1 RANGEIFY=2 python3 -m pytest --durations 20 test/test_softmax_fusion.py -k "not test_auto_softmax"
- name: Test ONNX

View File

@@ -8,7 +8,7 @@ from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
from tinygrad.helpers import Context, JIT, RANGEIFY, GlobalCounters, getenv
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
@@ -605,6 +605,7 @@ class TestJitPrune(unittest.TestCase):
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
class TestJitFree(unittest.TestCase):
@unittest.skipIf(RANGEIFY, "needs a rewrite")
def test_free_intermediates(self):
ext_tensor = Tensor([1,24,23,45,1])
@TinyJit

View File

@@ -2,7 +2,7 @@ import unittest, functools, random
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.helpers import CI, getenv, prod, Context, RANGEIFY
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
import numpy as np
@@ -372,6 +372,7 @@ class TestMultiTensor(unittest.TestCase):
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
@unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs")
def test_data_parallel_resnet(self):
from extra.models.resnet import ResNet18
@@ -408,6 +409,7 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
@unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs")
def test_data_parallel_resnet_train_step(self):
from extra.models.resnet import ResNet18
fake_image = Tensor.rand((2, 3, 224//16, 224//16))
@@ -415,6 +417,7 @@ class TestMultiTensor(unittest.TestCase):
m = ResNet18()
self._test_model_train_step(m, fake_image, labels)
@unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs")
def test_data_parallel_simple_train_step(self):
class Model:
def __init__(self): self.conv1 = nn.Linear(128,128)

View File

@@ -35,11 +35,13 @@ earliest_rewrites = double_reshape+PatternMatcher([
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# copy reorder
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
# the next two rules breaks the JIT
# TODO: this is causing many copies wih the replace tag None
# RESHAPE after COPY
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).reshape(r.arg)),
# TODO: this should be BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).shrink(r.arg)),
#(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).reshape(r.arg)),
# this becomes BUFFER_VIEW on disk
# const hacks
#(UPat(Ops.CONST, name="x"), lambda x:
@@ -50,6 +52,10 @@ earliest_rewrites = double_reshape+PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None),
# handle disk
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
(t.size, x.st.views[0].offset), tag=t.tag).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
# contiguous/buffer/copy/assign is already contiguous
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
])
@@ -65,7 +71,7 @@ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
@@ -338,7 +344,7 @@ pm_rangeify = pm_mops+PatternMatcher([
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(
{Ops.STORE, Ops.COPY, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"),
{Ops.STORE, Ops.COPY, Ops.BUFFER_VIEW, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
@@ -381,7 +387,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# for now just no REDUCE, COPY, or ASSIGN
ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
# we don't want to bufferize threefry, also causes problems because not all platforms support long
if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran) and src.op is not Ops.THREEFRY: return None
if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.BUFFER_VIEW, Ops.ASSIGN} for x in ran) and src.op is not Ops.THREEFRY: return None
# simple, matching old behavior
#if src.op is not Ops.INDEX: return None
@@ -398,7 +404,8 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
# remove noop buffers. if we look at the next index we can remove even more of these
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] else None),
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \
and idx.src[0].op is not Ops.BUFFER_VIEW else None),
# remove reindexing with cost function
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const
@@ -548,7 +555,7 @@ def split_store(ctx:list[UOp], x:UOp):
metadatas = [ctx[y].metadata for x in ret.sparents if x.tag is not None for y in x.tag]
# NOTE: the hack for COPY is here
ret = ret.sink() if ret.src[1].op is not Ops.COPY else ret.src[1]
ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None]))))
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
return x.as_buf().assign(kernel)