mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
180 lines
8.7 KiB
Python
180 lines
8.7 KiB
Python
import unittest
|
|
from tinygrad import dtypes, Variable
|
|
from tinygrad.dtype import AddrSpace
|
|
from tinygrad.helpers import Context
|
|
from tinygrad.uop.ops import Ops, UOp, AxisType
|
|
from test.test_uops import to_uops_list
|
|
|
|
class TestValidateOOB(unittest.TestCase):
|
|
"""Test z3 validation of index bounds for different ALU ops and patterns."""
|
|
|
|
# basic index patterns
|
|
def test_const_index(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
|
|
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(UOp.const(dtypes.int, 16), ptr=True).load(dtype=dtypes.int)]) # off by one
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(UOp.const(dtypes.int, 42), ptr=True).load(dtype=dtypes.int)]) # way out
|
|
|
|
def test_variable_index(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(Variable("i", -5, 10), ptr=True).load(dtype=dtypes.int)]) # negative
|
|
|
|
def test_range_with_mask(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
r = UOp.range(42, 0, AxisType.GLOBAL)
|
|
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(r.valid(r < 17), ptr=True).load(dtype=dtypes.int)]) # oob
|
|
|
|
def test_variable_with_mask(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
v = Variable("v", -5, 80)
|
|
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(v.valid(v < 20), ptr=True).load(dtype=dtypes.int)]) # negative not masked
|
|
|
|
def test_gated_store(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
v = Variable("v", 0, 20)
|
|
to_uops_list([buf.index(v.valid(v < 16)).store(0)]) # valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
|
|
|
|
# ALU ops in index
|
|
def test_idiv(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
|
|
|
|
def test_mod(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
r = UOp.range(100, 0, AxisType.GLOBAL)
|
|
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(r % 20, ptr=True).load(dtype=dtypes.int)]) # 0..19 oob
|
|
|
|
def test_shr(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
|
|
|
def test_shl(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
|
r = UOp.range(8, 0, AxisType.GLOBAL)
|
|
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(r << 4, ptr=True).load(dtype=dtypes.int)]) # 0..112 oob
|
|
|
|
def test_and(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
r = UOp.range(100, 0, AxisType.GLOBAL)
|
|
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(r & 31, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
|
|
|
def test_max(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
|
|
|
|
def test_xor_in_mask(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
r = UOp.range(32, 0, AxisType.GLOBAL)
|
|
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf.index(r.valid((r < 10) ^ (r >= 20)), ptr=True).load(dtype=dtypes.int)]) # 0..9,20..31 oob
|
|
|
|
# cast patterns
|
|
def test_float_cast_in_index(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
r = UOp.range(20, 0)
|
|
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
|
|
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
|
|
|
|
def test_bool_cast_in_mask(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
|
r = UOp.range(20, 0)
|
|
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
|
|
|
|
# load result as index/mask
|
|
def test_load_as_index(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 1)
|
|
r = UOp.range(42, 0, AxisType.GLOBAL)
|
|
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index)
|
|
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob
|
|
|
|
def test_load_bool_as_mask(self):
|
|
with Context(IGNORE_OOB=0, SPEC=2):
|
|
buf_bool = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
|
buf_int = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 1)
|
|
gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
|
ld_bool = buf_bool.index(gidx, ptr=True).load()
|
|
with self.assertRaises(RuntimeError):
|
|
to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8
|
|
|
|
# skipped tests (moved from test_uop_graph.py)
|
|
@unittest.skip("if not allowed in graph")
|
|
def test_in_bounds_access_gated_local(self):
|
|
with Context(IGNORE_OOB=0):
|
|
# Define buffers
|
|
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
|
|
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
|
|
|
# Define indices, valids and barrier
|
|
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
|
|
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
|
|
|
|
gate = (gidx<400) & (lidx<8)
|
|
|
|
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
|
|
|
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
|
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
|
|
|
# Load from local memory (after the IF/barrier)
|
|
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
|
|
|
# Store to global memory
|
|
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
|
to_uops_list([global_store])
|
|
|
|
@unittest.skip("Bool load is not supported yet")
|
|
def test_load_mask(self):
|
|
with Context(IGNORE_OOB=0):
|
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
|
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
|
ridx = UOp.range(20, 0)
|
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
|
to_uops_list([ld0])
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|