mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
320 lines
12 KiB
Python
320 lines
12 KiB
Python
import numpy as np
|
|
import functools, unittest, ctypes
|
|
|
|
from tinygrad.device import Device, Buffer
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import Context, from_mv
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.engine.jit import MultiGraphRunner
|
|
from tinygrad.schedule import linear_to_schedule
|
|
from tinygrad.uop.ops import UOp, Ops, buffers
|
|
|
|
from test.helpers import needs_second_gpu
|
|
|
|
np.random.seed(1337)
|
|
Tensor.manual_seed(1337)
|
|
BUF_SIZE = 4096
|
|
RUN_CNT = 5
|
|
|
|
# cache AST by (device, num_inputs)
|
|
cached_asts: dict[tuple[str, int], UOp] = {}
|
|
def get_ast(device:str, num_inputs:int) -> UOp:
|
|
if (device, num_inputs) not in cached_asts:
|
|
with Context(DEBUG=0):
|
|
fst = [Tensor.randn(BUF_SIZE, dtype=dtypes.int).realize() for _ in range(num_inputs)]
|
|
s = fst[0]
|
|
for i in range(1, num_inputs): s = s.bitwise_xor(fst[i])
|
|
cached_asts[(device, num_inputs)] = s.schedule()[-1].ast
|
|
return cached_asts[(device, num_inputs)]
|
|
|
|
def make_buffer(device, size=BUF_SIZE, fill=False):
|
|
buf = Buffer(device, size, dtypes.int).ensure_allocated()
|
|
if fill:
|
|
with Context(DEBUG=0):
|
|
buf.copyin(Tensor(np.random.randint(-10000, 10000, size=size, dtype=np.int32)).realize().uop.base.realized.as_memoryview())
|
|
return buf
|
|
|
|
def make_view(base, offset_elems, size_elems):
|
|
return Buffer(base.device, size_elems, base.dtype, base=base, offset=offset_elems * base.dtype.itemsize).ensure_allocated()
|
|
|
|
def get_buf_uop(buf:Buffer, cache:dict[Buffer,UOp]) -> UOp:
|
|
if buf not in cache:
|
|
cache[buf] = u = UOp.new_buffer(buf.device, buf.size, buf.dtype)
|
|
buffers[u] = buf
|
|
return cache[buf]
|
|
|
|
def make_graph(graph_cls, calls:list[UOp]):
|
|
linear = UOp(Ops.LINEAR, src=tuple(calls))
|
|
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(linear,), arg="graph")
|
|
return graph_cls(cf, [])
|
|
|
|
def run_schedule(calls:list[UOp]):
|
|
for ei in linear_to_schedule(UOp(Ops.LINEAR, src=tuple(calls))): ei.lower().run({})
|
|
|
|
def zero_bufs(bufs):
|
|
for b in bufs:
|
|
mv = memoryview(bytearray(b.nbytes))
|
|
ctypes.memset(from_mv(mv), 0, len(mv))
|
|
b.copyin(mv)
|
|
|
|
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
|
|
class TestGraph(unittest.TestCase):
|
|
def skip_if_no_offset(self):
|
|
if not hasattr(Device[Device.DEFAULT].allocator, "_offset"): self.skipTest("device does not support _offset")
|
|
|
|
def skip_if_not_multigraph(self):
|
|
graph = g.func if isinstance(g:=(d:=Device[Device.DEFAULT]).graph, functools.partial) else g
|
|
if not issubclass(graph, MultiGraphRunner): self.skipTest("graph is not supported (not MultiGraphRunner)")
|
|
if not hasattr(d.allocator, '_transfer') or not d.allocator.supports_transfer: self.skipTest("device is not supported (no transfers)")
|
|
|
|
def test_order_2_writes_to_same_buf(self):
|
|
d0 = Device.DEFAULT
|
|
b = [make_buffer(d0, fill=True) for _ in range(5)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
get_ast(d0, 2).call(get_buf_uop(b[0],c), get_buf_uop(b[1],c), get_buf_uop(b[2],c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(b[0],c), get_buf_uop(b[3],c), get_buf_uop(b[4],c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([b[0]])
|
|
run_schedule(calls)
|
|
expected = [np.frombuffer(x.as_memoryview(), np.int32).copy() for x in b]
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([b[0]])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for i, buf in enumerate(b): np.testing.assert_equal(expected[i], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_order_read_write_same_buf(self):
|
|
d0 = Device.DEFAULT
|
|
b = [make_buffer(d0, fill=True) for _ in range(5)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
get_ast(d0, 2).call(get_buf_uop(b[0],c), get_buf_uop(b[1],c), get_buf_uop(b[2],c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(b[1],c), get_buf_uop(b[3],c), get_buf_uop(b[4],c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([b[0], b[1]])
|
|
run_schedule(calls)
|
|
expected = [np.frombuffer(x.as_memoryview(), np.int32).copy() for x in b]
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([b[0], b[1]])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for i, buf in enumerate(b): np.testing.assert_equal(expected[i], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_order_write_read_same_buf(self):
|
|
d0 = Device.DEFAULT
|
|
b = [make_buffer(d0, fill=True) for _ in range(5)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
get_ast(d0, 2).call(get_buf_uop(b[0],c), get_buf_uop(b[1],c), get_buf_uop(b[2],c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(b[1],c), get_buf_uop(b[0],c), get_buf_uop(b[4],c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([b[0], b[1]])
|
|
run_schedule(calls)
|
|
expected = [np.frombuffer(x.as_memoryview(), np.int32).copy() for x in b]
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([b[0], b[1]])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for i, buf in enumerate(b): np.testing.assert_equal(expected[i], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_order_copy_writed(self):
|
|
self.skip_if_not_multigraph()
|
|
d0 = Device.DEFAULT
|
|
b = [make_buffer(d0, fill=True) for _ in range(4)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
get_ast(d0, 2).call(get_buf_uop(b[0],c), get_buf_uop(b[1],c), get_buf_uop(b[2],c), metadata=()),
|
|
UOp(Ops.COPY).call(get_buf_uop(b[3],c), get_buf_uop(b[0],c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([b[0], b[3]])
|
|
run_schedule(calls)
|
|
expected = [np.frombuffer(x.as_memoryview(), np.int32).copy() for x in b]
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([b[0], b[3]])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for i, buf in enumerate(b): np.testing.assert_equal(expected[i], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_order_copy_then_read(self):
|
|
self.skip_if_not_multigraph()
|
|
d0 = Device.DEFAULT
|
|
b = [make_buffer(d0, fill=True) for _ in range(4)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
UOp(Ops.COPY).call(get_buf_uop(b[1],c), get_buf_uop(b[0],c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(b[3],c), get_buf_uop(b[1],c), get_buf_uop(b[2],c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([b[1], b[3]])
|
|
run_schedule(calls)
|
|
expected = [np.frombuffer(x.as_memoryview(), np.int32).copy() for x in b]
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([b[1], b[3]])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for i, buf in enumerate(b): np.testing.assert_equal(expected[i], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_read_write_several_graphs(self):
|
|
d0 = Device.DEFAULT
|
|
b = [make_buffer(d0, fill=True) for _ in range(8)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls1 = [get_ast(d0, 2).call(get_buf_uop(b[3],c), get_buf_uop(b[1],c), get_buf_uop(b[2],c), metadata=())]
|
|
calls2 = [get_ast(d0, 2).call(get_buf_uop(b[4],c), get_buf_uop(b[1],c), get_buf_uop(b[3],c), metadata=())]
|
|
calls3 = [get_ast(d0, 2).call(get_buf_uop(b[5],c), get_buf_uop(b[4],c), get_buf_uop(b[2],c), metadata=())]
|
|
|
|
out = [b[3], b[4], b[5]]
|
|
zero_bufs(out)
|
|
run_schedule(calls1 + calls2 + calls3)
|
|
expected = [np.frombuffer(x.as_memoryview(), np.int32).copy() for x in b]
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs(out)
|
|
make_graph(Device[d0].graph, calls1)([], {})
|
|
make_graph(Device[d0].graph, calls2)([], {})
|
|
make_graph(Device[d0].graph, calls3)([], {})
|
|
for i, buf in enumerate(b): np.testing.assert_equal(expected[i], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
@needs_second_gpu
|
|
def test_copies_2_devs(self):
|
|
self.skip_if_not_multigraph()
|
|
d0, d1 = Device.DEFAULT, f"{Device.DEFAULT}:1"
|
|
b0 = [make_buffer(d0, fill=True) for _ in range(3)]
|
|
b1 = [make_buffer(d1, fill=True)]
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
UOp(Ops.COPY).call(get_buf_uop(b1[0],c), get_buf_uop(b0[0],c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(b0[2],c), get_buf_uop(b0[0],c), get_buf_uop(b0[1],c), metadata=()),
|
|
]
|
|
|
|
out = [b1[0], b0[2]]
|
|
zero_bufs(out)
|
|
run_schedule(calls)
|
|
expected = {buf: np.frombuffer(buf.as_memoryview(), np.int32).copy() for buf in b0 + b1}
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs(out)
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for buf in b0 + b1: np.testing.assert_equal(expected[buf], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_graph_offset_bufs(self):
|
|
self.skip_if_not_multigraph()
|
|
d0 = Device.DEFAULT
|
|
if not hasattr(Device[d0].allocator, "_offset"): self.skipTest("device does not support _offset")
|
|
|
|
b0 = make_buffer(d0, fill=True)
|
|
b1 = make_view(b0, 0, b0.size)
|
|
b2 = make_view(b0, 0, b0.size)
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
UOp(Ops.COPY).call(get_buf_uop(b0,c), get_buf_uop(b2,c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(b1,c), get_buf_uop(b0,c), get_buf_uop(b2,c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([b0])
|
|
run_schedule(calls)
|
|
expected = np.frombuffer(b0.as_memoryview(), np.int32).copy()
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([b0])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
np.testing.assert_equal(expected, np.frombuffer(b0.as_memoryview(), np.int32))
|
|
|
|
def test_partial_write_preserves_write_dep(self):
|
|
self.skip_if_not_multigraph()
|
|
self.skip_if_no_offset()
|
|
d0 = Device.DEFAULT
|
|
|
|
base = make_buffer(d0, BUF_SIZE * 2, fill=True)
|
|
copy_src_full = make_buffer(d0, BUF_SIZE * 2, fill=True)
|
|
copy_src_lo = make_buffer(d0, fill=True)
|
|
v_lo, v_hi = make_view(base, 0, BUF_SIZE), make_view(base, BUF_SIZE, BUF_SIZE)
|
|
a, out = make_buffer(d0, fill=True), make_buffer(d0, fill=True)
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
UOp(Ops.COPY).call(get_buf_uop(base,c), get_buf_uop(copy_src_full,c), metadata=()),
|
|
UOp(Ops.COPY).call(get_buf_uop(v_lo,c), get_buf_uop(copy_src_lo,c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(out,c), get_buf_uop(v_hi,c), get_buf_uop(a,c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([base, out])
|
|
run_schedule(calls)
|
|
expected = {base: np.frombuffer(base.as_memoryview(), np.int32).copy(), out: np.frombuffer(out.as_memoryview(), np.int32).copy()}
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([base, out])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for buf in [base, out]: np.testing.assert_equal(expected[buf], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_partial_write_preserves_read_dep(self):
|
|
self.skip_if_not_multigraph()
|
|
self.skip_if_no_offset()
|
|
d0 = Device.DEFAULT
|
|
|
|
base = make_buffer(d0, BUF_SIZE * 2, fill=True)
|
|
copy_dst = make_buffer(d0, BUF_SIZE * 2, fill=True)
|
|
copy_src_lo = make_buffer(d0, fill=True)
|
|
v_lo, v_hi = make_view(base, 0, BUF_SIZE), make_view(base, BUF_SIZE, BUF_SIZE)
|
|
a, b = make_buffer(d0, fill=True), make_buffer(d0, fill=True)
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
UOp(Ops.COPY).call(get_buf_uop(copy_dst,c), get_buf_uop(base,c), metadata=()),
|
|
UOp(Ops.COPY).call(get_buf_uop(v_lo,c), get_buf_uop(copy_src_lo,c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(v_hi,c), get_buf_uop(a,c), get_buf_uop(b,c), metadata=()),
|
|
]
|
|
|
|
zero_bufs([copy_dst, base])
|
|
run_schedule(calls)
|
|
expected = {copy_dst: np.frombuffer(copy_dst.as_memoryview(), np.int32).copy(), base: np.frombuffer(base.as_memoryview(), np.int32).copy()}
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs([copy_dst, base])
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for buf in [copy_dst, base]: np.testing.assert_equal(expected[buf], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
def test_middle_write_splits_write_dep(self):
|
|
self.skip_if_not_multigraph()
|
|
self.skip_if_no_offset()
|
|
d0 = Device.DEFAULT
|
|
|
|
base = make_buffer(d0, BUF_SIZE * 3, fill=True)
|
|
copy_src_full = make_buffer(d0, BUF_SIZE * 3, fill=True)
|
|
copy_src_mid = make_buffer(d0, fill=True)
|
|
v_lo, v_mid, v_hi = make_view(base, 0, BUF_SIZE), make_view(base, BUF_SIZE, BUF_SIZE), make_view(base, BUF_SIZE * 2, BUF_SIZE)
|
|
a, out1, out2 = make_buffer(d0, fill=True), make_buffer(d0, fill=True), make_buffer(d0, fill=True)
|
|
c: dict[Buffer,UOp] = {}
|
|
|
|
calls = [
|
|
UOp(Ops.COPY).call(get_buf_uop(base,c), get_buf_uop(copy_src_full,c), metadata=()),
|
|
UOp(Ops.COPY).call(get_buf_uop(v_mid,c), get_buf_uop(copy_src_mid,c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(out1,c), get_buf_uop(v_lo,c), get_buf_uop(a,c), metadata=()),
|
|
get_ast(d0, 2).call(get_buf_uop(out2,c), get_buf_uop(v_hi,c), get_buf_uop(a,c), metadata=()),
|
|
]
|
|
|
|
outs = [base, out1, out2]
|
|
zero_bufs(outs)
|
|
run_schedule(calls)
|
|
expected = {buf: np.frombuffer(buf.as_memoryview(), np.int32).copy() for buf in outs}
|
|
|
|
for _ in range(RUN_CNT):
|
|
zero_bufs(outs)
|
|
make_graph(Device[d0].graph, calls)([], {})
|
|
for buf in outs: np.testing.assert_equal(expected[buf], np.frombuffer(buf.as_memoryview(), np.int32))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|