diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 5b8bcd094e..87db0b844b 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -478,143 +478,6 @@ class TestUOpGraph(unittest.TestCase): for u in uops: self.assertNotEqual(u.dtype, dtypes.long) - def test_in_out_of_bounds_access(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),)) - to_uops_list([ld0]) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),)) - to_uops_list([ld1]) - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),)) - to_uops_list([ld1]) - - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),)) - with self.assertRaises(RuntimeError): to_uops_list([ld0]) - - def test_in_out_of_bounds_access_symbolic(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),)) - to_uops_list([ld0]) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),)) - to_uops_list([ld0]) - - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),)) - with self.assertRaises(RuntimeError): to_uops_list([ld0]) - - def test_in_out_of_bounds_access_gated_store(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0) - v = Variable("v", 0, 20) - st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0))) - to_uops_list([st0]) - - st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v)) - with self.assertRaises(RuntimeError): to_uops_list([st1]) - - @unittest.skip("if not allowed in graph") - def test_in_bounds_access_gated_local(self): - with Context(IGNORE_OOB=0): - # Define buffers - gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0) - sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0") - - # Define indices, valids and barrier - gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0") - lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0") - - gate = (gidx<400) & (lidx<8) - - local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1))) - - barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,)) - if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier)) - - # Load from local memory (after the IF/barrier) - local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier)) - - # Store to global memory - global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load)) - to_uops_list([global_store]) - - def test_load_with_float_in_index(self): - with Context(IGNORE_OOB=0): - ridx = UOp.range(20, 0) - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),)) - to_uops_list([ld0]) - glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0) - ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),)) - i = (ldfloat+3.14).cast(dtypes.int) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),)) - - def test_load_cast_to_bool(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) - ridx = UOp.range(20, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),)) - to_uops_list([ld0]) - - @unittest.skip("Bool load is not supported yet") - def test_load_mask(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) - ridx = UOp.range(20, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True))) - to_uops_list([ld0]) - - def test_out_of_bounds_off_by_one_access(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),)) - with self.assertRaises(RuntimeError): to_uops_list([ld0]) - - def test_in_out_bounds_access_with_mask(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - gidx0 = UOp.range(42, 0, AxisType.GLOBAL) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5=0)&(ld0<32)), ptr=True),)) - to_uops_list([ld1]) - - ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),)) - with self.assertRaises(RuntimeError): to_uops_list([ld1]) - - def test_bounds_with_loaded_bool(self): - with Context(IGNORE_OOB=0): - glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) - glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0) - gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0") - ld0 = glbl0.index(gidx0, ptr=True).load() - ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load() - with self.assertRaises(RuntimeError): to_uops_list([ld1]) - def test_fold_gated_load(self): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) diff --git a/test/unit/test_validate_oob.py b/test/unit/test_validate_oob.py new file mode 100644 index 0000000000..8f1fab6291 --- /dev/null +++ b/test/unit/test_validate_oob.py @@ -0,0 +1,179 @@ +import unittest +from tinygrad import dtypes, Variable +from tinygrad.dtype import AddrSpace +from tinygrad.helpers import Context +from tinygrad.uop.ops import Ops, UOp, AxisType +from test.test_uops import to_uops_list + +class TestValidateOOB(unittest.TestCase): + """Test z3 validation of index bounds for different ALU ops and patterns.""" + + # basic index patterns + def test_const_index(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid + to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element) + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(UOp.const(dtypes.int, 16), ptr=True).load(dtype=dtypes.int)]) # off by one + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(UOp.const(dtypes.int, 42), ptr=True).load(dtype=dtypes.int)]) # way out + + def test_variable_index(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(Variable("i", -5, 10), ptr=True).load(dtype=dtypes.int)]) # negative + + def test_range_with_mask(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + r = UOp.range(42, 0, AxisType.GLOBAL) + to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(r.valid(r < 17), ptr=True).load(dtype=dtypes.int)]) # oob + + def test_variable_with_mask(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + v = Variable("v", -5, 80) + to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(v.valid(v < 20), ptr=True).load(dtype=dtypes.int)]) # negative not masked + + def test_gated_store(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + v = Variable("v", 0, 20) + to_uops_list([buf.index(v.valid(v < 16)).store(0)]) # valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob + + # ALU ops in index + def test_idiv(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob + + def test_mod(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + r = UOp.range(100, 0, AxisType.GLOBAL) + to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(r % 20, ptr=True).load(dtype=dtypes.int)]) # 0..19 oob + + def test_shr(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob + + def test_shl(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0) + r = UOp.range(8, 0, AxisType.GLOBAL) + to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(r << 4, ptr=True).load(dtype=dtypes.int)]) # 0..112 oob + + def test_and(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + r = UOp.range(100, 0, AxisType.GLOBAL) + to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(r & 31, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob + + def test_max(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob + + def test_xor_in_mask(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + r = UOp.range(32, 0, AxisType.GLOBAL) + to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid + with self.assertRaises(RuntimeError): + to_uops_list([buf.index(r.valid((r < 10) ^ (r >= 20)), ptr=True).load(dtype=dtypes.int)]) # 0..9,20..31 oob + + # cast patterns + def test_float_cast_in_index(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + r = UOp.range(20, 0) + i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int) + to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)]) + + def test_bool_cast_in_mask(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) + r = UOp.range(20, 0) + to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid + + # load result as index/mask + def test_load_as_index(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 1) + r = UOp.range(42, 0, AxisType.GLOBAL) + ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index) + to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid + with self.assertRaises(RuntimeError): + to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob + + def test_load_bool_as_mask(self): + with Context(IGNORE_OOB=0, SPEC=2): + buf_bool = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) + buf_int = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 1) + gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0") + ld_bool = buf_bool.index(gidx, ptr=True).load() + with self.assertRaises(RuntimeError): + to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8 + + # skipped tests (moved from test_uop_graph.py) + @unittest.skip("if not allowed in graph") + def test_in_bounds_access_gated_local(self): + with Context(IGNORE_OOB=0): + # Define buffers + gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0) + sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0") + + # Define indices, valids and barrier + gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0") + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0") + + gate = (gidx<400) & (lidx<8) + + local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1))) + + barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,)) + if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier)) + + # Load from local memory (after the IF/barrier) + local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier)) + + # Store to global memory + global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load)) + to_uops_list([global_store]) + + @unittest.skip("Bool load is not supported yet") + def test_load_mask(self): + with Context(IGNORE_OOB=0): + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) + mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) + ridx = UOp.range(20, 0) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True))) + to_uops_list([ld0]) + +if __name__ == "__main__": + unittest.main()