From e75be6eafc5f1c413c056ab006cb6a8d8e0e7254 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Thu, 24 Apr 2025 14:06:08 +0200 Subject: [PATCH] [bounty] [pr] index validation with z3 (#9981) * index validation with z3 * Change comment * toposort -> toposort() --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .github/workflows/test.yml | 28 ++++++++++ extra/optimization/test_beam_search.py | 17 +++--- test/models/test_whisper.py | 11 +++- test/test_linearizer_failures.py | 5 +- test/test_setitem.py | 26 +++++----- test/test_symbolic_jit.py | 11 +++- test/test_symbolic_ops.py | 11 +++- test/test_tensor_variable.py | 54 ++++++++++--------- test/test_tiny.py | 15 +++--- test/test_uop_graph.py | 72 +++++++++++++++++++++++--- tinygrad/helpers.py | 2 +- tinygrad/spec.py | 59 ++++++++++++++++----- 12 files changed, 236 insertions(+), 75 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1a43e107e1..0a9797f77b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -153,6 +153,8 @@ jobs: name: Torch Backend Tests runs-on: ubuntu-latest timeout-minutes: 15 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -190,6 +192,8 @@ jobs: name: Torch Backend Tests More runs-on: ubuntu-latest timeout-minutes: 15 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -212,6 +216,8 @@ jobs: name: Tensor Core tests runs-on: ubuntu-latest timeout-minutes: 10 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -280,6 +286,8 @@ jobs: name: Python Backend runs-on: ubuntu-latest timeout-minutes: 10 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -376,6 +384,8 @@ jobs: name: 'GPU IMAGE Tests' runs-on: ubuntu-22.04 timeout-minutes: 10 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -400,6 +410,8 @@ jobs: name: 'openpilot Compile Tests' runs-on: ubuntu-22.04 timeout-minutes: 10 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -423,6 +435,8 @@ jobs: name: 'ONNX+Optimization Tests' runs-on: ubuntu-22.04 timeout-minutes: 20 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code @@ -467,6 +481,8 @@ jobs: name: Models (llvm+cpu+gpu) runs-on: ubuntu-22.04 timeout-minutes: 10 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -490,6 +506,8 @@ jobs: name: Linux (DSP) runs-on: ubuntu-24.04 timeout-minutes: 15 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -559,6 +577,8 @@ jobs: name: Linux (${{ matrix.backend }}) runs-on: ubuntu-22.04 timeout-minutes: 20 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code @@ -601,6 +621,8 @@ jobs: name: MacOS (unit) runs-on: macos-14 timeout-minutes: 20 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code @@ -658,6 +680,8 @@ jobs: run: | python3 -m pytest -n=auto test/test_hcq.py test/test_tiny.py --durations=20 - name: Run process replay tests + env: + IGNORE_OOB: 1 uses: ./.github/actions/process-replay osxwebgpu: @@ -689,6 +713,8 @@ jobs: name: MacOS (${{ matrix.backend }}) runs-on: macos-15 timeout-minutes: 20 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 @@ -723,6 +749,8 @@ jobs: name: Windows (${{ matrix.backend }}) runs-on: windows-latest timeout-minutes: 15 + env: + IGNORE_OOB: 0 steps: - name: Checkout Code uses: actions/checkout@v4 diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index 8571b863c7..24c3f943b7 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -1,7 +1,7 @@ import unittest import numpy as np -from tinygrad.helpers import BEAM, Timing, CI +from tinygrad.helpers import BEAM, Timing, CI, Context from tinygrad import Variable, Tensor from tinygrad.nn import Conv2d @@ -16,8 +16,9 @@ class TestBeamSearch(unittest.TestCase): BEAM.value = self.old_beam def test_variable_ast_beam(self): - a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) - a = (a+1).realize() + with Context(IGNORE_OOB=1): + a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) + a = (a+1).realize() def test_big_prime_number(self): a = rand(367, 367) @@ -43,14 +44,16 @@ class TestBeamSearch(unittest.TestCase): v = Variable("v", 1, 400).bind(367) a = rand(367, 367) b = rand(367, 367) - c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() - np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) + with Context(IGNORE_OOB=1): + c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() + np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) def test_variable_shrink_prime_number(self): v = Variable("v", 1, 400).bind(367) a = rand(400, 367) - b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() - np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) + with Context(IGNORE_OOB=1): + b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() + np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) def test_no_mutate_rawbuffers(self): a = rand(3, 3).realize() diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index f1696fd490..9837668bdd 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -1,7 +1,7 @@ import unittest import pathlib from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform -from tinygrad.helpers import CI, fetch +from tinygrad.helpers import CI, fetch, Context from tinygrad import Device, dtypes from tinygrad.device import is_dtype_supported @@ -24,15 +24,24 @@ class TestWhisper(unittest.TestCase): model, enc = init_whisper("tiny.en", batch_size=2) cls.model = model cls.enc = enc + # TODO: whisper has out of bounds access somewhere + cls.context = Context(IGNORE_OOB=1) + cls.context.__enter__() @classmethod def tearDownClass(cls): + cls.context.__exit__(None, None, None) del cls.model del cls.enc def test_transcribe_file1(self): self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1) + @unittest.expectedFailure # Test for out of bounds access + def test_transcribe_file1_OOB(self): + with Context(IGNORE_OOB=0): + self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1) + @unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI") def test_transcribe_file2(self): self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index eb1cc5e9cd..341e6f0382 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -6,7 +6,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.ops import UOp, Ops from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor -from tinygrad.helpers import CI +from tinygrad.helpers import CI, Context from test.external.fuzz_linearizer import compare_linearizer from test.helpers import ast_const @@ -1428,7 +1428,8 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CONST, dtypes.int, arg=-1, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50257, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUPTOP, axis=0, arg=29), Opt(op=OptOps.PADTO, axis=0, arg=32)] - helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL", "GPU", "AMD", "NV"]) + with Context(IGNORE_OOB=0): + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL", "GPU", "AMD", "NV", "CUDA"]) if __name__ == '__main__': unittest.main() diff --git a/test/test_setitem.py b/test/test_setitem.py index 06f42dbe3f..d3f066bad4 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -1,5 +1,6 @@ import unittest from tinygrad import Tensor, TinyJit, Variable, dtypes +from tinygrad.helpers import Context import numpy as np class TestSetitem(unittest.TestCase): @@ -131,20 +132,21 @@ class TestSetitem(unittest.TestCase): np.testing.assert_allclose(t.numpy(), n) def test_jit_setitem_variable_offset(self): - @TinyJit - def f(t:Tensor, a:Tensor, v:Variable): - t.shrink(((v,v+1), None)).assign(a).realize() + with Context(IGNORE_OOB=1): + @TinyJit + def f(t:Tensor, a:Tensor, v:Variable): + t.shrink(((v,v+1), None)).assign(a).realize() - t = Tensor.zeros(6, 6).contiguous().realize() - n = np.zeros((6, 6)) + t = Tensor.zeros(6, 6).contiguous().realize() + n = np.zeros((6, 6)) - for i in range(6): - v = Variable("v", 0, 6).bind(i) - a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous() - n[i, :] = i+1 - f(t, a, v) - np.testing.assert_allclose(t.numpy(), n) - np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]]) + for i in range(6): + v = Variable("v", 0, 6).bind(i) + a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous() + n[i, :] = i+1 + f(t, a, v) + np.testing.assert_allclose(t.numpy(), n) + np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]]) def test_setitem_overlapping_inplace1(self): t = Tensor([[3.0], [2.0], [1.0]]).contiguous() diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 05608180aa..7c2d171bf1 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -2,9 +2,18 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad import Variable, Tensor, TinyJit +from tinygrad.helpers import Context import numpy as np class TestSymbolicJit(unittest.TestCase): + def setUp(self): + # A lot of these test are out of bounds, so we ignore the bounds check + self.context = Context(IGNORE_OOB=1) + self.context.__enter__() + + def tearDown(self): + self.context.__exit__(None, None, None) + def test_plus1(self): def f(a): return (a+1).realize() jf = TinyJit(f) @@ -290,4 +299,4 @@ class TestSymbolicJit(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index a654654528..b877b2178b 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,10 +1,19 @@ import unittest from tinygrad import Variable +from tinygrad.helpers import Context from tinygrad.tensor import Tensor from examples.gpt2 import Attention import numpy as np class TestSymbolicOps(unittest.TestCase): + def setUp(self): + # A lot of these test are out of bounds, so we ignore the bounds check + self.context = Context(IGNORE_OOB=1) + self.context.__enter__() + + def tearDown(self): + self.context.__exit__(None, None, None) + def test_plus1(self): def f(a): return (a+1).realize() for i in range(1, 5): @@ -188,4 +197,4 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index eb7ddeb386..a680cf9a6e 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -1,6 +1,7 @@ import unittest import numpy as np from tinygrad import Tensor, Variable +from tinygrad.helpers import Context class TestTensorVariable(unittest.TestCase): def test_add_tvar(self): @@ -22,38 +23,43 @@ class TestTensorVariable(unittest.TestCase): assert (Tensor(3) * (vv * 4)).item() == 24 def test_symbolic_mean(self): - vv = Variable("a", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.mean().item() - assert ret == 1 + with Context(IGNORE_OOB=1): + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 2).contiguous().reshape(2, vv) + ret = t.mean().item() + assert ret == 1 def test_symbolic_mean_2d(self): - vv = Variable("a", 1, 10).bind(2) - vv2 = Variable("b", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) - ret = t.mean().item() - assert ret == 1 + with Context(IGNORE_OOB=1): + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) + t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) + ret = t.mean().item() + assert ret == 1 def test_symbolic_mean_2d_axis_1(self): - vv = Variable("a", 1, 10).bind(2) - vv2 = Variable("b", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) - ret = t.mean(axis=1).reshape(2, 1).numpy() - assert np.all(ret == 1) + with Context(IGNORE_OOB=1): + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) + t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) + ret = t.mean(axis=1).reshape(2, 1).numpy() + assert np.all(ret == 1) def test_symbolic_mean_2d_add(self): - add_term = Variable("c", 0, 10).bind(1) - vv = Variable("a", 1, 10).bind(1) - vv2 = Variable("b", 1, 10).bind(1) - t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term) - ret = t.mean().item() - assert ret == 1 + with Context(IGNORE_OOB=1): + add_term = Variable("c", 0, 10).bind(1) + vv = Variable("a", 1, 10).bind(1) + vv2 = Variable("b", 1, 10).bind(1) + t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term) + ret = t.mean().item() + assert ret == 1 def test_symbolic_var(self): - vv = Variable("a", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.var().item() - assert ret == 0 + with Context(IGNORE_OOB=1): + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 2).contiguous().reshape(2, vv) + ret = t.var().item() + assert ret == 0 def test_symbolic_pad(self): vv = Variable("a", 1, 10).bind(2) diff --git a/test/test_tiny.py b/test/test_tiny.py index b83ba6f04a..0d6798fbab 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -73,15 +73,17 @@ class TestTiny(unittest.TestCase): def test_symbolic(self): i = Variable('i', 1, 10) - for s in [2,5]: - ret = Tensor.ones(s).contiguous().reshape(i.bind(s)) + 1 - self.assertListEqual(ret.reshape(s).tolist(), [2.0]*s) + with Context(IGNORE_OOB=1): + for s in [2,5]: + ret = Tensor.ones(s).contiguous().reshape(i.bind(s)) + 1 + self.assertListEqual(ret.reshape(s).tolist(), [2.0]*s) def test_symbolic_reduce(self): i = Variable('i', 1, 10) - for s in [2,5]: - ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum() - self.assertEqual(ret.item(), s) + with Context(IGNORE_OOB=1): + for s in [2,5]: + ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum() + self.assertEqual(ret.item(), s) # *** a model *** @@ -115,4 +117,3 @@ class TestTiny(unittest.TestCase): if __name__ == '__main__': unittest.main() - diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 6a1355c7f4..ed8c938dd1 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -1,7 +1,7 @@ from typing import List import unittest, time, pytest -from tinygrad import dtypes, Device -from tinygrad.helpers import DEBUG +from tinygrad import dtypes, Device, Variable +from tinygrad.helpers import DEBUG, Context from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index @@ -446,10 +446,70 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([v.bitcast(dt)]) self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}") - def test_out_of_bounds_access(self): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),)) - with self.assertRaises(RuntimeError): to_uops_list([ld0]) + def test_in_out_of_bounds_access(self): + with Context(IGNORE_OOB=0): + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),)) + to_uops_list([ld0]) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),)) + to_uops_list([ld1]) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),)) + to_uops_list([ld1]) + + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),)) + with self.assertRaises(RuntimeError): to_uops_list([ld0]) + + def test_in_out_of_bounds_access_symbolic(self): + with Context(IGNORE_OOB=0): + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),)) + to_uops_list([ld0]) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),)) + to_uops_list([ld0]) + + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),)) + with self.assertRaises(RuntimeError): to_uops_list([ld0]) + + def test_out_of_bounds_off_by_one_access(self): + with Context(IGNORE_OOB=0): + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),)) + with self.assertRaises(RuntimeError): to_uops_list([ld0]) + + def test_in_out_bounds_access_with_mask(self): + with Context(IGNORE_OOB=0): + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + gidx0 = UOp(Ops.SPECIAL, dtype=dtypes.int, arg=("gidx0", 42)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5=0)&(ld0<32)),)) + to_uops_list([ld1]) + + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),)) + with self.assertRaises(RuntimeError): to_uops_list([ld1]) def test_fold_gated_load(self): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 0fcf90d1b0..9014fb0b0c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -117,7 +117,7 @@ PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROF CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) -QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0) +QUANTIZE, VALIDATE_WITH_CPU, IGNORE_OOB = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("IGNORE_OOB", 1) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 6760c83176..ccaf78392f 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -1,7 +1,33 @@ -from typing import cast -from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops +from typing import cast, Callable +from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, RewriteContext from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType -from tinygrad.helpers import all_same, dedup, prod, getenv, DEBUG +from tinygrad.helpers import all_same, dedup, prod, DEBUG, IGNORE_OOB +try: + import z3 + z3_imported = True + + # IDIV is truncated division but z3 does floored division; mod by power of two sometimes uses Ops.AND + def z3_cdiv(a,b): return z3.If(a<0, (a+(b-1))/b, a/b) + z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()), + Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b} + def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef: + s = z3.Int(name, ctx=solver.ctx) + solver.add(vmin <= s, s <= vmax) + return s + + # ctx is (solver, load_number_dict) + z3_renderer = PatternMatcher([ + # Ops.SPECIAL can have symbolic arg but it wont be in the toposort beacuse its not a src, we need to add it manually + (UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))), + (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))), + (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))), + (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", x.src[0].arg, x.src[1].arg-1, ctx[0]))), + (UPat(Ops.LOAD, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))), + (UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))), + (UPat(Ops.CAST, name="x"), lambda x: x.src[0]), + (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))), + ]) +except ImportError: z3_imported = False buffer_spec = PatternMatcher([ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), @@ -57,16 +83,23 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # ***** uop type spec ***** def validate_index(idx:UOp, mask:UOp|None=None): - if getenv("IGNORE_OOB"): return True - # this checks for out of bounds access. it is not complete but should catch some issues - if mask is None and not isinstance(idx.dtype, ImageDType): - # WEBGPU has a BITCAST in the index. TODO: fix - if any(x.op in {Ops.DEFINE_VAR, Ops.BITCAST} or (x.op is Ops.SPECIAL and any(not isinstance(y, int) for y in x.arg[1:])) for x in idx.toposort()): - return True - vmin, vmax, sz = idx.src[1].vmin, idx.src[1].vmax, cast(PtrDType, idx.src[0].dtype).size - if sz != -1 and (vmin < 0 or vmax >= sz): - if DEBUG >= 1: print(f"OUT OF BOUNDS ACCESS in INDEX {vmin} - {vmax} not in 0 - {sz}. {idx.src[1].render()=}") - return False + if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := cast(PtrDType, idx.src[0].dtype).size) == -1: return True + # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask + if 0<=idx.src[1].vmin and idx.src[1].vmax