mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
RANGEIFY=1 test_jit (#12254)
* RANGEIFY=1 test_jit * don't do any of that * disk * simple disk tensor * more work * run more tests * it also doesn't copy everytime * skip tests that hang everything
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -553,7 +553,7 @@ jobs:
|
||||
opencl: 'true'
|
||||
llvm: "true"
|
||||
- name: Test CL=1 RANGEIFY=1
|
||||
run: CL=1 RANGEIFY=1 pytest -n auto test/test_ops.py test/test_schedule.py test/test_symbolic_ops.py --durations 20
|
||||
run: CL=1 RANGEIFY=1 pytest -n auto test/test_ops.py test/test_schedule.py test/test_symbolic_ops.py test/test_jit.py test/unit/test_disk_tensor.py test/models/test_mnist.py test/unit/test_mnist_dataset.py --durations 20
|
||||
- name: Test Fuse
|
||||
run: CL=1 RANGEIFY=2 python3 -m pytest --durations 20 test/test_softmax_fusion.py -k "not test_auto_softmax"
|
||||
- name: Test ONNX
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
|
||||
from tinygrad.helpers import Context, JIT, RANGEIFY, GlobalCounters, getenv
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
@@ -605,6 +605,7 @@ class TestJitPrune(unittest.TestCase):
|
||||
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
|
||||
|
||||
class TestJitFree(unittest.TestCase):
|
||||
@unittest.skipIf(RANGEIFY, "needs a rewrite")
|
||||
def test_free_intermediates(self):
|
||||
ext_tensor = Tensor([1,24,23,45,1])
|
||||
@TinyJit
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest, functools, random
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.helpers import CI, getenv, prod, Context, RANGEIFY
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
|
||||
import numpy as np
|
||||
@@ -372,6 +372,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
|
||||
@unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs")
|
||||
def test_data_parallel_resnet(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
|
||||
@@ -408,6 +409,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
|
||||
@unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
fake_image = Tensor.rand((2, 3, 224//16, 224//16))
|
||||
@@ -415,6 +417,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
m = ResNet18()
|
||||
self._test_model_train_step(m, fake_image, labels)
|
||||
|
||||
@unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs")
|
||||
def test_data_parallel_simple_train_step(self):
|
||||
class Model:
|
||||
def __init__(self): self.conv1 = nn.Linear(128,128)
|
||||
|
||||
@@ -35,11 +35,13 @@ earliest_rewrites = double_reshape+PatternMatcher([
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
|
||||
# copy reorder
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
|
||||
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
|
||||
# the next two rules breaks the JIT
|
||||
# TODO: this is causing many copies wih the replace tag None
|
||||
# RESHAPE after COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).shrink(r.arg)),
|
||||
#(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).reshape(r.arg)),
|
||||
# this becomes BUFFER_VIEW on disk
|
||||
|
||||
# const hacks
|
||||
#(UPat(Ops.CONST, name="x"), lambda x:
|
||||
@@ -50,6 +52,10 @@ earliest_rewrites = double_reshape+PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
|
||||
lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None),
|
||||
|
||||
# handle disk
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
|
||||
(t.size, x.st.views[0].offset), tag=t.tag).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
|
||||
# contiguous/buffer/copy/assign is already contiguous
|
||||
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
||||
])
|
||||
@@ -65,7 +71,7 @@ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
for s in rb.src:
|
||||
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
@@ -338,7 +344,7 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(
|
||||
{Ops.STORE, Ops.COPY, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"),
|
||||
{Ops.STORE, Ops.COPY, Ops.BUFFER_VIEW, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
|
||||
@@ -381,7 +387,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
# for now just no REDUCE, COPY, or ASSIGN
|
||||
ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
|
||||
# we don't want to bufferize threefry, also causes problems because not all platforms support long
|
||||
if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran) and src.op is not Ops.THREEFRY: return None
|
||||
if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.BUFFER_VIEW, Ops.ASSIGN} for x in ran) and src.op is not Ops.THREEFRY: return None
|
||||
|
||||
# simple, matching old behavior
|
||||
#if src.op is not Ops.INDEX: return None
|
||||
@@ -398,7 +404,8 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
|
||||
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] else None),
|
||||
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \
|
||||
and idx.src[0].op is not Ops.BUFFER_VIEW else None),
|
||||
# remove reindexing with cost function
|
||||
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
# no buffers for const
|
||||
@@ -548,7 +555,7 @@ def split_store(ctx:list[UOp], x:UOp):
|
||||
metadatas = [ctx[y].metadata for x in ret.sparents if x.tag is not None for y in x.tag]
|
||||
|
||||
# NOTE: the hack for COPY is here
|
||||
ret = ret.sink() if ret.src[1].op is not Ops.COPY else ret.src[1]
|
||||
ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
|
||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None]))))
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
return x.as_buf().assign(kernel)
|
||||
|
||||
Reference in New Issue
Block a user