clean up more RANGEIFY flag (#12556)

This commit is contained in:
chenyu
2025-10-09 15:06:48 +08:00
committed by GitHub
parent 658c566e22
commit cf8232ec6a
9 changed files with 34 additions and 43 deletions

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import temp, RANGEIFY
from tinygrad.helpers import temp
N = 200 # has to be bigger than the cache to fail
@@ -300,7 +300,6 @@ class TestAssign(unittest.TestCase):
#assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
@unittest.skipUnless(RANGEIFY, "only correct in rangeify")
def test_post_permuted_assignment_alt(self):
a = Tensor.arange(N*N).reshape(N,N).contiguous().realize()
b = Tensor.arange(N*N).reshape(N,N).contiguous().realize()

View File

@@ -3,7 +3,6 @@ from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType, ConstType
from tinygrad.uop.ops import Ops, UOp
from tinygrad.codegen import full_rewrite_to_sink
from tinygrad.helpers import RANGEIFY
from tinygrad.device import is_dtype_supported
import numpy as np
from test.helpers import not_support_multi_device
@@ -158,8 +157,7 @@ class TestMovedConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
def test_add_padded_zero(self):
# TODO: it's 1 now, this might be possible to fold
_check_ast_count(0 if RANGEIFY else 1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
def test_mul_shrunk_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
@@ -168,16 +166,16 @@ class TestMovedConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
def test_cast_padded(self):
# NOTE: RANGEIFY or not, it's always 1 kernel when calling .numpy, limitation of _check_ast_count
# NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count
if is_dtype_supported(dtypes.int16):
_check_ast_count(1 if RANGEIFY else 0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.uint16):
_check_ast_count(1 if RANGEIFY else 0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# folded
if is_dtype_supported(dtypes.int64):
_check_ast_count(1 if RANGEIFY else 0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
class TestReduceOpsConstFolding(unittest.TestCase):
@@ -249,7 +247,7 @@ class TestReduceOpsConstFolding(unittest.TestCase):
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
assert t.sum().dtype == t.contiguous().sum().dtype
@unittest.skipIf(not_support_multi_device() or RANGEIFY, "no multi, RANGEIFY doesn't support multi const folding")
@unittest.skipIf(not_support_multi_device() or True, "no multi, RANGEIFY doesn't support multi const folding")
class TestMultiConstFolding(unittest.TestCase):
def test_multi_const_folding_literal(self):
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))

View File

@@ -4,7 +4,7 @@ from tinygrad import Device, dtypes, Tensor, Context
from tinygrad.device import LRUAllocator, is_dtype_supported
from tinygrad.dtype import ImageDType
from tinygrad.engine.realize import lower_schedule
from tinygrad.helpers import prod, unwrap, RANGEIFY
from tinygrad.helpers import prod, unwrap
from test.helpers import REAL_DEV
IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
@@ -139,7 +139,7 @@ class TestImageDType(unittest.TestCase):
# NOTE: the w1 grad must realize to a seperate kernel
assert w1.grad.uop.is_realized, f"never realized {w1.grad}"
self.assertEqual(w1.grad.uop.base.buffer.dtype, dtypes.float32)
self.assertEqual(len(sched), 9 if RANGEIFY else 10)
self.assertEqual(len(sched), 9)
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
class TestImageRealization(unittest.TestCase):

View File

@@ -8,7 +8,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, RANGEIFY
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.renderer.ptx import PTXRenderer
@@ -314,7 +314,7 @@ class TestLinearizer(unittest.TestCase):
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
@unittest.skipIf(RANGEIFY and isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))

View File

@@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings
import numpy as np
from typing import List, Callable
import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM, RANGEIFY
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
@@ -3040,7 +3040,6 @@ class TestOps(unittest.TestCase):
pos_weight=torch.tensor(pos_weight)),
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight)))
@unittest.skipIf(RANGEIFY > 1, "broken on RANGEIFY > 1, TODO: fix")
def test_cross_entropy_class_probabilities(self):
helper_test_op([(32,), (32,)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y))

View File

@@ -2,7 +2,7 @@ import unittest
import numpy as np
from tinygrad import Tensor, GlobalCounters, Context, Device
from tinygrad.dtype import DTypeLike, dtypes
from tinygrad.helpers import DEBUG, get_single_element, RANGEIFY
from tinygrad.helpers import DEBUG, get_single_element
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.device import is_dtype_supported
@@ -39,17 +39,17 @@ class TestFuse(unittest.TestCase):
np_multi = fxn(*args, **kwargs).numpy()
np.testing.assert_allclose(np_single, np_multi, atol=atol)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_fuse_norm(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a / a.mean(axis=1), a)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_fuse_argmax(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.argmax(axis=-1), a)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_fuse_softmax(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.softmax(axis=-1), a)
@@ -60,7 +60,7 @@ class TestFuse(unittest.TestCase):
self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True)
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_fuse_softmax_dtype(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.softmax(axis=-1, dtype='half'), a, atol=3e-4)
@@ -68,7 +68,7 @@ class TestFuse(unittest.TestCase):
def test_fuse_arange_eye(self):
self._test_fuse(lambda: Tensor.arange(10).reshape(10,1).expand(10,10) == Tensor.arange(10).reshape(1,10).expand(10,10))
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_double_gemm(self):
N = 32
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
@@ -91,7 +91,7 @@ class TestFuse(unittest.TestCase):
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
self._test_fuse(embedding, a, atol=1e-5)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_attention_kernel_count(self):
wq = Tensor.empty(32, 32)
wk = Tensor.empty(32, 32)
@@ -104,7 +104,7 @@ class TestFuse(unittest.TestCase):
s = attn.schedule()
self.assertEqual(len(s), 4) # 3 matmul and 1 attention
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_flash_attention(self):
BS = 4
HEADS = 2
@@ -172,7 +172,7 @@ class TestSoftmaxFusion(unittest.TestCase):
np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
@unittest.skip("needs RANGEIFY>1")
def test_auto_softmax(self):
print("*** softmax ***")
with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)):

View File

@@ -1,12 +1,10 @@
import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.helpers import RANGEIFY
from tinygrad.apps.llm import apply_rope
#from tinygrad.engine.realize import run_schedule
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
@unittest.skipIf(RANGEIFY > 0, "not half on rangeify")
def test_half_qkv_buffers(self):
BS, seqlen, dim = 10, 4, 100
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
@@ -14,12 +12,11 @@ class TestAttention(unittest.TestCase):
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
attn = q.scaled_dot_product_attention(k, v)
sched = attn.schedule()
#run_schedule(sched[:])
# attention has 5 kernels now
self.assertEqual(len(sched), 4 if RANGEIFY else 5)
softmax_inputs = sched[1:4]
for i,si in enumerate(softmax_inputs):
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
# attention has 4 kernels now
self.assertEqual(len(sched), 4)
# softmax_inputs = sched[1:4]
# for i,si in enumerate(softmax_inputs):
# assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
def test_apply_rope(self):
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)

View File

@@ -1,7 +1,7 @@
import unittest, sys
import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, Context, nn
from tinygrad.helpers import CI, Profiling, WINO, RANGEIFY
from tinygrad.helpers import CI, Profiling, WINO
@unittest.skipIf(sys.platform.startswith("win"), "flaky on Windows")
class TestWinogradClose(unittest.TestCase):
@@ -61,9 +61,9 @@ class TestWinograd(unittest.TestCase):
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
if not RANGEIFY:
self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
self.assertLess(mem_ratio, 10)
# TODO: what's optimal on this?
self.assertLess(ops_ratio, 4.3)
self.assertLess(mem_ratio, 3)
def test_dtype(self):
IC, OC, X, Y = 4,4,9,9

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, RANGEIFY
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
@@ -38,11 +38,10 @@ rewrites_for_linearizer = [
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value, RANGEIFY.value)
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
@functools.cache
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL,
_RANGEIFY) -> list[RewriteStep]:
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
@@ -52,8 +51,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
# split ranges
if _RANGEIFY:
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
ret.append(RewriteStep(pm_split_ranges+pm_flatten_range, ctx=lambda _: {}, name="split ranges"))
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic"))