[bounty] [pr] index validation with z3 (#9981)

* index validation with z3

* Change comment

* toposort -> toposort()

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Sieds Lykles
2025-04-24 14:06:08 +02:00
committed by GitHub
parent 9e49721c47
commit e75be6eafc
12 changed files with 236 additions and 75 deletions

View File

@@ -153,6 +153,8 @@ jobs:
name: Torch Backend Tests
runs-on: ubuntu-latest
timeout-minutes: 15
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -190,6 +192,8 @@ jobs:
name: Torch Backend Tests More
runs-on: ubuntu-latest
timeout-minutes: 15
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -212,6 +216,8 @@ jobs:
name: Tensor Core tests
runs-on: ubuntu-latest
timeout-minutes: 10
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -280,6 +286,8 @@ jobs:
name: Python Backend
runs-on: ubuntu-latest
timeout-minutes: 10
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -376,6 +384,8 @@ jobs:
name: 'GPU IMAGE Tests'
runs-on: ubuntu-22.04
timeout-minutes: 10
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -400,6 +410,8 @@ jobs:
name: 'openpilot Compile Tests'
runs-on: ubuntu-22.04
timeout-minutes: 10
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -423,6 +435,8 @@ jobs:
name: 'ONNX+Optimization Tests'
runs-on: ubuntu-22.04
timeout-minutes: 20
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
@@ -467,6 +481,8 @@ jobs:
name: Models (llvm+cpu+gpu)
runs-on: ubuntu-22.04
timeout-minutes: 10
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -490,6 +506,8 @@ jobs:
name: Linux (DSP)
runs-on: ubuntu-24.04
timeout-minutes: 15
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -559,6 +577,8 @@ jobs:
name: Linux (${{ matrix.backend }})
runs-on: ubuntu-22.04
timeout-minutes: 20
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
@@ -601,6 +621,8 @@ jobs:
name: MacOS (unit)
runs-on: macos-14
timeout-minutes: 20
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
@@ -658,6 +680,8 @@ jobs:
run: |
python3 -m pytest -n=auto test/test_hcq.py test/test_tiny.py --durations=20
- name: Run process replay tests
env:
IGNORE_OOB: 1
uses: ./.github/actions/process-replay
osxwebgpu:
@@ -689,6 +713,8 @@ jobs:
name: MacOS (${{ matrix.backend }})
runs-on: macos-15
timeout-minutes: 20
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -723,6 +749,8 @@ jobs:
name: Windows (${{ matrix.backend }})
runs-on: windows-latest
timeout-minutes: 15
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4

View File

@@ -1,7 +1,7 @@
import unittest
import numpy as np
from tinygrad.helpers import BEAM, Timing, CI
from tinygrad.helpers import BEAM, Timing, CI, Context
from tinygrad import Variable, Tensor
from tinygrad.nn import Conv2d
@@ -16,8 +16,9 @@ class TestBeamSearch(unittest.TestCase):
BEAM.value = self.old_beam
def test_variable_ast_beam(self):
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()
with Context(IGNORE_OOB=1):
a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3))
a = (a+1).realize()
def test_big_prime_number(self):
a = rand(367, 367)
@@ -43,14 +44,16 @@ class TestBeamSearch(unittest.TestCase):
v = Variable("v", 1, 400).bind(367)
a = rand(367, 367)
b = rand(367, 367)
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
with Context(IGNORE_OOB=1):
c = (a.reshape(367, v) @ b.reshape(v, 367)).realize()
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4)
def test_variable_shrink_prime_number(self):
v = Variable("v", 1, 400).bind(367)
a = rand(400, 367)
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
with Context(IGNORE_OOB=1):
b = (a.shrink(((0,v), None))+1).reshape(367,367).realize()
np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4)
def test_no_mutate_rawbuffers(self):
a = rand(3, 3).realize()

View File

@@ -1,7 +1,7 @@
import unittest
import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.helpers import CI, fetch
from tinygrad.helpers import CI, fetch, Context
from tinygrad import Device, dtypes
from tinygrad.device import is_dtype_supported
@@ -24,15 +24,24 @@ class TestWhisper(unittest.TestCase):
model, enc = init_whisper("tiny.en", batch_size=2)
cls.model = model
cls.enc = enc
# TODO: whisper has out of bounds access somewhere
cls.context = Context(IGNORE_OOB=1)
cls.context.__enter__()
@classmethod
def tearDownClass(cls):
cls.context.__exit__(None, None, None)
del cls.model
del cls.enc
def test_transcribe_file1(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
@unittest.expectedFailure # Test for out of bounds access
def test_transcribe_file1_OOB(self):
with Context(IGNORE_OOB=0):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI")
def test_transcribe_file2(self):
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)

View File

@@ -6,7 +6,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
from tinygrad.helpers import CI, Context
from test.external.fuzz_linearizer import compare_linearizer
from test.helpers import ast_const
@@ -1428,7 +1428,8 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CONST, dtypes.int, arg=-1, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50257, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
opts = [Opt(op=OptOps.GROUPTOP, axis=0, arg=29), Opt(op=OptOps.PADTO, axis=0, arg=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL", "GPU", "AMD", "NV"])
with Context(IGNORE_OOB=0):
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL", "GPU", "AMD", "NV", "CUDA"])
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,6 @@
import unittest
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.helpers import Context
import numpy as np
class TestSetitem(unittest.TestCase):
@@ -131,20 +132,21 @@ class TestSetitem(unittest.TestCase):
np.testing.assert_allclose(t.numpy(), n)
def test_jit_setitem_variable_offset(self):
@TinyJit
def f(t:Tensor, a:Tensor, v:Variable):
t.shrink(((v,v+1), None)).assign(a).realize()
with Context(IGNORE_OOB=1):
@TinyJit
def f(t:Tensor, a:Tensor, v:Variable):
t.shrink(((v,v+1), None)).assign(a).realize()
t = Tensor.zeros(6, 6).contiguous().realize()
n = np.zeros((6, 6))
t = Tensor.zeros(6, 6).contiguous().realize()
n = np.zeros((6, 6))
for i in range(6):
v = Variable("v", 0, 6).bind(i)
a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
n[i, :] = i+1
f(t, a, v)
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
for i in range(6):
v = Variable("v", 0, 6).bind(i)
a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
n[i, :] = i+1
f(t, a, v)
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
def test_setitem_overlapping_inplace1(self):
t = Tensor([[3.0], [2.0], [1.0]]).contiguous()

View File

@@ -2,9 +2,18 @@ import unittest
from test.helpers import assert_jit_cache_len
from tinygrad import Variable, Tensor, TinyJit
from tinygrad.helpers import Context
import numpy as np
class TestSymbolicJit(unittest.TestCase):
def setUp(self):
# A lot of these test are out of bounds, so we ignore the bounds check
self.context = Context(IGNORE_OOB=1)
self.context.__enter__()
def tearDown(self):
self.context.__exit__(None, None, None)
def test_plus1(self):
def f(a): return (a+1).realize()
jf = TinyJit(f)
@@ -290,4 +299,4 @@ class TestSymbolicJit(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -1,10 +1,19 @@
import unittest
from tinygrad import Variable
from tinygrad.helpers import Context
from tinygrad.tensor import Tensor
from examples.gpt2 import Attention
import numpy as np
class TestSymbolicOps(unittest.TestCase):
def setUp(self):
# A lot of these test are out of bounds, so we ignore the bounds check
self.context = Context(IGNORE_OOB=1)
self.context.__enter__()
def tearDown(self):
self.context.__exit__(None, None, None)
def test_plus1(self):
def f(a): return (a+1).realize()
for i in range(1, 5):
@@ -188,4 +197,4 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -1,6 +1,7 @@
import unittest
import numpy as np
from tinygrad import Tensor, Variable
from tinygrad.helpers import Context
class TestTensorVariable(unittest.TestCase):
def test_add_tvar(self):
@@ -22,38 +23,43 @@ class TestTensorVariable(unittest.TestCase):
assert (Tensor(3) * (vv * 4)).item() == 24
def test_symbolic_mean(self):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.mean().item()
assert ret == 1
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.mean().item()
assert ret == 1
def test_symbolic_mean_2d(self):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
ret = t.mean().item()
assert ret == 1
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
ret = t.mean().item()
assert ret == 1
def test_symbolic_mean_2d_axis_1(self):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
ret = t.mean(axis=1).reshape(2, 1).numpy()
assert np.all(ret == 1)
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
vv2 = Variable("b", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv)
ret = t.mean(axis=1).reshape(2, 1).numpy()
assert np.all(ret == 1)
def test_symbolic_mean_2d_add(self):
add_term = Variable("c", 0, 10).bind(1)
vv = Variable("a", 1, 10).bind(1)
vv2 = Variable("b", 1, 10).bind(1)
t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
ret = t.mean().item()
assert ret == 1
with Context(IGNORE_OOB=1):
add_term = Variable("c", 0, 10).bind(1)
vv = Variable("a", 1, 10).bind(1)
vv2 = Variable("b", 1, 10).bind(1)
t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term)
ret = t.mean().item()
assert ret == 1
def test_symbolic_var(self):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.var().item()
assert ret == 0
with Context(IGNORE_OOB=1):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.var().item()
assert ret == 0
def test_symbolic_pad(self):
vv = Variable("a", 1, 10).bind(2)

View File

@@ -73,15 +73,17 @@ class TestTiny(unittest.TestCase):
def test_symbolic(self):
i = Variable('i', 1, 10)
for s in [2,5]:
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)) + 1
self.assertListEqual(ret.reshape(s).tolist(), [2.0]*s)
with Context(IGNORE_OOB=1):
for s in [2,5]:
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)) + 1
self.assertListEqual(ret.reshape(s).tolist(), [2.0]*s)
def test_symbolic_reduce(self):
i = Variable('i', 1, 10)
for s in [2,5]:
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
self.assertEqual(ret.item(), s)
with Context(IGNORE_OOB=1):
for s in [2,5]:
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
self.assertEqual(ret.item(), s)
# *** a model ***
@@ -115,4 +117,3 @@ class TestTiny(unittest.TestCase):
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,7 @@
from typing import List
import unittest, time, pytest
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad import dtypes, Device, Variable
from tinygrad.helpers import DEBUG, Context
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
@@ -446,10 +446,70 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
def test_out_of_bounds_access(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0)),))
to_uops_list([ld0])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15)),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7)),))
to_uops_list([ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10)),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15)),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_out_of_bounds_off_by_one_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtype=dtypes.int, arg=("gidx0", 42))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<17),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<10),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<15),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<20),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtype=dtypes.int, arg=("gidx0", 42))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_fold_gated_load(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)

View File

@@ -117,7 +117,7 @@ PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROF
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
QUANTIZE, VALIDATE_WITH_CPU, IGNORE_OOB = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("IGNORE_OOB", 1)
@dataclass(frozen=True)
class Metadata:

View File

@@ -1,7 +1,33 @@
from typing import cast
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
from typing import cast, Callable
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, RewriteContext
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
from tinygrad.helpers import all_same, dedup, prod, getenv, DEBUG
from tinygrad.helpers import all_same, dedup, prod, DEBUG, IGNORE_OOB
try:
import z3
z3_imported = True
# IDIV is truncated division but z3 does floored division; mod by power of two sometimes uses Ops.AND
def z3_cdiv(a,b): return z3.If(a<0, (a+(b-1))/b, a/b)
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b}
def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef:
s = z3.Int(name, ctx=solver.ctx)
solver.add(vmin <= s, s <= vmax)
return s
# ctx is (solver, load_number_dict)
z3_renderer = PatternMatcher([
# Ops.SPECIAL can have symbolic arg but it wont be in the toposort beacuse its not a src, we need to add it manually
(UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))),
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", x.src[0].arg, x.src[1].arg-1, ctx[0]))),
(UPat(Ops.LOAD, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
(UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0]),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))),
])
except ImportError: z3_imported = False
buffer_spec = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
@@ -57,16 +83,23 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# ***** uop type spec *****
def validate_index(idx:UOp, mask:UOp|None=None):
if getenv("IGNORE_OOB"): return True
# this checks for out of bounds access. it is not complete but should catch some issues
if mask is None and not isinstance(idx.dtype, ImageDType):
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op in {Ops.DEFINE_VAR, Ops.BITCAST} or (x.op is Ops.SPECIAL and any(not isinstance(y, int) for y in x.arg[1:])) for x in idx.toposort()):
return True
vmin, vmax, sz = idx.src[1].vmin, idx.src[1].vmax, cast(PtrDType, idx.src[0].dtype).size
if sz != -1 and (vmin < 0 or vmax >= sz):
if DEBUG >= 1: print(f"OUT OF BOUNDS ACCESS in INDEX {vmin} - {vmax} not in 0 - {sz}. {idx.src[1].render()=}")
return False
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := cast(PtrDType, idx.src[0].dtype).size) == -1: return True
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 is required for bounds checking, try IGNORE_OOB=0 or \"pip install z3-solver\"")
solver = z3.Solver(ctx=z3.Context())
rewriter = RewriteContext(z3_renderer, ctx=(solver, {})) # Use RewriteContext directly to keep rewrite cache between index and mask
z3_idx = rewriter.top_down_rewrite(idx.src[1]).arg
if mask is not None: solver.add(rewriter.top_down_rewrite(mask).arg)
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
print(f"idx={idx.src[1].render(simplify=False)}")
if mask is not None: print(f"mask={mask.render(simplify=False)}")
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
return False
return True
# this is the matcher for the final rendered UOps