diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 2bb51e998c..d7c7ebf29c 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -642,7 +642,7 @@ jobs: ln -s /data/home/tiny/tinygrad/testsig-*.so . PYTHONPATH=. CC=clang-19 CPU=1 CPU_LLVM=0 QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx # benchmark on DSP with NOOPT=1, the devectorizer has issues - PYTHONPATH=. CC=clang-19 DSP=1 DONT_REALIZE_EXPAND=1 NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx + PYTHONPATH=. CC=clang-19 DSP=1 NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py - uses: actions/upload-artifact@v4 diff --git a/examples/test_onnx_imagenet.py b/examples/test_onnx_imagenet.py index 11f469aebd..98a560f1b9 100644 --- a/examples/test_onnx_imagenet.py +++ b/examples/test_onnx_imagenet.py @@ -19,8 +19,8 @@ from tinygrad.helpers import fetch, getenv # QUANT=1 python3 examples/test_onnx_imagenet.py # https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx -# DONT_REALIZE_EXPAND=1 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx -# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx +# python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx +# VIZ=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx def imagenet_dataloader(cnt=0): input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) diff --git a/test/external/external_benchmark_bert_softmax.py b/test/external/external_benchmark_bert_softmax.py index 176e2061a4..131b05dce4 100644 --- a/test/external/external_benchmark_bert_softmax.py +++ b/test/external/external_benchmark_bert_softmax.py @@ -1,4 +1,4 @@ -from tinygrad import Tensor, dtypes, Context, GlobalCounters +from tinygrad import Tensor, dtypes, GlobalCounters dtypes.default_float = dtypes.float16 from tinygrad.dtype import to_dtype from tinygrad.helpers import getenv @@ -13,6 +13,5 @@ if __name__ == "__main__": # test single kernel softmax GlobalCounters.reset() - with Context(DONT_GROUP_REDUCES=1): - single_kernel_softmax(t, -1, acc_dtype).realize() + single_kernel_softmax(t, -1, acc_dtype).realize() diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index cfaa44cc5d..dba794cbc8 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -72,7 +72,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase): out_file = get_quantized_model(sz) run_onnx = OnnxRunner(out_file) inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)) - with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1): + with Context(QUANTIZE=1): sched = run_onnx({"input":inp})["output"].schedule() ei = lower_schedule_item(sched[-2]) daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_REG] @@ -86,8 +86,7 @@ class TestQuantizeOnnx(unittest.TestCase): # divide is ~1500-2000 without reduce_range, 750-900 with it out_file = get_quantized_model(sz) run_onnx_jit, _ = load_onnx_model(out_file) - with Context(DONT_REALIZE_EXPAND=1): - run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))) + run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))) def test_prequant_conv2d_1x1(self): X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8)) @@ -109,11 +108,10 @@ class TestQuantizeOnnx(unittest.TestCase): N = 512 X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(xi)) W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(wi)) - with Context(DONT_REALIZE_EXPAND=1): - # this divide is interesting and forces the accumulator to actually be an int - out = (X.cast("int").matmul(W.cast("int"))//1000).cast("int8") - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] - sexec(out, opts) + # this divide is interesting and forces the accumulator to actually be an int + out = (X.cast("int").matmul(W.cast("int"))//1000).cast("int8") + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + sexec(out, opts) def test_prequant_gemm_handcode(self): src = """typedef int int128 __attribute__((aligned(512),vector_size(512))); @@ -203,14 +201,12 @@ class TestQuantizeOnnx(unittest.TestCase): def test_prequant_gemm_intacc(self, xi=np.uint8, wi=np.uint8, replace_src=None, N=512, clip=True, opts=None): X = Tensor(m1:=(np.random.uniform(0, 255, size=(N,N)).astype(xi))).realize() W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize() - # ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant tg_dtype = dtypes.int8 if xi == np.int8 else dtypes.uint8 - with Context(DONT_REALIZE_EXPAND=1): - out = (X.int().matmul(W.int())//1000) - if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype)) - out = out.cast(tg_dtype) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts - sexec(out, opts, replace_src, run_count=1) + out = (X.int().matmul(W.int())//1000) + if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype)) + out = out.cast(tg_dtype) + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts + sexec(out, opts, replace_src, run_count=1) tout = out.numpy() mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) // 1000) if clip: mout = mout.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype)) @@ -225,7 +221,6 @@ class TestQuantizeOnnx(unittest.TestCase): def test_prequant_gemv(self): N = 2048 - # ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant X = Tensor(np.random.uniform(0, 255, size=(1,N)).astype(np.uint8)).realize() W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)).realize() #out = X.cast(dtypes.int) @ W.cast(dtypes.int) diff --git a/test/test_schedule.py b/test/test_schedule.py index 7220bb1263..6d7c855bf8 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -171,8 +171,7 @@ class TestSchedule(unittest.TestCase): def test_rand_recompute_arange(self): x = Tensor.rand(32) - with Context(DONT_GROUP_REDUCES=1): - check_schedule(x, 3, [Tensor._device_rng_counters[x.device]]) + check_schedule(x, 3, [Tensor._device_rng_counters[x.device]]) def test_empty_is_not_realized(self): a = Tensor.empty(10) @@ -276,7 +275,7 @@ class TestSchedule(unittest.TestCase): a = Tensor.randn(10,10,10).realize() b = Tensor.randn(10,10,1).realize() c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b - with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(c, 1)) + run_schedule(check_schedule(c, 1)) np.testing.assert_allclose(c.numpy(), np.sum(a.numpy(), axis=0, keepdims=True).transpose(2,1,0)+b.numpy()) def test_binop_early_reshape_reduce_fusion(self): @@ -1976,8 +1975,7 @@ class TestSwizzle(unittest.TestCase): a = Tensor.randint(32, 32).realize() r = (a+a).sum(1).sum(0) # double reduce collapses to a single reduce - with Context(DONT_GROUP_REDUCES=1): - run_schedule(check_schedule(r, 1)) + run_schedule(check_schedule(r, 1)) self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0)) def test_single_swizzle(self): @@ -1997,33 +1995,29 @@ class TestSwizzle(unittest.TestCase): b = Tensor.randint(4,).realize() # parallel reduce! add = a.sum(0)+b.sum(0) - with Context(DONT_GROUP_REDUCES=1): - run_schedule(check_schedule(add, 1)) + run_schedule(check_schedule(add, 1)) self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0)) - @unittest.skip("TODO: how do we express the norm") def test_softmax_one_kernel(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randn(32, 32).realize() t = a.softmax() - with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): - check_schedule(t, 1) + check_schedule(t, 1) def test_argmax_one_kernel(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randn(10, 20).realize() t = a.argmax(0) - with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): t.realize() + check_schedule(t, 1) def test_swizzle_reduceop(self): Tensor.manual_seed(0) x = Tensor.randn(4,4).realize() y = Tensor.randn(4,4,4).realize() out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y - with Context(DONT_REALIZE_EXPAND=1, DONT_GROUP_REDUCES=1): - run_schedule(check_schedule(out, 1)) + run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy()) def test_permute_rewrite(self): @@ -2031,7 +2025,7 @@ class TestSwizzle(unittest.TestCase): y = Tensor.randn(4, 1, 16).realize() z = Tensor.randn(4, 4, 1).realize() t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z - with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1)) + run_schedule(check_schedule(t, 1)) t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy() np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3) @@ -2042,14 +2036,14 @@ class TestSwizzle(unittest.TestCase): a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,)) b_reduce = b.sum(axis=(0,)) t = a_reduce+b_reduce - with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1)) + run_schedule(check_schedule(t, 1)) def test_parallel_reduce_possible(self): Tensor.manual_seed(0) x = Tensor.randn(4, 2, 2).realize() y = Tensor.randn(4, 2, 2).realize() t = x.sum(axis=1)+y.sum(axis=1) - with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(t, 1)) + run_schedule(check_schedule(t, 1)) np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3) # kernels can only have 1 or n in each dim @@ -2058,7 +2052,7 @@ class TestSwizzle(unittest.TestCase): x = Tensor.randn(4, 2, 2).realize() y = Tensor.randn(4, 3, 2).realize() t = x.sum(axis=1)+y.sum(axis=1) - with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(t, 1)) + run_schedule(check_schedule(t, 1)) np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3) def test_unsafe_pad(self): diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index fc77f9765b..8ccb54f20d 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -165,8 +165,7 @@ class TestSoftmaxFusion(unittest.TestCase): sout.realize() print("*** single kernel softmax ***") - # NOTE: DONT_GROUP_REDUCES is required here - with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1): + with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)): out = single_kernel_softmax(self.test) out.realize() @@ -186,7 +185,6 @@ class TestSoftmaxFusion(unittest.TestCase): np.testing.assert_allclose(sout.numpy(), out.numpy(), atol=3e-7) - @unittest.skip("recursion error no longer raised") def test_softmax_bw(self): print("*** softmax bw ***") self.test.requires_grad_() @@ -197,14 +195,11 @@ class TestSoftmaxFusion(unittest.TestCase): self.test.grad = None print("*** single kernel softmax bw ***") - # NOTE: DONT_GROUP_REDUCES is required here - # TODO: fix RecursionError with DONT_GROUP_REDUCES - with self.assertRaises(RecursionError): - with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2), DONT_GROUP_REDUCES=1): - single_kernel_softmax(self.test).sum().backward() - g = self.test.grad.realize() + with Context(NOOPT=1, DEBUG=max(DEBUG.value, 2)): + single_kernel_softmax(self.test).sum().backward() + g = self.test.grad.realize() - np.testing.assert_allclose(sg.numpy(), g.numpy(), atol=1e-7) + np.testing.assert_allclose(sg.numpy(), g.numpy(), atol=1e-7) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 7b3936b20d..e5fffe2330 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -158,7 +158,6 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) DISABLE_COMPILER_CACHE, BLOCK_REORDER = ContextVar("DISABLE_COMPILER_CACHE", 0), ContextVar("BLOCK_REORDER", 1) -DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0) CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)