mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
hcqfuzz: init (#10049)
* hcqfuzz: init * fix fuzz * linter * graph * taht test * update readme
This commit is contained in:
1
extra/hcqfuzz/.gitignore
vendored
Normal file
1
extra/hcqfuzz/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
reports
|
||||
72
extra/hcqfuzz/fuzzer.py
Normal file
72
extra/hcqfuzz/fuzzer.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os, random, subprocess, shlex, datetime, time, signal
|
||||
from extra.hcqfuzz.tools import create_report, on_start_run, collect_tests, init_log, log
|
||||
from extra.hcqfuzz.spec import AMSpec
|
||||
|
||||
def run_test(dev, test):
|
||||
on_start_run(dev, test)
|
||||
|
||||
dev_env = dev.get_exec_state()
|
||||
test_env, cmd, timeout = test.get_exec_state()
|
||||
env = {**dev_env, **test_env}
|
||||
|
||||
if isinstance(cmd, str): cmd = shlex.split(cmd)
|
||||
assert isinstance(cmd, list), "cmd must be list or str"
|
||||
|
||||
if env is None: env = os.environ.copy()
|
||||
else:
|
||||
env = {k: str(v) for k, v in env.items()}
|
||||
env = {**os.environ, **env}
|
||||
|
||||
start_ts = datetime.datetime.now()
|
||||
t0 = time.perf_counter()
|
||||
log(f"[{start_ts:%Y-%m-%d %H:%M:%S}] running: {test.name()}: {' '.join(cmd)}", end="", flush=True)
|
||||
|
||||
proc = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
try:
|
||||
stdout, stderr = proc.communicate(timeout=timeout)
|
||||
ret = proc.returncode
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try: stdout, stderr = proc.communicate(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
stdout, stderr = proc.communicate()
|
||||
raise
|
||||
except subprocess.TimeoutExpired:
|
||||
cur_time = datetime.datetime.now()
|
||||
log(f"\r[{cur_time:%Y-%m-%d %H:%M:%S}] {test.name()} send SIGKILL", end="", flush=True)
|
||||
|
||||
proc.kill()
|
||||
stdout, stderr = proc.communicate()
|
||||
ret = -9
|
||||
|
||||
finish_time = datetime.datetime.now()
|
||||
elapsed = time.perf_counter() - t0
|
||||
if ret != 0:
|
||||
log(f"\r[{finish_time:%Y-%m-%d %H:%M:%S}] {test.name()} failed with {ret} after {elapsed:.1f}s", flush=True)
|
||||
create_report(dev, test, ret, stdout, stderr)
|
||||
else:
|
||||
log(f"\r[{finish_time:%Y-%m-%d %H:%M:%S}] {test.name()} exited {ret} after {elapsed:.1f}s", flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_log()
|
||||
device_name = "AM"
|
||||
dev = AMSpec()
|
||||
|
||||
start_seed = os.environ.get("SEED", 3332)
|
||||
random.seed(start_seed)
|
||||
|
||||
log(f"Starting with seed {start_seed}")
|
||||
|
||||
test_set = collect_tests()
|
||||
log(f"Found {len(test_set)} tests:")
|
||||
for test in test_set: log(f" - {test.name()}")
|
||||
|
||||
while True:
|
||||
seed = random.randint(0, 2**31)
|
||||
test = random.choice(test_set)
|
||||
|
||||
dev.prepare(seed)
|
||||
test.prepare(dev, seed)
|
||||
run_test(dev, test)
|
||||
12
extra/hcqfuzz/readme
Normal file
12
extra/hcqfuzz/readme
Normal file
@@ -0,0 +1,12 @@
|
||||
# Fuzzing Infra
|
||||
|
||||
To add a new test, define a `TestSpec`-based class in a file in the `tests/` folder.
|
||||
|
||||
You can choose which tests to load from which file:
|
||||
```bash
|
||||
RUN_FILES="hcq,allocator" python3 extra/hcqfuzz/fuzzer.py
|
||||
```
|
||||
Or skip tests from any file:
|
||||
```bash
|
||||
SKIP_FILES="allocator" python3 extra/hcqfuzz/fuzzer.py
|
||||
```
|
||||
42
extra/hcqfuzz/spec.py
Normal file
42
extra/hcqfuzz/spec.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os, random
|
||||
|
||||
class TestSpec:
|
||||
def prepare(self, device, seed):
|
||||
raise NotImplementedError("prepare must be implemented in the derived class")
|
||||
def get_exec_state(self):
|
||||
raise NotImplementedError("get_exec_state must be implemented in the derived class")
|
||||
def name(self): return self.__class__.__name__
|
||||
|
||||
class DeviceSpec:
|
||||
def prepare(self, seed):
|
||||
raise NotImplementedError("prepare must be implemented in the derived class")
|
||||
def get_exec_state(self):
|
||||
raise NotImplementedError("get_exec_state must be implemented in the derived class")
|
||||
def name(self): return self.__class__.__name__
|
||||
|
||||
class HCQSpec(DeviceSpec): pass
|
||||
class AMDSpec(HCQSpec):
|
||||
def __init__(self):
|
||||
assert os.path.exists('/sys/module/amdgpu'), "amdgpu module should be loaded"
|
||||
|
||||
def prepare(self, seed):
|
||||
self.env = {
|
||||
"AMD": 1,
|
||||
"AMD_LLVM": 0
|
||||
}
|
||||
|
||||
def get_exec_state(self): return self.env
|
||||
|
||||
class AMSpec(AMDSpec):
|
||||
def __init__(self):
|
||||
assert not os.path.exists('/sys/module/amdgpu'), "amdgpu module should not be loaded"
|
||||
|
||||
def prepare(self, seed):
|
||||
super().prepare(seed)
|
||||
|
||||
self.env = {
|
||||
**self.env, # from AMDSpec
|
||||
"AMD_SDMA_BIND": random.randint(0, 1),
|
||||
"AMD_ALLOC_QUEUE_DEV_MEM": 0, # random.randint(0, 1) need to validate
|
||||
"AMD_QUEUE_SIZE": 1 << random.randint(10, 26),
|
||||
}
|
||||
16
extra/hcqfuzz/tests/allocator.py
Normal file
16
extra/hcqfuzz/tests/allocator.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from extra.hcqfuzz.spec import TestSpec
|
||||
import random
|
||||
|
||||
class TLSFAllocator(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
"SEED": seed,
|
||||
"ITERS": random.randint(10000, 1000000),
|
||||
}
|
||||
|
||||
self.cmd = "python3 test/external/external_fuzz_tlsf.py"
|
||||
self.timeout = 60 * 60 # 60 minutes
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
17
extra/hcqfuzz/tests/allreduce.py
Normal file
17
extra/hcqfuzz/tests/allreduce.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from extra.hcqfuzz.spec import TestSpec
|
||||
import random
|
||||
|
||||
class RingAllreduce(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
"GPUS": random.choice([2, 3, 4, 5, 6]),
|
||||
"ITERS": random.randint(10, 1000),
|
||||
"DEBUG": 2,
|
||||
}
|
||||
|
||||
self.cmd = "python3 test/external/external_benchmark_multitensor_allreduce.py"
|
||||
self.timeout = 10 * 60 # 10 minutes
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
83
extra/hcqfuzz/tests/bert.py
Normal file
83
extra/hcqfuzz/tests/bert.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from extra.hcqfuzz.spec import TestSpec
|
||||
import random
|
||||
|
||||
bert_train_params = {
|
||||
"DEFAULT_FLOAT": "HALF",
|
||||
"SUM_DTYPE": "HALF",
|
||||
"GPUS": 6,
|
||||
"BS": 96,
|
||||
"EVAL_BS": 96,
|
||||
"FUSE_ARANGE": 1,
|
||||
"FUSE_ARANGE_UINT": 0,
|
||||
"BASEDIR": "/raid/datasets/wiki",
|
||||
}
|
||||
|
||||
class TrainBert(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
**bert_train_params,
|
||||
"IGNORE_BEAM_CACHE": 1,
|
||||
"BEAM": 5,
|
||||
"BEAM_UOPS_MAX": 10000,
|
||||
"BEAM_UPCAST_MAX": 256,
|
||||
"BEAM_LOCAL_MAX": 1024,
|
||||
"BEAM_MIN_PROGRESS": 5,
|
||||
"IGNORE_JIT_FIRST_BEAM": 1,
|
||||
"LOGMLPERF": 0,
|
||||
"SEED": seed,
|
||||
}
|
||||
|
||||
self.cmd = "python3 examples/mlperf/model_train.py"
|
||||
self.timeout = 7 * 60 * 60 # 7 hours
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
|
||||
class TrainBertShort(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
**bert_train_params,
|
||||
"IGNORE_BEAM_CACHE": 1,
|
||||
"BEAM": 5,
|
||||
"BEAM_UOPS_MAX": 10000,
|
||||
"BEAM_UPCAST_MAX": 256,
|
||||
"BEAM_LOCAL_MAX": 1024,
|
||||
"BEAM_MIN_PROGRESS": 5,
|
||||
"IGNORE_JIT_FIRST_BEAM": 1,
|
||||
"SEED": seed,
|
||||
"BENCHMARK": 4096,
|
||||
"JIT": 2
|
||||
}
|
||||
|
||||
self.cmd = "python3 examples/mlperf/model_train.py"
|
||||
self.timeout = 2 * 60 * 60 # 2 hours
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
|
||||
class BertBeam(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
**bert_train_params,
|
||||
"IGNORE_BEAM_CACHE": 1,
|
||||
"BEAM": random.choice([1, 2, 3, 4, 5]),
|
||||
"BEAM_UOPS_MAX": 10000,
|
||||
"BEAM_UPCAST_MAX": 256,
|
||||
"BEAM_LOCAL_MAX": 1024,
|
||||
"BEAM_MIN_PROGRESS": 5,
|
||||
"IGNORE_JIT_FIRST_BEAM": 1,
|
||||
"SEED": seed,
|
||||
"RESET_STEP": 1,
|
||||
"BENCHMARK": 10,
|
||||
"BERT_LAYERS": 2,
|
||||
"SEED": seed,
|
||||
}
|
||||
|
||||
self.cmd = "python3 examples/mlperf/model_train.py"
|
||||
self.timeout = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
35
extra/hcqfuzz/tests/hcq.py
Normal file
35
extra/hcqfuzz/tests/hcq.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from extra.hcqfuzz.spec import TestSpec
|
||||
import random
|
||||
|
||||
class HCQSignalFuzzer(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
"GPUS": random.choice([2, 3, 4, 5, 6]),
|
||||
"ITERS": random.randint(1000000, 10000000),
|
||||
"SEED": seed,
|
||||
}
|
||||
|
||||
self.cmd = "python3 test/external/external_fuzz_hcq_signals.py"
|
||||
self.timeout = 30 * 60 # 30 minutes
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
|
||||
class HCQGraphFuzzer(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
"FUZZ_GRAPH_SPLIT_RUNS": random.randint(48, 64),
|
||||
"FUZZ_GRAPH_MAX_SPLITS": random.randint(4, 16),
|
||||
"FUZZ_GRAPH_SPLIT_RETRY_RUNS": random.randint(4, 8),
|
||||
"MAX_KERNELS": random.randint(32, 512),
|
||||
"MAX_DEVICES": random.choice([2, 3, 4, 5, 6]),
|
||||
"ITERS": random.randint(100, 1000),
|
||||
}
|
||||
|
||||
self.cmd = "python3 test/external/fuzz_graph.py"
|
||||
self.timeout = 60 * 60 # 60 minutes
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
66
extra/hcqfuzz/tests/resnet.py
Normal file
66
extra/hcqfuzz/tests/resnet.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from extra.hcqfuzz.spec import TestSpec
|
||||
import random
|
||||
|
||||
resnet_train_params = {
|
||||
"DEFAULT_FLOAT": "HALF",
|
||||
"SUM_DTYPE": "HALF",
|
||||
"GPUS": 6,
|
||||
"BS": 1536,
|
||||
"EVAL_BS": 192,
|
||||
"TRAIN_BEAM": 4,
|
||||
"IGNORE_JIT_FIRST_BEAM": 1,
|
||||
"BEAM_UOPS_MAX": 2000,
|
||||
"BEAM_UPCAST_MAX": 96,
|
||||
"BEAM_LOCAL_MAX": 1024,
|
||||
"BEAM_MIN_PROGRESS": 5,
|
||||
"BEAM_PADTO": 0,
|
||||
"EVAL_START_EPOCH": 3,
|
||||
"EVAL_FREQ": 4
|
||||
}
|
||||
|
||||
class TrainResnet(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
**resnet_train_params,
|
||||
"IGNORE_BEAM_CACHE": 1,
|
||||
"SEED": seed,
|
||||
}
|
||||
|
||||
self.cmd = "python3 examples/mlperf/model_train.py"
|
||||
self.timeout = 4 * 60 * 60 # 7 hours
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
|
||||
class TrainResnetShort(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
**resnet_train_params,
|
||||
"SEED": seed,
|
||||
"BENCHMARK": 4096,
|
||||
"JIT": 2,
|
||||
}
|
||||
|
||||
self.cmd = "python3 examples/mlperf/model_train.py"
|
||||
self.timeout = 2 * 60 * 60 # 2 hours
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
|
||||
class ResnetBeam(TestSpec):
|
||||
def prepare(self, dev, seed):
|
||||
random.seed(seed)
|
||||
|
||||
self.env = {
|
||||
**resnet_train_params,
|
||||
"IGNORE_BEAM_CACHE": 1,
|
||||
"BENCHMARK": 10,
|
||||
"SEED": seed,
|
||||
}
|
||||
|
||||
self.cmd = "python3 examples/mlperf/model_train.py"
|
||||
self.timeout = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def get_exec_state(self): return self.env, self.cmd, self.timeout
|
||||
76
extra/hcqfuzz/tools.py
Normal file
76
extra/hcqfuzz/tools.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pickle, datetime, os, tempfile, subprocess, zipfile, importlib.util
|
||||
from extra.hcqfuzz.spec import TestSpec
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
TEST_DIR = os.path.join(BASE_DIR, "tests")
|
||||
REPORTS_DIR = os.path.join(BASE_DIR, "reports")
|
||||
|
||||
def collect_tests():
|
||||
run_files = getenv("RUN_FILES", "").split(",")
|
||||
skip_tests = getenv("SKIP_FILES", "").split(",")
|
||||
|
||||
tests = []
|
||||
for filename in os.listdir(TEST_DIR):
|
||||
if filename.endswith(".py") and not filename.startswith("__"):
|
||||
if run_files and filename[:-3] not in run_files: continue
|
||||
if skip_tests and filename[:-3] in skip_tests: continue
|
||||
|
||||
filepath = os.path.join(TEST_DIR, filename)
|
||||
module_name = f"tests.{filename[:-3]}"
|
||||
module = importlib.import_module(module_name)
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, TestSpec) and attr is not TestSpec:
|
||||
tests.append(attr())
|
||||
return tests
|
||||
|
||||
def on_start_run(dev, test):
|
||||
os.makedirs(REPORTS_DIR, exist_ok=True)
|
||||
pickle.dump((dev, test), open(f"{REPORTS_DIR}/last_launch.pkl", "wb"))
|
||||
|
||||
def create_report(dev, test, result, stdout, stderr):
|
||||
os.makedirs(REPORTS_DIR, exist_ok=True)
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
report_name = f"{timestamp}_{test.name()}_report"
|
||||
report_path = os.path.join(REPORTS_DIR, report_name)
|
||||
|
||||
os.makedirs(report_path, exist_ok=False)
|
||||
|
||||
pickle_path = os.path.join(report_path, "repro.pkl")
|
||||
with open(pickle_path, "wb") as f: pickle.dump((dev, test), f)
|
||||
|
||||
stdout_path = os.path.join(report_path, "stdout.txt")
|
||||
with open(stdout_path, "w") as f: f.write(stdout)
|
||||
|
||||
stderr_path = os.path.join(report_path, "stderr.txt")
|
||||
with open(stderr_path, "w") as f: f.write(stderr)
|
||||
|
||||
dmesg_path = os.path.join(report_path, "dmesg.txt")
|
||||
dmesg_output = subprocess.check_output(["sudo", "dmesg", "--ctime", "--color=never"], text=True)
|
||||
with open(dmesg_path, "w") as f: f.write(dmesg_output)
|
||||
|
||||
summary_path = os.path.join(report_path, "summary.txt")
|
||||
with open(summary_path, "w") as f:
|
||||
f.write(f"Test: {test.name()}\n")
|
||||
f.write(f"Dev params: {vars(dev)}\n")
|
||||
f.write(f"Test params: {vars(test)}\n")
|
||||
f.write(f"Exit Code: {result}\n")
|
||||
|
||||
print(f"Crash report saved to {report_path}")
|
||||
|
||||
_log_file = None
|
||||
def init_log():
|
||||
global _log_file
|
||||
os.makedirs(REPORTS_DIR, exist_ok=True)
|
||||
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
name = f"log_{ts}.log"
|
||||
_log_file = open(f"{REPORTS_DIR}/{name}", "a", buffering=1)
|
||||
|
||||
def log(msg="", end="\n", flush=False):
|
||||
global _log_file
|
||||
_log_file.write(msg.replace("\r", "\n") + end)
|
||||
if flush: _log_file.flush()
|
||||
print(msg + " " * 60, end=end, flush=flush)
|
||||
2
test/external/external_fuzz_hcq_signals.py
vendored
2
test/external/external_fuzz_hcq_signals.py
vendored
@@ -24,7 +24,7 @@ def main():
|
||||
dev.timeline_value += 1
|
||||
|
||||
if sync:=random.randint(0, 10) < 3: dev.synchronize()
|
||||
if DEBUG >= 2: print(f"{i}: {q_t.__name__} {dev.device_id} timeline {dev.timeline_value}, wait for {[d.device_id for d in wait_devs]}, {sync=}")
|
||||
if DEBUG >= 2: print(f"{i}: {q_t} {dev.device_id} timeline {dev.timeline_value}, wait for {[d.device_id for d in wait_devs]}, {sync=}")
|
||||
elif i % 100 == 0: print(f"\rCompleted {i} iterations", end='')
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
7
test/external/external_fuzz_tlsf.py
vendored
7
test/external/external_fuzz_tlsf.py
vendored
@@ -1,6 +1,6 @@
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
|
||||
class AllocatorFuzzer:
|
||||
@@ -62,7 +62,7 @@ class AllocatorFuzzer:
|
||||
return True
|
||||
|
||||
def run(self):
|
||||
for i in range(10000000):
|
||||
for i in range(getenv("ITERS", 100000)):
|
||||
if (random.random() < self.alloc_probability or not self.allocations): self.random_alloc()
|
||||
else: self.random_free()
|
||||
|
||||
@@ -72,5 +72,8 @@ class AllocatorFuzzer:
|
||||
print("Fuzzing completed successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
SEED = getenv("SEED", 42)
|
||||
random.seed(SEED)
|
||||
|
||||
fuzzer = AllocatorFuzzer(1 << 30)
|
||||
fuzzer.run()
|
||||
|
||||
2
test/external/fuzz_graph.py
vendored
2
test/external/fuzz_graph.py
vendored
@@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
np.random.seed(SEED)
|
||||
|
||||
next_graph_id = 0
|
||||
while True:
|
||||
for i in range(getenv("ITERS", 1000)):
|
||||
print("Running graph", next_graph_id)
|
||||
jis, all_buffers, input_buffers = gen_graph()
|
||||
fuzz_graph(jis, all_buffers, input_buffers)
|
||||
|
||||
Reference in New Issue
Block a user