mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix extract_dataset + add tests to CI (#10995)
* fix extract_dataset + tests * add CI * sops.gz itself is same as master * yml + gzip -c + ge * don't commit that * bump limit to 1000 * axis=7 * test_tiny
This commit is contained in:
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
@@ -371,6 +371,13 @@ jobs:
|
||||
run: PYTHONPATH="." python test/external/external_uop_gc.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
- name: Regen dataset on test_tiny
|
||||
run: |
|
||||
test/external/process_replay/reset.py
|
||||
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
|
||||
PYTHONPATH=. python extra/optimization/extract_dataset.py
|
||||
gzip -c /tmp/sops > extra/datasets/sops.gz
|
||||
DEBUG=1 MIN_ASTS=1 PYTHONPATH=. python extra/optimization/get_action_space.py
|
||||
- name: Repo line count < 14500 lines
|
||||
run: MAX_LINE_COUNT=14500 python sz.py
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from test.external.process_replay.process_replay import _pmap
|
||||
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
|
||||
|
||||
def extract_ast(*args) -> None:
|
||||
open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n")
|
||||
open(LOGOPS, "a").write(str(args[1]).replace("\n", "").replace(" ", "")+"\n")
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
_pmap("kernel", extract_ast)
|
||||
_pmap({"get_program":extract_ast})
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.helpers import getenv
|
||||
inf, nan = float('inf'), float('nan')
|
||||
UOps = Ops
|
||||
|
||||
@@ -26,7 +27,7 @@ from tinygrad.helpers import dedup, DEBUG
|
||||
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
|
||||
fn = Path(__file__).parent.parent / "datasets/sops.gz"
|
||||
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
|
||||
assert len(ast_strs) > 5000, f"dataset size = {len(ast_strs)} is too small"
|
||||
assert len(ast_strs) >= getenv("MIN_ASTS", 5000), f"dataset size = {len(ast_strs)} is too small"
|
||||
if DEBUG >= 1: print(f"loaded {len(ast_strs)=} before filters")
|
||||
if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
|
||||
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
||||
|
||||
49
test/external/process_replay/process_replay.py
vendored
49
test/external/process_replay/process_replay.py
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools
|
||||
from typing import Callable, Any
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.kernelize.kernelize import get_kernelize_map
|
||||
@@ -49,7 +49,7 @@ replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_ke
|
||||
|
||||
# *** run replayers on captured rows and print diffs
|
||||
|
||||
def diff(offset:int) -> None:
|
||||
def diff(offset:int, fxns:dict[str, Callable[..., tuple|None]]) -> None:
|
||||
if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
|
||||
if early_stop.is_set(): return None
|
||||
conn = db_connection()
|
||||
@@ -64,8 +64,10 @@ def diff(offset:int) -> None:
|
||||
try:
|
||||
name, args, kwargs, ctx_vals, loc, ret = pickle.loads(row[0])
|
||||
ctx_vars = {k:v.value for k,v in ctx_vals.items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value}
|
||||
if (replayer:=replayers.get(name)) is None: continue
|
||||
with Context(**ctx_vars): good, compare, metadata = replayer(ret, *args, **kwargs)
|
||||
if (replayer:=fxns.get(name)) is None: continue
|
||||
with Context(**ctx_vars):
|
||||
if (ret:=replayer(ret, *args, **kwargs)) is None: continue
|
||||
good, compare, metadata = ret
|
||||
if good != compare:
|
||||
for m in metadata: trunc_log(m)
|
||||
logging.info(loc)
|
||||
@@ -79,6 +81,25 @@ def diff(offset:int) -> None:
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
# *** generic runner to map rows of a table to a function in parallel
|
||||
|
||||
def _pmap(fxns:dict[str, Callable]) -> None:
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||
except sqlite3.OperationalError:
|
||||
raise RuntimeError(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
||||
finally:
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool:
|
||||
inputs = list(range(0, row_count, PAGE_SIZE))
|
||||
list(tqdm(pool.imap_unordered(functools.partial(diff, fxns=fxns), inputs), total=len(inputs)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pool.terminate()
|
||||
|
||||
# *** main loop
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -86,20 +107,8 @@ if __name__ == "__main__":
|
||||
logging.info("skipping process replay.")
|
||||
exit(0)
|
||||
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||
except sqlite3.OperationalError:
|
||||
warnings.warn(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
|
||||
exit(int(ASSERT_DIFF))
|
||||
finally:
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
logging.info(f"running process replay with {ASSERT_DIFF=}")
|
||||
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count()) as pool:
|
||||
inputs = list(range(0, row_count, PAGE_SIZE))
|
||||
list(tqdm(pool.imap_unordered(diff, inputs), total=len(inputs)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pool.terminate()
|
||||
try: _pmap(replayers)
|
||||
except Exception as e:
|
||||
logging.info("process replay err", e)
|
||||
exit(int(ASSERT_DIFF))
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
||||
|
||||
Reference in New Issue
Block a user