mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add memplanner tests (#9577)
This commit is contained in:
124
test/test_memory_planner.py
Normal file
124
test/test_memory_planner.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
|
||||
global_map = {}
|
||||
def b(i, base=None, offset=0, pin=False, size=16):
|
||||
global global_map
|
||||
if i in global_map: return global_map[i]
|
||||
global_map[i] = Buffer(Device.DEFAULT, size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
|
||||
if pin: global_map[i].ref(1)
|
||||
return global_map[i]
|
||||
|
||||
def check_assign(buffers:list[list[Buffer]|tuple[Buffer, ...]]):
|
||||
assigned = _internal_memory_planner(buffers, noopt_buffers=None)
|
||||
|
||||
taken_parts = set()
|
||||
first_appearance, last_appearance = {}, {}
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
if buf.is_allocated() or buf.base.is_allocated() or buf.lb_refcount > 0: continue
|
||||
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||||
last_appearance[buf.base] = i
|
||||
|
||||
for i,u in enumerate(buffers):
|
||||
for buf in u:
|
||||
if buf.is_allocated() or buf.base.is_allocated() or buf.lb_refcount > 0: continue
|
||||
cur, base = assigned.get(buf, buf), assigned.get(buf.base, buf.base)
|
||||
if buf._base is not None:
|
||||
assert cur.base == base.base and cur.offset == buf.offset + base.offset, f"failed: {buf} {cur} {base} {buf.offset} {base.offset}"
|
||||
else:
|
||||
for part in taken_parts:
|
||||
assert buf.base == part[3] or part[0] != cur.base or part[1] + part[2] <= cur.offset or part[1] >= cur.offset + buf.nbytes
|
||||
if first_appearance[buf.base] == i: taken_parts.add((cur.base, cur.offset, buf.nbytes, buf.base))
|
||||
if last_appearance[buf.base] == i: taken_parts.remove((cur.base, cur.offset, buf.nbytes, buf.base))
|
||||
|
||||
class TestMemoryPlanner(unittest.TestCase):
|
||||
def setUp(self):
|
||||
global global_map
|
||||
global_map = {}
|
||||
|
||||
def test_simple_buffer(self):
|
||||
bs = [
|
||||
[b(0), b(1), b(2)],
|
||||
[b(1), b(2), b(3)],
|
||||
[b(4), b(3)],
|
||||
[b(5), b(2)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_simple_pinned(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1), b(2, pin=True)],
|
||||
[b(1), b(2), b(3)],
|
||||
[b(4), b(3)],
|
||||
[b(5), b(2)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_all_pinned(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1, pin=True)],
|
||||
[b(1), b(2, pin=True)],
|
||||
[b(4, pin=True), b(3, pin=True)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_simple_buffer_offset(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1, base=0, offset=1, size=8), b(2)],
|
||||
[b(1), b(2), b(3, base=0, offset=1, size=8)],
|
||||
[b(4), b(3)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_buffer_offset(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1, base=0, offset=1, size=8), b(2)],
|
||||
[b(1), b(2), b(3, base=0, offset=1, size=8)],
|
||||
[b(4), b(3)],
|
||||
[b(5, base=2, offset=2, size=8), b(3)],
|
||||
[b(6), b(5), b(0)],
|
||||
[b(7), b(8, pin=True)],
|
||||
[b(8), b(9, base=2, offset=2, size=8)],
|
||||
[b(9), b(3), b(5)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_buffer_offset2(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1), b(2)],
|
||||
[b(1), b(2), b(3)],
|
||||
[b(4), b(3)],
|
||||
[b(5), b(3)],
|
||||
[b(6), b(5), b(0)],
|
||||
[b(7), b(8, pin=True)],
|
||||
[b(8), b(9)],
|
||||
[b(9), b(3), b(5)],
|
||||
[b(11), b(0)],
|
||||
[b(11), b(10), b(5)],
|
||||
[b(12), b(11), b(0)],
|
||||
[b(6), b(12), b(7)],
|
||||
[b(13), b(6), b(11)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
def test_all_offsets_of_one(self):
|
||||
bs = [
|
||||
[b(0, pin=True), b(1)],
|
||||
[b(3, base=1, offset=0, size=8), b(2, base=0, offset=0, size=8)],
|
||||
[b(5, base=1, offset=8, size=8), b(4, base=0, offset=8, size=8)],
|
||||
[b(7, base=1, offset=4, size=8), b(6, base=0, offset=4, size=8)],
|
||||
|
||||
[b(4), b(5), b(2)],
|
||||
[b(3), b(7)],
|
||||
[b(10), b(6), b(7)],
|
||||
[b(11), b(3), b(2)],
|
||||
[b(12), b(5), b(4), b(3), b(2)],
|
||||
[b(13), b(6), b(12), b(7)],
|
||||
]
|
||||
check_assign(bs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user