From 712980e167e06fcb1578a16b1201e202f7b92a50 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 27 Jun 2025 01:51:36 +0300 Subject: [PATCH] fix extract_dataset + add tests to CI (#10995) * fix extract_dataset + tests * add CI * sops.gz itself is same as master * yml + gzip -c + ge * don't commit that * bump limit to 1000 * axis=7 * test_tiny --- .github/workflows/test.yml | 7 +++ extra/optimization/extract_dataset.py | 4 +- extra/optimization/helpers.py | 3 +- .../external/process_replay/process_replay.py | 49 +++++++++++-------- tinygrad/opt/search.py | 2 +- 5 files changed, 41 insertions(+), 24 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 14ccbba2e8..bbea0ab112 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -371,6 +371,13 @@ jobs: run: PYTHONPATH="." python test/external/external_uop_gc.py - name: Run process replay tests uses: ./.github/actions/process-replay + - name: Regen dataset on test_tiny + run: | + test/external/process_replay/reset.py + CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus + PYTHONPATH=. python extra/optimization/extract_dataset.py + gzip -c /tmp/sops > extra/datasets/sops.gz + DEBUG=1 MIN_ASTS=1 PYTHONPATH=. python extra/optimization/get_action_space.py - name: Repo line count < 14500 lines run: MAX_LINE_COUNT=14500 python sz.py diff --git a/extra/optimization/extract_dataset.py b/extra/optimization/extract_dataset.py index 5ec093bac8..b33530b209 100755 --- a/extra/optimization/extract_dataset.py +++ b/extra/optimization/extract_dataset.py @@ -6,8 +6,8 @@ from test.external.process_replay.process_replay import _pmap LOGOPS = os.getenv("LOGOPS", "/tmp/sops") def extract_ast(*args) -> None: - open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n") + open(LOGOPS, "a").write(str(args[1]).replace("\n", "").replace(" ", "")+"\n") return None if __name__ == "__main__": - _pmap("kernel", extract_ast) + _pmap({"get_program":extract_ast}) diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index bca2eb6b85..4d2549dee6 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -5,6 +5,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View +from tinygrad.helpers import getenv inf, nan = float('inf'), float('nan') UOps = Ops @@ -26,7 +27,7 @@ from tinygrad.helpers import dedup, DEBUG def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True): fn = Path(__file__).parent.parent / "datasets/sops.gz" ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n")) - assert len(ast_strs) > 5000, f"dataset size = {len(ast_strs)} is too small" + assert len(ast_strs) >= getenv("MIN_ASTS", 5000), f"dataset size = {len(ast_strs)} is too small" if DEBUG >= 1: print(f"loaded {len(ast_strs)=} before filters") if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x] if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x] diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index d26de01163..41d779fe0c 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools +import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools from typing import Callable, Any from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.kernelize.kernelize import get_kernelize_map @@ -49,7 +49,7 @@ replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_ke # *** run replayers on captured rows and print diffs -def diff(offset:int) -> None: +def diff(offset:int, fxns:dict[str, Callable[..., tuple|None]]) -> None: if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning) if early_stop.is_set(): return None conn = db_connection() @@ -64,8 +64,10 @@ def diff(offset:int) -> None: try: name, args, kwargs, ctx_vals, loc, ret = pickle.loads(row[0]) ctx_vars = {k:v.value for k,v in ctx_vals.items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value} - if (replayer:=replayers.get(name)) is None: continue - with Context(**ctx_vars): good, compare, metadata = replayer(ret, *args, **kwargs) + if (replayer:=fxns.get(name)) is None: continue + with Context(**ctx_vars): + if (ret:=replayer(ret, *args, **kwargs)) is None: continue + good, compare, metadata = ret if good != compare: for m in metadata: trunc_log(m) logging.info(loc) @@ -79,6 +81,25 @@ def diff(offset:int) -> None: conn.commit() cur.close() +# *** generic runner to map rows of a table to a function in parallel + +def _pmap(fxns:dict[str, Callable]) -> None: + conn = db_connection() + cur = conn.cursor() + try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] + except sqlite3.OperationalError: + raise RuntimeError(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") + finally: + conn.commit() + cur.close() + + with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool: + inputs = list(range(0, row_count, PAGE_SIZE)) + list(tqdm(pool.imap_unordered(functools.partial(diff, fxns=fxns), inputs), total=len(inputs))) + pool.close() + pool.join() + pool.terminate() + # *** main loop if __name__ == "__main__": @@ -86,20 +107,8 @@ if __name__ == "__main__": logging.info("skipping process replay.") exit(0) - conn = db_connection() - cur = conn.cursor() - try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] - except sqlite3.OperationalError: - warnings.warn(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning) - exit(int(ASSERT_DIFF)) - finally: - conn.commit() - cur.close() - logging.info(f"running process replay with {ASSERT_DIFF=}") - with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool: - inputs = list(range(0, row_count, PAGE_SIZE)) - list(tqdm(pool.imap_unordered(diff, inputs), total=len(inputs))) - pool.close() - pool.join() - pool.terminate() + try: _pmap(replayers) + except Exception as e: + logging.info("process replay err", e) + exit(int(ASSERT_DIFF)) diff --git a/tinygrad/opt/search.py b/tinygrad/opt/search.py index 27c9e5e1eb..903a51f281 100644 --- a/tinygrad/opt/search.py +++ b/tinygrad/opt/search.py @@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer import ProgramSpec -actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] +actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)] actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)] actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]