From ae51bdd06ad68998c7c099f6169396e55fa1fc8c Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 9 Oct 2025 14:29:38 +0800 Subject: [PATCH] remove trivial use of RANGEIFY flag (#12550) some tests need update still --- docs/abstractions2.py | 1 - extra/gemm/amd_uop_matmul.py | 5 +-- test/external/external_test_opt.py | 14 +++---- test/helpers.py | 7 +--- test/opt/test_kernel_opts.py | 3 +- test/test_arange.py | 4 +- test/test_assign.py | 59 +++++++++++--------------- test/test_ops.py | 4 +- test/test_rangeify.py | 5 +-- test/test_schedule.py | 66 +++++++++++------------------- test/test_symbolic_jit.py | 5 +-- test/test_tensor.py | 19 +++------ test/test_uops_stats.py | 9 ++-- test/unit/test_kernelize.py | 5 +-- test/unit/test_shm_tensor.py | 4 +- test/unit/test_winograd.py | 4 +- tinygrad/tensor.py | 4 +- 17 files changed, 86 insertions(+), 132 deletions(-) diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 747b628644..708933118c 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -80,7 +80,6 @@ print("******** third, the UOp ***********") from tinygrad.engine.realize import run_schedule from tinygrad.engine.schedule import create_schedule_with_vars -from tinygrad.helpers import RANGEIFY from tinygrad.schedule.rangeify import get_rangeify_map # allocate some values + load in values diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 78dbf81a1d..4b5dddd777 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -49,8 +49,7 @@ def rangeify_kernel3(): b = Tensor.empty(N,N) c = a@b #c = c.reshape((32,2,16,4,32,2,16,4)).contiguous() - with Context(RANGEIFY=1): - sink = c.schedule()[-1].ast + sink = c.schedule()[-1].ast #print(sink) opts = [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UPCAST, 0, 2)] @@ -329,7 +328,7 @@ if __name__ == "__main__": elif HL == 1: hprg = hl_spec_kernel3() else: hprg = hand_spec_kernel3() if HL == 3: - with Context(RANGEIFY=1, BLOCK_REORDER=0): + with Context(BLOCK_REORDER=0): prg = get_program(hprg, Device.default.renderer) else: prg = get_program(hprg, Device.default.renderer) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 45bb87fd50..f1bab81d26 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -4,7 +4,7 @@ import numpy as np import torch from tinygrad import GlobalCounters, Tensor, Device -from tinygrad.helpers import getenv, RANGEIFY +from tinygrad.helpers import getenv from tinygrad.nn.state import get_parameters from tinygrad.engine.realize import capturing from tinygrad.tensor import _to_np_dtype @@ -164,7 +164,7 @@ class TestOpt(unittest.TestCase): def test_permute_was_pushed(self): a = Tensor.randn(16, 16, 16) - with CLCache(1 if RANGEIFY else 2): + with CLCache(1): c = a.sum(2) d = c.permute(1,0).contiguous() d.realize() @@ -172,7 +172,7 @@ class TestOpt(unittest.TestCase): def test_permute_was_pushed_through_contract_reshape(self): a = Tensor.randn(4, 4, 4, 4, 4) - with CLCache(1 if RANGEIFY else 2): + with CLCache(1): c = a.sum(-1) d = c.reshape(16,16).permute(1,0).contiguous() d.realize() @@ -180,7 +180,7 @@ class TestOpt(unittest.TestCase): def test_permute_was_pushed_through_contractw1s_reshape(self): a = Tensor.randn(4, 4, 4, 4, 4) - with CLCache(1 if RANGEIFY else 2): + with CLCache(1): c = a.sum(-1) d = c.reshape(16,1,16).permute(2,1,0).contiguous() d.realize() @@ -188,7 +188,7 @@ class TestOpt(unittest.TestCase): def test_permute_was_pushed_through_expand_reshape(self): a = Tensor.randn(16, 16, 16) - with CLCache(1 if RANGEIFY else 2): + with CLCache(1): c = a.sum(2) d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous() d.realize() @@ -220,7 +220,7 @@ class TestOpt(unittest.TestCase): for axis in [0, 1]: for n in [4, 8, 16]: b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) - with CLCache(allowed=3 if RANGEIFY else 2): + with CLCache(allowed=3): a = Tensor.ones(n, n).contiguous().sum(axis).reshape(n, 1).expand(n, n).sum(axis) a.realize() np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) @@ -229,7 +229,7 @@ class TestOpt(unittest.TestCase): axis1, axis2 = 0, 1 for n in [4, 8, 16]: b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) - with CLCache(allowed=3 if RANGEIFY else 2): + with CLCache(allowed=3): a = Tensor.ones(n, n).contiguous().sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) a.realize() np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) diff --git a/test/helpers.py b/test/helpers.py index 98b5978f98..cee64595f3 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,4 +1,4 @@ -import time, struct, unittest +import time, struct from typing import Any, Callable import numpy as np from tinygrad import Tensor, dtypes, Device @@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import DType from tinygrad.nn.state import get_parameters -from tinygrad.helpers import T, CI, RANGEIFY +from tinygrad.helpers import T, CI from tinygrad.codegen import full_rewrite from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler @@ -62,6 +62,3 @@ def not_support_multi_device(): # NOTE: This will open REMOTE if it's the default device REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device) - -def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn) -def expect_nonrangeify_fails(fxn): return (unittest.expectedFailure if not RANGEIFY else (lambda f:f))(fxn) diff --git a/test/opt/test_kernel_opts.py b/test/opt/test_kernel_opts.py index fda46a36c1..d1e5d35164 100644 --- a/test/opt/test_kernel_opts.py +++ b/test/opt/test_kernel_opts.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Device, Tensor, dtypes -from tinygrad.helpers import CI, RANGEIFY +from tinygrad.helpers import CI from tinygrad.codegen.opt import Opt, OptOps, KernelOptError # TODO: write a clean version of this @@ -351,7 +351,6 @@ class TestKernelOpts(unittest.TestCase): ] + [[Opt(OptOps.THREAD, 0, 4)] if Device[Device.DEFAULT].renderer.global_max[0] >= 4 else []] + [[Opt(OptOps.THREAD, 0, 8)] if Device[Device.DEFAULT].renderer.global_max[0] >= 8 else []]) - @unittest.skipUnless(RANGEIFY>=1, "Kernel only fuses with rangeify") def test_double_sum_group(self): a = Tensor.rand(4, 4, 4) r = a.sum((1, 2)).sum() diff --git a/test/test_arange.py b/test/test_arange.py index 3f31b71303..248cba3d56 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -1,7 +1,7 @@ import unittest import numpy as np from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable -from tinygrad.helpers import CI, Context, getenv, RANGEIFY +from tinygrad.helpers import CI, Context, getenv from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program from tinygrad.uop.ops import Ops @@ -95,7 +95,7 @@ class TestIndexing(unittest.TestCase): X = dataset[idxs] assert X.shape == (4,DDIM) sched = X.schedule() - self.assertEqual(len(sched), 1 if RANGEIFY else 2) + self.assertEqual(len(sched), 1) run_schedule(sched) assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops} != {4*DSET}" np.testing.assert_allclose(real_index, X.numpy()) diff --git a/test/test_assign.py b/test/test_assign.py index b517c8e39d..b23f172207 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -1,6 +1,5 @@ #!/usr/bin/env python import unittest -import contextlib import numpy as np from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable from tinygrad.device import is_dtype_supported @@ -271,8 +270,6 @@ class TestAssign(unittest.TestCase): b.assign(a.contiguous()).realize() assert GlobalCounters.kernel_count - kc == 2 - # passing in RANGEIFY=1, RANGEIFY=0 asserts permuted assigns it can't fuse - def assert_permuted_assign(self): return self.assertRaisesRegex(RuntimeError, "contiguous") if not RANGEIFY else contextlib.nullcontext() def test_permuted_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) @@ -280,14 +277,13 @@ class TestAssign(unittest.TestCase): b.realize() ba1 = a.uop.base.realized bb1 = b.uop.base.realized - with self.assert_permuted_assign(): - a = a.permute(1,0) - a += b - a.realize() - ba2 = a.uop.base.realized - np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) - # permute and base are the same buffer - assert ba1 == ba2 and ba1 != bb1 + a = a.permute(1,0) + a += b + a.realize() + ba2 = a.uop.base.realized + np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + # permute and base are the same buffer + assert ba1 == ba2 and ba1 != bb1 def test_post_permuted_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) @@ -297,13 +293,12 @@ class TestAssign(unittest.TestCase): #GlobalCounters.cache = [] ba1 = a.uop.base.realized # noqa: F841 bb1 = b.uop.base.realized # noqa: F841 - with self.assert_permuted_assign(): - a.assign(a.permute(1,0) + b) # this should not work! - a.realize() - ba2 = a.uop.base.realized # noqa: F841 - # NOTE: don't test that it's assigned - #assert ba1 == ba2 and ba1 != bb1 - np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + a.assign(a.permute(1,0) + b) # this should not work! + a.realize() + ba2 = a.uop.base.realized # noqa: F841 + # NOTE: don't test that it's assigned + #assert ba1 == ba2 and ba1 != bb1 + np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) @unittest.skipUnless(RANGEIFY, "only correct in rangeify") def test_post_permuted_assignment_alt(self): @@ -345,21 +340,18 @@ class TestAssign(unittest.TestCase): def test_permuted_assignment_correct(self): a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize() b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize() - # TODO: swizzler.py limitation, should NOT raise AssertionError from numpy. - with self.assert_permuted_assign(): - a = a.permute(1, 0) - new_val = a + b - a.assign(new_val) - np.testing.assert_equal(a.numpy(), np.arange(4 * 4).reshape(4, 4).transpose(1, 0) + np.arange(4 * 4).reshape(4, 4)) + a = a.permute(1, 0) + new_val = a + b + a.assign(new_val) + np.testing.assert_equal(a.numpy(), np.arange(4 * 4).reshape(4, 4).transpose(1, 0) + np.arange(4 * 4).reshape(4, 4)) def test_permuted_reduceop_child_dual_use(self): a = Tensor.randn(32, 32, 32).realize() b = Tensor.full((32, 32), 1.).contiguous().realize() - with self.assert_permuted_assign(): - r = a.sum(axis=1) - b.assign(r + b.permute(1, 0)) - b.realize() - np.testing.assert_allclose(b.numpy(), a.numpy().sum(axis=1)+np.ones((32, 32)).transpose(1, 0), atol=1e-6, rtol=1e-3) + r = a.sum(axis=1) + b.assign(r + b.permute(1, 0)) + b.realize() + np.testing.assert_allclose(b.numpy(), a.numpy().sum(axis=1)+np.ones((32, 32)).transpose(1, 0), atol=1e-6, rtol=1e-3) @unittest.skip("multi output not supported anymore") def test_permuted_reduceop_multioutput_dual_use(self): @@ -401,11 +393,10 @@ class TestAssign(unittest.TestCase): def test_permuted_assignment_masked_view_not_contiguous(self): a = Tensor.ones(4, 4).contiguous().realize() - with self.assert_permuted_assign(): - b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0) - a.assign(a + b) - a.realize() - self.assertListEqual(a.tolist(), [[2.,2.,2.,2.],[2.,2.,2.,2.],[3.,3.,3.,3.], [3.,3.,3.,3.]]) + b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0) + a.assign(a + b) + a.realize() + self.assertListEqual(a.tolist(), [[2.,2.,2.,2.],[2.,2.,2.,2.],[3.,3.,3.,3.], [3.,3.,3.,3.]]) # TODO: is there a way to sneak in a permute such that it returns the wrong answer? diff --git a/test/test_ops.py b/test/test_ops.py index 952f3a84f0..6f33a15548 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3164,8 +3164,8 @@ class TestOps(unittest.TestCase): helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf)) - @unittest.skipIf(RANGEIFY and (getenv("MOCKGPU") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold") - @unittest.skipIf(RANGEIFY and Device.DEFAULT == "WEBGPU", "webgpu runtime issue") + @unittest.skipIf((getenv("MOCKGPU") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold") + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu runtime issue") def test_masked_select(self): helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True) helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index fe4a673d8d..4f22a1dcc8 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,9 +1,8 @@ import unittest from tinygrad import Tensor, nn -from tinygrad.helpers import RANGEIFY, Context, GlobalCounters +from tinygrad.helpers import Context, GlobalCounters from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops -@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") class TestRangeifyAssign(unittest.TestCase): def test_assign_permuted(self): A = Tensor.empty(4, 4, dtype='int') @@ -55,7 +54,6 @@ class TestRangeifyOpt(unittest.TestCase): A = Tensor.empty(8,8,8,8).permute(1,0,3,2).flatten() A.sum().realize() -@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") class TestRangeify(unittest.TestCase): def test_groupnorm(self): # ranges 1 and 3 are merging @@ -230,7 +228,6 @@ class TestRangeify(unittest.TestCase): # contiguous + reduce can support ranges? @unittest.skip("okay to disable this for now") -@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") class TestOuterworld(unittest.TestCase): def test_passthrough_range(self): t = Tensor.rand(10, 10).realize() diff --git a/test/test_schedule.py b/test/test_schedule.py index e13292e0e5..c06fadbe45 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -17,7 +17,6 @@ from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, from tinygrad.schedule.rangeify import get_rangeify_map, Kernel from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule -from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails class KernelCountException(Exception): pass def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True): @@ -117,7 +116,7 @@ class TestSchedule(unittest.TestCase): a = Tensor.empty(10) b = Tensor.empty((1,), device="CPU").expand(10).contiguous() c = a+b - with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2 if RANGEIFY else 1) + with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2) @unittest.skipUnless(is_dtype_supported(dtypes.half) and getenv("CAST_AFTER_EXPAND"), "need half and CAST_AFTER_EXPAND=1") @unittest.skip("CAST_AFTER_EXPAND is not supported") @@ -343,7 +342,7 @@ class TestSchedule(unittest.TestCase): r1 = (x - r0).sum(axis=0).div(2) out0 = r0 + y out1 = r1 + y - schedule = check_schedule([out0, out1], 2 if RANGEIFY else 4) + schedule = check_schedule([out0, out1], 2) reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}] assert len(reduceops) in [2,3] # why is RANGEIFY different? @@ -712,7 +711,7 @@ class TestSchedule(unittest.TestCase): check_schedule(b, 0) self.assertEqual(b.item(), 1) - @expect_rangeify_fails + @unittest.expectedFailure def test_multioutput_ast(self): a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop @@ -919,7 +918,7 @@ class TestSchedule(unittest.TestCase): out0 = a.sum() + 2 out1 = a.sum() + 4 out2 = out0 * out1 - run_schedule(check_schedule([out0, out1, out2], 1 if RANGEIFY else 4)) + run_schedule(check_schedule([out0, out1, out2], 1)) np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6) @@ -930,7 +929,7 @@ class TestSchedule(unittest.TestCase): out0 = a.sum().exp2() # out1 has two paths to a.sum() out1 = a.sum() + out0 - run_schedule(check_schedule([out0, out1], 1 if RANGEIFY else 3)) + run_schedule(check_schedule([out0, out1], 1)) np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6) @@ -1022,7 +1021,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty(10,) c = a.sum() + b[0] d = a.sum() + 2 - check_schedule([c, d], 1 if RANGEIFY else 3) + check_schedule([c, d], 1) def test_reduce_multiple_paths_midshrink(self): a = Tensor.empty(4, 4) @@ -1186,14 +1185,14 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") - @expect_rangeify_fails + @unittest.expectedFailure def test_softmax_upcast(self): # input half, softmax in float Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize() out = x.softmax(dtype=dtypes.float) sched = out.schedule() - self.assertEqual(len(sched), 2 if RANGEIFY else 3) + self.assertEqual(len(sched), 2) self.assertEqual(sched[0].bufs[0].dtype, dtypes.half) # input float, softmax in float @@ -1323,7 +1322,7 @@ class TestSchedule(unittest.TestCase): check_schedule(opt.schedule_step(), 14) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") - @expect_rangeify_fails + @unittest.expectedFailure def test_prefer_half_buffer(self): x = Tensor.ones(4).contiguous().realize() # y = Tensor.ones(4).contiguous().realize() @@ -1475,7 +1474,7 @@ class TestSchedule(unittest.TestCase): e = c * d f = b.sum() - e # run_schedule(check_schedule([c, d, e, f], 1)) - run_schedule(check_schedule([c, d, e, f], 2 if RANGEIFY else 5)) + run_schedule(check_schedule([c, d, e, f], 2)) np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4) @@ -1690,7 +1689,7 @@ class TestSchedule(unittest.TestCase): def test_late_fusion_post_expand(self): self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2) - @expect_rangeify_fails + @unittest.expectedFailure def test_cast_padded_view(self): a = Tensor.arange(4).reshape(1, 4) casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float) @@ -1720,7 +1719,7 @@ class TestSchedule(unittest.TestCase): self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) @given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all)) - @expect_rangeify_fails + @unittest.expectedFailure def test_cast_padded_const(self, dt1, dt2): assume(is_dtype_supported(dt1) and is_dtype_supported(dt2)) a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None)) @@ -1891,9 +1890,7 @@ class TestSchedule(unittest.TestCase): tst = x.shrink((None, (0, 2))).assign(a).realize() xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy() np.testing.assert_equal(x.numpy(), xref) - if RANGEIFY > 0: - # NOTE: this is a bug on non rangeify - np.testing.assert_equal(tst.numpy(), a.numpy()) + np.testing.assert_equal(tst.numpy(), a.numpy()) def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1): a = Tensor.arange(16, device="CPU").reshape(4, 4).contiguous().realize() @@ -1904,7 +1901,6 @@ class TestSchedule(unittest.TestCase): run_schedule(sched) self.assertListEqual(a.tolist(), expected) self.assertEqual(kcount, expected_kcount) - @unittest.skipUnless(RANGEIFY>0, "this asserts on non rangeify") def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2) def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1) @@ -1943,7 +1939,7 @@ class TestSchedule(unittest.TestCase): r = (X+Tensor.arange(16).reshape(4, 4)).sum() out0 = r+2 out1 = r+3 - run_schedule(check_schedule([out0, out1], 1 if RANGEIFY else 3)) + run_schedule(check_schedule([out0, out1], 1)) r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum() np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7) np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7) @@ -2088,7 +2084,7 @@ class TestView(unittest.TestCase): run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) - @expect_rangeify_fails + @unittest.expectedFailure def test_mask_dim_1(self): # mask out dim = 1 works too a = Tensor.rand(10, 10).realize() @@ -2236,7 +2232,6 @@ class TestCopyFolding(unittest.TestCase): b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) - @expect_nonrangeify_fails def test_permute_on_disk_contiguous(self): with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") @@ -2251,8 +2246,6 @@ class TestCopyFolding(unittest.TestCase): self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) # NOTE: disk permute must come after COPY - # TODO: this is wrong because of the permute - @expect_nonrangeify_fails def test_permute_after_shrink_on_disk(self): with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") @@ -2396,12 +2389,8 @@ class TestUOpBecome(unittest.TestCase): a = Tensor.empty(4, 1) b = a.expand(4, 4).reciprocal() check_schedule(b, 1) - if RANGEIFY: - self.assertEqual(b.uop.base.buffer.size, 4) - self.assertEqual(b.uop.shape, (4, 4)) - return - self.assertEqual(b.uop.base.buffer.size, 16) - self.assertEqual(b.uop.st, ShapeTracker.from_shape((4, 4))) + self.assertEqual(b.uop.base.buffer.size, 4) + self.assertEqual(b.uop.shape, (4, 4)) def test_reorder_expand_alt(self): x = Tensor.empty(4, 1) @@ -2410,7 +2399,7 @@ class TestUOpBecome(unittest.TestCase): z = (img*x) / y check_schedule(z, 1) - @expect_rangeify_fails + @unittest.expectedFailure def test_become_existing_buffer(self): a = Tensor.empty(4, 4) b = a*1 @@ -2444,7 +2433,7 @@ class TestUOpBecome(unittest.TestCase): assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {}) # tensors can become another realized tensor source - @expect_rangeify_fails + @unittest.expectedFailure def test_become_existing_buf_simple(self): a = Tensor.empty(4, 4) b = a+0 @@ -2453,14 +2442,14 @@ class TestUOpBecome(unittest.TestCase): self.assertIs(a.uop, b.uop) # they can also chain other movement ops on top of the tensor source - @expect_rangeify_fails + @unittest.expectedFailure def test_become_existing_buf_view(self): a = Tensor.empty(4, 4) b = a.permute((1, 0))+0 check_schedule(b, 0) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st) - @expect_rangeify_fails + @unittest.expectedFailure def test_become_existing_buf_view_alt(self): a = Tensor.empty(4, 4) b = a.permute((1, 0)).reshape((8, 2))+0 @@ -2468,7 +2457,7 @@ class TestUOpBecome(unittest.TestCase): self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) # they can also have other base parents that simplified, in that case we just backtrack to the chained mops - @expect_rangeify_fails + @unittest.expectedFailure def test_become_existing_buf_complex(self): a = Tensor.empty(4, 4) b = (a.permute((1, 0))+0).reshape((8, 2))+0 @@ -2476,7 +2465,7 @@ class TestUOpBecome(unittest.TestCase): self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) assert b.uop.base.op is Ops.BUFFER - @expect_rangeify_fails + @unittest.expectedFailure def test_become_multiple_choices(self): a = Tensor.empty(16) b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 @@ -2494,13 +2483,8 @@ class TestUOpBecome(unittest.TestCase): b.realize() assert a.uop.is_realized assert a.uop.buffer._base is None - # b is a subbuffer of a (buffer_view in non rangeify, rangeify just makes a shrink) - if RANGEIFY: - assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK) - assert b.uop.base is a.uop.base - return - assert b.uop.op is Ops.BUFFER_VIEW - assert b.uop.src[0] is a.uop + assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK) + assert b.uop.base is a.uop.base def test_setitem_offset(self): a = Tensor.full((16,), 0.).contiguous().realize() diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index f28d274dcc..9174a47187 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -2,7 +2,6 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad import Variable, Tensor, TinyJit -from tinygrad.helpers import RANGEIFY import numpy as np class TestSymbolicJit(unittest.TestCase): @@ -27,7 +26,7 @@ class TestSymbolicJit(unittest.TestCase): symbolic = jf(a[:, :vi]).numpy() expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel? + assert_jit_cache_len(jf, 1) def test_add(self): def f(a, b): return (a+b).realize() @@ -80,7 +79,7 @@ class TestSymbolicJit(unittest.TestCase): symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy() expected = f(q, k[:, :i], v[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 4 if RANGEIFY else 5) + assert_jit_cache_len(jf, 4) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() diff --git a/test/test_tensor.py b/test/test_tensor.py index e67d776dbf..b046378118 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -4,7 +4,7 @@ import torch import unittest, copy, mmap, random, math, array from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _METADATA -from tinygrad.helpers import getenv, temp, mv_address, RANGEIFY +from tinygrad.helpers import getenv, temp, mv_address from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from hypothesis import given, settings, strategies as strat from tinygrad.device import is_dtype_supported @@ -872,18 +872,11 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid") self.assertTrue(y.grad.uop.metadata[0].backward) si = Tensor.schedule(out, x.grad, y.grad)[-1] - if not RANGEIFY: - self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}") - self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"}) - bw = [m for m in si.metadata if m.backward] - self.assertEqual(len(bw), 2) - self.assertEqual(bw[0].name, "sigmoid") - else: - self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}") - self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"}) - bw = [m for m in si.metadata if m.backward] - self.assertEqual(len(bw), 1) - self.assertEqual(bw[0].name, "sigmoid") + self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}") + self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"}) + bw = [m for m in si.metadata if m.backward] + self.assertEqual(len(bw), 1) + self.assertEqual(bw[0].name, "sigmoid") class TestIdxUpcast(unittest.TestCase): def _find_op(self, ast: UOp, op: Ops): diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 83dfcf0be6..845ab8b325 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor -from tinygrad.helpers import getenv, GlobalCounters, EMULATE, RANGEIFY +from tinygrad.helpers import getenv, GlobalCounters, EMULATE from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program from tinygrad.renderer import Estimates from tinygrad.codegen import full_rewrite @@ -51,11 +51,8 @@ class TestMemoryCount(unittest.TestCase): a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024) b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024) _, mem = get_stats(a+b) - if RANGEIFY: - # rangeify is smart! - self.assertEqual(mem, 1024 + 2*1024) # 2 lil reads + 1 lil write - else: - self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write + # rangeify is smart! + self.assertEqual(mem, 1024 + 2*1024) # 2 lil reads + 1 lil write def test_self_add(self): a = Tensor.empty(1024, 1024, dtype=dtypes.uint8) diff --git a/test/unit/test_kernelize.py b/test/unit/test_kernelize.py index baa49baff3..e571c1d297 100644 --- a/test/unit/test_kernelize.py +++ b/test/unit/test_kernelize.py @@ -1,7 +1,6 @@ import unittest from tinygrad import Tensor from tinygrad.uop import Ops -from tinygrad.helpers import RANGEIFY class TestKernelize(unittest.TestCase): def test_add_reshaped(self): @@ -18,8 +17,8 @@ class TestKernelize(unittest.TestCase): a1 = a.sum(axis=1) a0 = a1.sum(axis=0) a0.kernelize() - self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2 if RANGEIFY else 3) - self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS if RANGEIFY else Ops.ASSIGN) + self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2) + self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS) # input Tensor and user contiguous kernelize self.assertIs(a0.uop.base.op, Ops.ASSIGN) self.assertIs(a.uop.base.op, Ops.ASSIGN) diff --git a/test/unit/test_shm_tensor.py b/test/unit/test_shm_tensor.py index 0c953a7767..93b26c7568 100644 --- a/test/unit/test_shm_tensor.py +++ b/test/unit/test_shm_tensor.py @@ -1,11 +1,11 @@ import unittest import multiprocessing.shared_memory as shared_memory -from tinygrad.helpers import CI, WIN, RANGEIFY +from tinygrad.helpers import CI, WIN from tinygrad.tensor import Tensor, Device import numpy as np class TestRawShmBuffer(unittest.TestCase): - @unittest.skipIf(WIN and CI and RANGEIFY, "only fails with RANGEIFY on CI windows instance") + @unittest.skipIf(WIN and CI, "only fails on CI windows instance") def test_e2e(self): t = Tensor.randn(2, 2, 2).realize() diff --git a/test/unit/test_winograd.py b/test/unit/test_winograd.py index 0a7855ff06..54b54fc2b1 100644 --- a/test/unit/test_winograd.py +++ b/test/unit/test_winograd.py @@ -35,14 +35,14 @@ class TestWinograd(unittest.TestCase): def test_forward_kernels(self): x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() out = Tensor.conv2d(x,w) - self.assertEqual(len(out.schedule()), 2 if RANGEIFY else 4) + self.assertEqual(len(out.schedule()), 2) def test_backward_kernels(self): x,w = Tensor.empty(1,4,9,9,requires_grad=True).realize(), Tensor.empty(4,4,3,3,requires_grad=True).realize() out = Tensor.conv2d(x,w, padding=1) out.mean().backward() backward_schedule = Tensor.schedule(x.grad, w.grad) - self.assertEqual(len(backward_schedule), 4 if RANGEIFY else 9) + self.assertEqual(len(backward_schedule), 4) def test_counters(self): IC, OC, X, Y = 4,4,9,9 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ceb3c52aff..71c25bd472 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION +from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, FUSE_ATTENTION from tinygrad.helpers import suppress_finalizing from tinygrad.gradient import compute_gradient from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \ @@ -227,7 +227,7 @@ class Tensor(MathTrait): # verify Tensors match the spec if __debug__: type_verify(list(big_sink.toposort()), tensor_uop_spec) - if RANGEIFY and any(isinstance(x._device, tuple) for x in big_sink.toposort()): + if any(isinstance(x._device, tuple) for x in big_sink.toposort()): _apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map") big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))