consolidate IGNORE_OOB=0 tests (#13937)

add a new unit test file and add more cases
This commit is contained in:
chenyu
2025-12-31 15:24:20 -05:00
committed by GitHub
parent c0c1c1c8c8
commit 52acadc160
2 changed files with 179 additions and 137 deletions

View File

@@ -478,143 +478,6 @@ class TestUOpGraph(unittest.TestCase):
for u in uops:
self.assertNotEqual(u.dtype, dtypes.long)
def test_in_out_of_bounds_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
to_uops_list([ld0])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
to_uops_list([ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_gated_store(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
v = Variable("v", 0, 20)
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
to_uops_list([st0])
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
with self.assertRaises(RuntimeError): to_uops_list([st1])
@unittest.skip("if not allowed in graph")
def test_in_bounds_access_gated_local(self):
with Context(IGNORE_OOB=0):
# Define buffers
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
# Define indices, valids and barrier
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
gate = (gidx<400) & (lidx<8)
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
# Load from local memory (after the IF/barrier)
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
# Store to global memory
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
to_uops_list([global_store])
def test_load_with_float_in_index(self):
with Context(IGNORE_OOB=0):
ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
i = (ldfloat+3.14).cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
def test_load_cast_to_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet")
def test_load_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])
def test_out_of_bounds_off_by_one_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_bounds_with_loaded_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
ld0 = glbl0.index(gidx0, ptr=True).load()
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_fold_gated_load(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)

View File

@@ -0,0 +1,179 @@
import unittest
from tinygrad import dtypes, Variable
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops, UOp, AxisType
from test.test_uops import to_uops_list
class TestValidateOOB(unittest.TestCase):
"""Test z3 validation of index bounds for different ALU ops and patterns."""
# basic index patterns
def test_const_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.const(dtypes.int, 16), ptr=True).load(dtype=dtypes.int)]) # off by one
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.const(dtypes.int, 42), ptr=True).load(dtype=dtypes.int)]) # way out
def test_variable_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("i", -5, 10), ptr=True).load(dtype=dtypes.int)]) # negative
def test_range_with_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(42, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r.valid(r < 17), ptr=True).load(dtype=dtypes.int)]) # oob
def test_variable_with_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
v = Variable("v", -5, 80)
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(v.valid(v < 20), ptr=True).load(dtype=dtypes.int)]) # negative not masked
def test_gated_store(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
v = Variable("v", 0, 20)
to_uops_list([buf.index(v.valid(v < 16)).store(0)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
# ALU ops in index
def test_idiv(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
def test_mod(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(100, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r % 20, ptr=True).load(dtype=dtypes.int)]) # 0..19 oob
def test_shr(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
def test_shl(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
r = UOp.range(8, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r << 4, ptr=True).load(dtype=dtypes.int)]) # 0..112 oob
def test_and(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(100, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r & 31, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
def test_max(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
def test_xor_in_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(32, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r.valid((r < 10) ^ (r >= 20)), ptr=True).load(dtype=dtypes.int)]) # 0..9,20..31 oob
# cast patterns
def test_float_cast_in_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(20, 0)
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
def test_bool_cast_in_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
r = UOp.range(20, 0)
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
# load result as index/mask
def test_load_as_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 1)
r = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index)
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob
def test_load_bool_as_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf_bool = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
buf_int = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 1)
gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
ld_bool = buf_bool.index(gidx, ptr=True).load()
with self.assertRaises(RuntimeError):
to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8
# skipped tests (moved from test_uop_graph.py)
@unittest.skip("if not allowed in graph")
def test_in_bounds_access_gated_local(self):
with Context(IGNORE_OOB=0):
# Define buffers
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
# Define indices, valids and barrier
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
gate = (gidx<400) & (lidx<8)
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
# Load from local memory (after the IF/barrier)
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
# Store to global memory
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
to_uops_list([global_store])
@unittest.skip("Bool load is not supported yet")
def test_load_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])
if __name__ == "__main__":
unittest.main()