mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
graph fuzzer (#5082)
* graph fuzzer * more options * mypy * no underscores for funcs
This commit is contained in:
129
test/external/fuzz_graph.py
vendored
Normal file
129
test/external/fuzz_graph.py
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
import random, ctypes
|
||||
import numpy as np
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.helpers import Context, getenv, from_mv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner
|
||||
from tinygrad.engine.jit import apply_graph_to_jit
|
||||
|
||||
BUF_LEN = getenv("BUF_LEN", 128)
|
||||
|
||||
cached_prgs = {}
|
||||
def gen_prg(device, inputs_cnt):
|
||||
if (device, inputs_cnt) in cached_prgs: return cached_prgs[(device, inputs_cnt)]
|
||||
|
||||
with Context(DEBUG=0):
|
||||
fst = [Tensor.randn(BUF_LEN, dtype=dtypes.int).realize() for i in range(inputs_cnt)]
|
||||
s = fst[0]
|
||||
for i in range(1, inputs_cnt): s = s.xor(fst[i])
|
||||
|
||||
si = create_schedule([s.lazydata])[-1]
|
||||
prg = get_runner(device, si.ast)
|
||||
cached_prgs[(device, inputs_cnt)] = prg
|
||||
return prg
|
||||
|
||||
def alloc_rawbuffer(device, fill=False):
|
||||
rawbuf = Buffer(device, BUF_LEN, dtypes.int).ensure_allocated()
|
||||
if fill:
|
||||
with Context(DEBUG=0):
|
||||
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
|
||||
rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
|
||||
return rawbuf
|
||||
|
||||
def gen_kernel_ji(device, deps):
|
||||
assert len(deps) >= 2
|
||||
out = alloc_rawbuffer(device)
|
||||
prg = gen_prg(device, len(deps))
|
||||
return ExecItem(prg, [out] + deps)
|
||||
|
||||
def gen_copy_ji(device, deps):
|
||||
assert len(deps) == 1
|
||||
out = alloc_rawbuffer(device)
|
||||
prg = BufferXfer(deps[0].nbytes, device, deps[0].device)
|
||||
return ExecItem(prg, [out] + deps)
|
||||
|
||||
def gen_graph():
|
||||
input_buffers = []
|
||||
all_buffers = []
|
||||
jis = []
|
||||
|
||||
last_n_deps = getenv("LAST_N_DEPS", 0)
|
||||
|
||||
kernel_count = random.randint(2, getenv("MAX_KERNELS", 128))
|
||||
for i in range(kernel_count):
|
||||
target_device_id = random.randint(0, getenv("MAX_DEVICES", 6) - 1)
|
||||
target_device = f"{Device.DEFAULT}:{target_device_id}"
|
||||
is_copy = random.randint(0, 10) < 3
|
||||
|
||||
if is_copy:
|
||||
deps_pool = [buf for buf in all_buffers[-last_n_deps:] if buf.device != target_device]
|
||||
if len(deps_pool) == 0: deps = []
|
||||
else: deps = random.sample(deps_pool, 1)
|
||||
else:
|
||||
deps_pool = [buf for buf in all_buffers[-last_n_deps:] if buf.device == target_device]
|
||||
deps_count = random.randint(0, min(getenv("MAX_DEPS_COUNT", 6), len(deps_pool)))
|
||||
if deps_count == 0: deps = []
|
||||
else: deps = random.sample(deps_pool, deps_count)
|
||||
|
||||
if len(deps) == 0 or (not is_copy and len(deps) < 2):
|
||||
buf = alloc_rawbuffer(target_device, fill=True)
|
||||
input_buffers.append(buf)
|
||||
all_buffers.append(buf)
|
||||
elif is_copy:
|
||||
jis.append(gen_copy_ji(target_device, deps))
|
||||
all_buffers.append(jis[-1].bufs[0])
|
||||
else:
|
||||
jis.append(gen_kernel_ji(target_device, deps))
|
||||
all_buffers.append(jis[-1].bufs[0])
|
||||
|
||||
return jis, all_buffers, input_buffers
|
||||
|
||||
def run_jit(jis, all_buffers, input_buffers, var_vals):
|
||||
with Context(DEBUG=0):
|
||||
for rawbuf in all_buffers:
|
||||
if rawbuf in input_buffers: continue
|
||||
mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize))
|
||||
ctypes.memset(from_mv(mv), 0, len(mv))
|
||||
rawbuf.copyin(mv)
|
||||
|
||||
for ei in jis: ei.run(var_vals, jit=True)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
res_buffers = []
|
||||
for rawbuf in all_buffers: res_buffers.append(rawbuf.as_buffer())
|
||||
return res_buffers
|
||||
|
||||
def fuzz_graph(jis, all_buffers, input_buffers):
|
||||
ground_thruth_bufs = run_jit(jis, input_buffers, all_buffers, {})
|
||||
ground_truth_np = [np.frombuffer(x, _to_np_dtype(all_buffers[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
|
||||
|
||||
for _ in range(getenv("FUZZ_GRAPH_SPLIT_RUNS", 64)):
|
||||
max_split_points = len(jis) // 3
|
||||
split_points = random.randint(0, min(max_split_points, getenv("FUZZ_GRAPH_MAX_SPLITS", 8)))
|
||||
split = [0]
|
||||
for i in range(split_points - 1):
|
||||
split.append(random.randint(split[-1] + 2, len(jis) - 2 * (max_split_points - i)))
|
||||
split.append(len(jis))
|
||||
|
||||
graphed_jit = []
|
||||
for sp in range(len(split)-1):
|
||||
graphed_jit += apply_graph_to_jit(jis[split[sp]:split[sp+1]], [], {})
|
||||
|
||||
for _ in range(getenv("FUZZ_GRAPH_SPLIT_RETRY_RUNS", 4)):
|
||||
test_bufs = run_jit(graphed_jit, input_buffers, all_buffers, {})
|
||||
test_bufs_np = [np.frombuffer(x, _to_np_dtype(all_buffers[i].dtype)) for i,x in enumerate(test_bufs)]
|
||||
for i in range(len(ground_thruth_bufs)): np.testing.assert_equal(ground_truth_np[i], test_bufs_np[i])
|
||||
|
||||
if __name__ == "__main__":
|
||||
SEED = getenv("SEED", 42)
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
|
||||
next_graph_id = 0
|
||||
while True:
|
||||
print("Running graph", next_graph_id)
|
||||
jis, all_buffers, input_buffers = gen_graph()
|
||||
fuzz_graph(jis, all_buffers, input_buffers)
|
||||
next_graph_id += 1
|
||||
Reference in New Issue
Block a user