mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
248 lines
7.7 KiB
Python
248 lines
7.7 KiB
Python
import unittest
|
|
from tinygrad import dtypes
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
from tinygrad.schedule.memory import memory_plan_rewrite
|
|
|
|
global_map = {}
|
|
held_bufs: set[UOp] = set()
|
|
def b(i, base=None, offset=0, pin=False, size=16):
|
|
global global_map
|
|
if i in global_map: return global_map[i]
|
|
if base is not None:
|
|
global_map[i] = global_map[base]
|
|
return global_map[i]
|
|
global_map[i] = UOp.new_buffer("NULL", size, dtypes.int8)
|
|
if pin: held_bufs.add(global_map[i])
|
|
return global_map[i]
|
|
|
|
def _make_linear(buffer_lists, copies=None):
|
|
copy_pairs = {frozenset((id(dst), id(src))) for dst, src in copies} if copies else set()
|
|
calls = []
|
|
for bufs in buffer_lists:
|
|
is_copy = len(bufs) == 2 and frozenset((id(bufs[0]), id(bufs[1]))) in copy_pairs
|
|
calls.append(UOp(Ops.CALL, dtypes.void, (UOp(Ops.COPY if is_copy else Ops.SINK), *bufs)))
|
|
return UOp(Ops.LINEAR, src=tuple(calls))
|
|
|
|
def _get_arena(buf, linear, result):
|
|
for orig_si, new_si in zip(linear.src, result.src):
|
|
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
|
if orig is buf and new.op is Ops.BUFFER_VIEW: return new.src[0]
|
|
return None
|
|
|
|
def check_assign(buffer_lists, copies=None):
|
|
linear = _make_linear(buffer_lists, copies)
|
|
result = memory_plan_rewrite(linear, held_bufs)
|
|
|
|
# build mapping: original buf -> (arena, offset_bytes, nbytes) from the result
|
|
replace_map: dict[int, tuple[UOp, int, int]] = {}
|
|
for orig_si, new_si in zip(linear.src, result.src):
|
|
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
|
if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map:
|
|
replace_map[id(orig)] = (new.src[0], new.arg[1] * new.dtype.itemsize, new.arg[0] * new.dtype.itemsize)
|
|
|
|
# verify pinned buffers are not planned
|
|
for buf in held_bufs:
|
|
assert id(buf) not in replace_map, "pinned buffer was planned"
|
|
|
|
# compute lifetimes
|
|
first_appearance, last_appearance = {}, {}
|
|
for i, bufs in enumerate(buffer_lists):
|
|
for buf in bufs:
|
|
if buf in held_bufs: continue
|
|
if id(buf) not in first_appearance: first_appearance[id(buf)] = i
|
|
last_appearance[id(buf)] = i
|
|
|
|
# verify non-overlapping: no two live buffers share the same arena region
|
|
taken_parts: set[tuple[int, int, int, int]] = set() # (id(arena), offset, nbytes, id(buf))
|
|
for i, bufs in enumerate(buffer_lists):
|
|
for buf in bufs:
|
|
if buf in held_bufs or id(buf) not in replace_map: continue
|
|
arena, off, nb = replace_map[id(buf)]
|
|
for part in taken_parts:
|
|
assert id(buf) == part[3] or part[0] != id(arena) or part[1] + part[2] <= off or part[1] >= off + nb, \
|
|
f"overlap at step {i}: [{off}, {off+nb}) conflicts with [{part[1]}, {part[1]+part[2]})"
|
|
if first_appearance.get(id(buf)) == i: taken_parts.add((id(arena), off, nb, id(buf)))
|
|
if last_appearance.get(id(buf)) == i: taken_parts.discard((id(arena), off, nb, id(buf)))
|
|
|
|
class TestMemoryPlanner(unittest.TestCase):
|
|
def setUp(self):
|
|
global global_map
|
|
held_bufs.clear()
|
|
global_map = {}
|
|
|
|
def test_simple_buffer(self):
|
|
bs = [
|
|
[b(0), b(1), b(2)],
|
|
[b(1), b(2), b(3)],
|
|
[b(4), b(3)],
|
|
[b(5), b(2)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_simple_pinned(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1), b(2, pin=True)],
|
|
[b(1), b(2), b(3)],
|
|
[b(4), b(3)],
|
|
[b(5), b(2)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_all_pinned(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1, pin=True)],
|
|
[b(1), b(2, pin=True)],
|
|
[b(4, pin=True), b(3, pin=True)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_simple_buffer_offset(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1, base=0, offset=1, size=8), b(2)],
|
|
[b(1), b(2), b(3, base=0, offset=1, size=8)],
|
|
[b(4), b(3)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_buffer_offset(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1, base=0, offset=1, size=8), b(2)],
|
|
[b(1), b(2), b(3, base=0, offset=1, size=8)],
|
|
[b(4), b(3)],
|
|
[b(5, base=2, offset=2, size=8), b(3)],
|
|
[b(6), b(5), b(0)],
|
|
[b(7), b(8, pin=True)],
|
|
[b(8), b(9, base=2, offset=2, size=8)],
|
|
[b(9), b(3), b(5)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_buffer_offset2(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1), b(2)],
|
|
[b(1), b(2), b(3)],
|
|
[b(4), b(3)],
|
|
[b(5), b(3)],
|
|
[b(6), b(5), b(0)],
|
|
[b(7), b(8, pin=True)],
|
|
[b(8), b(9)],
|
|
[b(9), b(3), b(5)],
|
|
[b(11), b(0)],
|
|
[b(11), b(10), b(5)],
|
|
[b(12), b(11), b(0)],
|
|
[b(6), b(12), b(7)],
|
|
[b(13), b(6), b(11)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_all_offsets_of_one(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1)],
|
|
[b(3, base=1, offset=0, size=8), b(2, base=0, offset=0, size=8)],
|
|
[b(5, base=1, offset=8, size=8), b(4, base=0, offset=8, size=8)],
|
|
[b(7, base=1, offset=4, size=8), b(6, base=0, offset=4, size=8)],
|
|
|
|
[b(4), b(5), b(2)],
|
|
[b(3), b(7)],
|
|
[b(10), b(6), b(7)],
|
|
[b(11), b(3), b(2)],
|
|
[b(12), b(5), b(4), b(3), b(2)],
|
|
[b(13), b(6), b(12), b(7)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_very_small_buffers(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1, size=32)],
|
|
[b(3, size=4), b(4, size=6)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_very_big_buffers(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1, size=34359738368000)],
|
|
[b(3, size=1 << 128), b(4, size=1 << 64)],
|
|
]
|
|
check_assign(bs)
|
|
|
|
def test_copy_bufs_separate_from_compute(self):
|
|
bs = [
|
|
[b(0), b(1)],
|
|
[b(1), b(2)],
|
|
[b(3), b(2)],
|
|
]
|
|
linear = _make_linear(bs, copies=[(b(1), b(0))])
|
|
result = memory_plan_rewrite(linear)
|
|
r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result)
|
|
assert r1_arena is not None and r2_arena is not None
|
|
assert r1_arena is not r2_arena
|
|
|
|
def test_copy_bufs_reuse_among_copies(self):
|
|
bs = [
|
|
[b(0), b(1)],
|
|
[b(2), b(1)],
|
|
[b(3), b(2)],
|
|
]
|
|
linear = _make_linear(bs, copies=[(b(1), b(0)), (b(2), b(1))])
|
|
result = memory_plan_rewrite(linear)
|
|
r1_arena, r2_arena = _get_arena(b(1), linear, result), _get_arena(b(2), linear, result)
|
|
assert r1_arena is not None and r2_arena is not None
|
|
assert r1_arena is r2_arena
|
|
|
|
def test_compute_bufs_reuse_among_compute(self):
|
|
bs = [
|
|
[b(0), b(1)],
|
|
[b(2), b(1)],
|
|
[b(3), b(2)],
|
|
[b(4), b(3)],
|
|
]
|
|
linear = _make_linear(bs, copies=[(b(1), b(0))])
|
|
result = memory_plan_rewrite(linear)
|
|
r2_arena, r3_arena = _get_arena(b(2), linear, result), _get_arena(b(3), linear, result)
|
|
assert r2_arena is not None and r3_arena is not None
|
|
assert r2_arena is r3_arena
|
|
|
|
def test_copy_and_compute_no_cross_reuse(self):
|
|
bs = [
|
|
[b(0), b(1)],
|
|
[b(2), b(1)],
|
|
[b(3), b(2)],
|
|
]
|
|
linear = _make_linear(bs, copies=[(b(2), b(1))])
|
|
result = memory_plan_rewrite(linear)
|
|
r0_arena, r2_arena = _get_arena(b(0), linear, result), _get_arena(b(2), linear, result)
|
|
assert r0_arena is not None and r2_arena is not None
|
|
assert r0_arena is not r2_arena
|
|
|
|
def test_multiple_copy_bufs_with_offsets(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1), b(2)],
|
|
[b(3, base=0, offset=1, size=8), b(1), b(2)],
|
|
[b(4), b(3)],
|
|
[b(5), b(4)],
|
|
]
|
|
check_assign(bs, copies=[(b(1), b(0)), (b(2), b(0))])
|
|
|
|
def test_copy_bufs_pinned_mixed(self):
|
|
bs = [
|
|
[b(0, pin=True), b(1), b(2)],
|
|
[b(1), b(3), b(2)],
|
|
[b(4), b(3)],
|
|
[b(5), b(4), b(0)],
|
|
]
|
|
check_assign(bs, copies=[(b(1), b(0)), (b(3), b(1))])
|
|
|
|
def test_deferred_copy_frees_chain(self):
|
|
bs = []
|
|
copies = []
|
|
for i in range(6):
|
|
copy_buf, compute_buf = b(i * 2 + 1), b(i * 2 + 2)
|
|
bs.append([copy_buf, b(0, pin=True)])
|
|
bs.append([compute_buf, copy_buf])
|
|
copies.append((copy_buf, b(0, pin=True)))
|
|
bs.append([b(100, pin=True)])
|
|
check_assign(bs, copies=copies)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|