mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
minor schedule differ utils [run_process_replay] (#6348)
* minor schedule differ utils [run_process_replay] * rm
This commit is contained in:
22
test/external/process_replay/diff_schedule.py
vendored
22
test/external/process_replay/diff_schedule.py
vendored
@@ -1,13 +1,12 @@
|
||||
# create a diff of two schedule graphs
|
||||
import shutil, importlib, uuid, os, logging, contextlib
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple
|
||||
from test.external.process_replay.utils import print_diff
|
||||
from typing import DefaultDict, List, Set, Tuple
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
|
||||
from tinygrad.helpers import CI, DEBUG, Context, ContextVar, colored, diskcache_put, fetch, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
|
||||
from tinygrad.ops import UOp
|
||||
|
||||
CAPTURING_PROCESS_REPLAY = ContextVar("CAPTURING_PROCESS_REPLAY", getenv("RUN_PROCESS_REPLAY"))
|
||||
|
||||
@@ -29,15 +28,18 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||
for buf in lsi.outputs:
|
||||
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
||||
changed = 0
|
||||
seen_diffs: Set[Tuple[bytes, ...]] = set()
|
||||
for buf, si in si_for_buf.items():
|
||||
asts: Dict[bytes, UOp] = {x.ast.key:x.ast for x in si}
|
||||
seen_diffs: Set[bytes] = set()
|
||||
for buf,si in si_for_buf.items():
|
||||
si = list({x.ast.key:x for x in si}.values())
|
||||
if len(si) == 1: continue
|
||||
assert len(si) == 2, f"must have a ref and a compare schedule {len(si)}"
|
||||
ref, compare = si
|
||||
# no new kernel for buf
|
||||
if len(asts) == 1: continue
|
||||
if (cache_key:=tuple(asts)) in seen_diffs: continue
|
||||
if ref.ast.key == compare.ast.key: continue
|
||||
if (cache_key:=ref.ast.key+compare.ast.key) in seen_diffs: continue
|
||||
seen_diffs.add(cache_key)
|
||||
changed += 1
|
||||
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values())))
|
||||
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key]))
|
||||
if not CI: print_si_diff(si[0], si[1])
|
||||
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
|
||||
return changed
|
||||
@@ -50,7 +52,7 @@ def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
|
||||
ei0 = lower_schedule_item(si0)
|
||||
ei1 = lower_schedule_item(si1)
|
||||
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
|
||||
print_diff(ei0.prg.p.src, ei1.prg.p.src)
|
||||
if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src)
|
||||
# TODO: create new Buffers for process replay to test correctness
|
||||
if getenv("TIMING"):
|
||||
with Context(DEBUG=2):
|
||||
|
||||
@@ -2,6 +2,7 @@ import difflib, logging
|
||||
from tinygrad.helpers import colored, getenv
|
||||
|
||||
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
||||
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
if unified:
|
||||
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
|
||||
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
|
||||
@@ -4,7 +4,7 @@ import os, multiprocessing, logging, pickle, sqlite3
|
||||
from typing import Callable, List, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, db_connection, getenv, tqdm
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from test.external.process_replay.utils import print_diff
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
|
||||
# *** process replay settings
|
||||
|
||||
@@ -14,9 +14,9 @@ REF = os.getenv("GITHUB_REF_NAME", "")
|
||||
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
|
||||
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
||||
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
||||
early_stop = multiprocessing.Event()
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
os.environ["RUN_PROCESS_REPLAY"] = "0"
|
||||
early_stop = multiprocessing.Event()
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
|
||||
# user config
|
||||
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
|
||||
|
||||
@@ -18,7 +18,6 @@ sys.setrecursionlimit(10000)
|
||||
|
||||
# optionally log the ops to disk
|
||||
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
||||
# use graph rewrite for reduceop fusion
|
||||
|
||||
# *** ScheduleItem return type ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user