fix extract_dataset + add tests to CI (#10995)

* fix extract_dataset + tests

* add CI

* sops.gz itself is same as master

* yml + gzip -c + ge

* don't commit that

* bump limit to 1000

* axis=7

* test_tiny
This commit is contained in:
qazal
2025-06-27 01:51:36 +03:00
committed by GitHub
parent 4572e65f0f
commit 712980e167
5 changed files with 41 additions and 24 deletions

View File

@@ -371,6 +371,13 @@ jobs:
run: PYTHONPATH="." python test/external/external_uop_gc.py
- name: Run process replay tests
uses: ./.github/actions/process-replay
- name: Regen dataset on test_tiny
run: |
test/external/process_replay/reset.py
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
PYTHONPATH=. python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
DEBUG=1 MIN_ASTS=1 PYTHONPATH=. python extra/optimization/get_action_space.py
- name: Repo line count < 14500 lines
run: MAX_LINE_COUNT=14500 python sz.py

View File

@@ -6,8 +6,8 @@ from test.external.process_replay.process_replay import _pmap
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
def extract_ast(*args) -> None:
open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n")
open(LOGOPS, "a").write(str(args[1]).replace("\n", "").replace(" ", "")+"\n")
return None
if __name__ == "__main__":
_pmap("kernel", extract_ast)
_pmap({"get_program":extract_ast})

View File

@@ -5,6 +5,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.helpers import getenv
inf, nan = float('inf'), float('nan')
UOps = Ops
@@ -26,7 +27,7 @@ from tinygrad.helpers import dedup, DEBUG
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
fn = Path(__file__).parent.parent / "datasets/sops.gz"
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
assert len(ast_strs) > 5000, f"dataset size = {len(ast_strs)} is too small"
assert len(ast_strs) >= getenv("MIN_ASTS", 5000), f"dataset size = {len(ast_strs)} is too small"
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} before filters")
if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools
from typing import Callable, Any
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.kernelize.kernelize import get_kernelize_map
@@ -49,7 +49,7 @@ replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_ke
# *** run replayers on captured rows and print diffs
def diff(offset:int) -> None:
def diff(offset:int, fxns:dict[str, Callable[..., tuple|None]]) -> None:
if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
if early_stop.is_set(): return None
conn = db_connection()
@@ -64,8 +64,10 @@ def diff(offset:int) -> None:
try:
name, args, kwargs, ctx_vals, loc, ret = pickle.loads(row[0])
ctx_vars = {k:v.value for k,v in ctx_vals.items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value}
if (replayer:=replayers.get(name)) is None: continue
with Context(**ctx_vars): good, compare, metadata = replayer(ret, *args, **kwargs)
if (replayer:=fxns.get(name)) is None: continue
with Context(**ctx_vars):
if (ret:=replayer(ret, *args, **kwargs)) is None: continue
good, compare, metadata = ret
if good != compare:
for m in metadata: trunc_log(m)
logging.info(loc)
@@ -79,6 +81,25 @@ def diff(offset:int) -> None:
conn.commit()
cur.close()
# *** generic runner to map rows of a table to a function in parallel
def _pmap(fxns:dict[str, Callable]) -> None:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
raise RuntimeError(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
finally:
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(functools.partial(diff, fxns=fxns), inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
# *** main loop
if __name__ == "__main__":
@@ -86,20 +107,8 @@ if __name__ == "__main__":
logging.info("skipping process replay.")
exit(0)
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
warnings.warn(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
exit(int(ASSERT_DIFF))
finally:
conn.commit()
cur.close()
logging.info(f"running process replay with {ASSERT_DIFF=}")
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(diff, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
try: _pmap(replayers)
except Exception as e:
logging.info("process replay err", e)
exit(int(ASSERT_DIFF))

View File

@@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import ProgramSpec
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]