mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
consolidate IGNORE_OOB=0 tests (#13937)
add a new unit test file and add more cases
This commit is contained in:
@@ -478,143 +478,6 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
for u in uops:
|
for u in uops:
|
||||||
self.assertNotEqual(u.dtype, dtypes.long)
|
self.assertNotEqual(u.dtype, dtypes.long)
|
||||||
|
|
||||||
def test_in_out_of_bounds_access(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
|
|
||||||
to_uops_list([ld1])
|
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
|
|
||||||
to_uops_list([ld1])
|
|
||||||
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
|
||||||
|
|
||||||
def test_in_out_of_bounds_access_symbolic(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
|
||||||
|
|
||||||
def test_in_out_of_bounds_access_gated_store(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
|
||||||
v = Variable("v", 0, 20)
|
|
||||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
|
|
||||||
to_uops_list([st0])
|
|
||||||
|
|
||||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
|
||||||
|
|
||||||
@unittest.skip("if not allowed in graph")
|
|
||||||
def test_in_bounds_access_gated_local(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
# Define buffers
|
|
||||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
|
|
||||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
|
||||||
|
|
||||||
# Define indices, valids and barrier
|
|
||||||
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
|
|
||||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
|
|
||||||
|
|
||||||
gate = (gidx<400) & (lidx<8)
|
|
||||||
|
|
||||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
|
||||||
|
|
||||||
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
|
||||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
|
||||||
|
|
||||||
# Load from local memory (after the IF/barrier)
|
|
||||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
|
||||||
|
|
||||||
# Store to global memory
|
|
||||||
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
|
||||||
to_uops_list([global_store])
|
|
||||||
|
|
||||||
def test_load_with_float_in_index(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
ridx = UOp.range(20, 0)
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
|
||||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
|
||||||
i = (ldfloat+3.14).cast(dtypes.int)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
|
|
||||||
|
|
||||||
def test_load_cast_to_bool(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
|
||||||
ridx = UOp.range(20, 0)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
|
|
||||||
@unittest.skip("Bool load is not supported yet")
|
|
||||||
def test_load_mask(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
|
||||||
ridx = UOp.range(20, 0)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
|
|
||||||
def test_out_of_bounds_off_by_one_access(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
|
||||||
|
|
||||||
def test_in_out_bounds_access_with_mask(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
|
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
|
|
||||||
to_uops_list([ld0, ld1])
|
|
||||||
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
|
||||||
|
|
||||||
def test_in_out_of_bounds_access_symbolic_mask(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
i = Variable("i", 1, 80)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
|
|
||||||
to_uops_list([ld0])
|
|
||||||
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
|
||||||
|
|
||||||
def test_in_out_of_bounds_access_index_load(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
||||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
|
||||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
|
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
|
|
||||||
to_uops_list([ld1])
|
|
||||||
|
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
|
||||||
|
|
||||||
def test_bounds_with_loaded_bool(self):
|
|
||||||
with Context(IGNORE_OOB=0):
|
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
|
||||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
|
|
||||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
|
||||||
ld0 = glbl0.index(gidx0, ptr=True).load()
|
|
||||||
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
|
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
|
||||||
|
|
||||||
def test_fold_gated_load(self):
|
def test_fold_gated_load(self):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||||
|
|||||||
179
test/unit/test_validate_oob.py
Normal file
179
test/unit/test_validate_oob.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import unittest
|
||||||
|
from tinygrad import dtypes, Variable
|
||||||
|
from tinygrad.dtype import AddrSpace
|
||||||
|
from tinygrad.helpers import Context
|
||||||
|
from tinygrad.uop.ops import Ops, UOp, AxisType
|
||||||
|
from test.test_uops import to_uops_list
|
||||||
|
|
||||||
|
class TestValidateOOB(unittest.TestCase):
|
||||||
|
"""Test z3 validation of index bounds for different ALU ops and patterns."""
|
||||||
|
|
||||||
|
# basic index patterns
|
||||||
|
def test_const_index(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||||
|
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(UOp.const(dtypes.int, 16), ptr=True).load(dtype=dtypes.int)]) # off by one
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(UOp.const(dtypes.int, 42), ptr=True).load(dtype=dtypes.int)]) # way out
|
||||||
|
|
||||||
|
def test_variable_index(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(Variable("i", -5, 10), ptr=True).load(dtype=dtypes.int)]) # negative
|
||||||
|
|
||||||
|
def test_range_with_mask(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||||
|
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(r.valid(r < 17), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||||
|
|
||||||
|
def test_variable_with_mask(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
v = Variable("v", -5, 80)
|
||||||
|
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(v.valid(v < 20), ptr=True).load(dtype=dtypes.int)]) # negative not masked
|
||||||
|
|
||||||
|
def test_gated_store(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
v = Variable("v", 0, 20)
|
||||||
|
to_uops_list([buf.index(v.valid(v < 16)).store(0)]) # valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
|
||||||
|
|
||||||
|
# ALU ops in index
|
||||||
|
def test_idiv(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
|
||||||
|
|
||||||
|
def test_mod(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
r = UOp.range(100, 0, AxisType.GLOBAL)
|
||||||
|
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(r % 20, ptr=True).load(dtype=dtypes.int)]) # 0..19 oob
|
||||||
|
|
||||||
|
def test_shr(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
||||||
|
|
||||||
|
def test_shl(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||||
|
r = UOp.range(8, 0, AxisType.GLOBAL)
|
||||||
|
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(r << 4, ptr=True).load(dtype=dtypes.int)]) # 0..112 oob
|
||||||
|
|
||||||
|
def test_and(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
r = UOp.range(100, 0, AxisType.GLOBAL)
|
||||||
|
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(r & 31, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
||||||
|
|
||||||
|
def test_max(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
|
||||||
|
|
||||||
|
def test_xor_in_mask(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
r = UOp.range(32, 0, AxisType.GLOBAL)
|
||||||
|
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf.index(r.valid((r < 10) ^ (r >= 20)), ptr=True).load(dtype=dtypes.int)]) # 0..9,20..31 oob
|
||||||
|
|
||||||
|
# cast patterns
|
||||||
|
def test_float_cast_in_index(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
r = UOp.range(20, 0)
|
||||||
|
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
|
||||||
|
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
|
||||||
|
|
||||||
|
def test_bool_cast_in_mask(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||||
|
r = UOp.range(20, 0)
|
||||||
|
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
|
||||||
|
|
||||||
|
# load result as index/mask
|
||||||
|
def test_load_as_index(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 1)
|
||||||
|
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||||
|
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index)
|
||||||
|
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||||
|
|
||||||
|
def test_load_bool_as_mask(self):
|
||||||
|
with Context(IGNORE_OOB=0, SPEC=2):
|
||||||
|
buf_bool = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||||
|
buf_int = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 1)
|
||||||
|
gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
||||||
|
ld_bool = buf_bool.index(gidx, ptr=True).load()
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8
|
||||||
|
|
||||||
|
# skipped tests (moved from test_uop_graph.py)
|
||||||
|
@unittest.skip("if not allowed in graph")
|
||||||
|
def test_in_bounds_access_gated_local(self):
|
||||||
|
with Context(IGNORE_OOB=0):
|
||||||
|
# Define buffers
|
||||||
|
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
|
||||||
|
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
||||||
|
|
||||||
|
# Define indices, valids and barrier
|
||||||
|
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
|
||||||
|
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
|
||||||
|
|
||||||
|
gate = (gidx<400) & (lidx<8)
|
||||||
|
|
||||||
|
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
||||||
|
|
||||||
|
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
||||||
|
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||||
|
|
||||||
|
# Load from local memory (after the IF/barrier)
|
||||||
|
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
||||||
|
|
||||||
|
# Store to global memory
|
||||||
|
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
||||||
|
to_uops_list([global_store])
|
||||||
|
|
||||||
|
@unittest.skip("Bool load is not supported yet")
|
||||||
|
def test_load_mask(self):
|
||||||
|
with Context(IGNORE_OOB=0):
|
||||||
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
|
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||||
|
ridx = UOp.range(20, 0)
|
||||||
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
||||||
|
to_uops_list([ld0])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user