minor schedule differ utils [run_process_replay] (#6348)

* minor schedule differ utils [run_process_replay]

* rm
This commit is contained in:
qazal
2024-09-04 03:41:38 +08:00
committed by GitHub
parent fc30e4825d
commit 99018a4aa1
4 changed files with 16 additions and 14 deletions

View File

@@ -1,13 +1,12 @@
# create a diff of two schedule graphs
import shutil, importlib, uuid, os, logging, contextlib
from collections import defaultdict
from typing import DefaultDict, Dict, List, Set, Tuple
from test.external.process_replay.utils import print_diff
from typing import DefaultDict, List, Set, Tuple
from test.external.process_replay.helpers import print_diff
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
from tinygrad.helpers import CI, DEBUG, Context, ContextVar, colored, diskcache_put, fetch, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
from tinygrad.ops import UOp
CAPTURING_PROCESS_REPLAY = ContextVar("CAPTURING_PROCESS_REPLAY", getenv("RUN_PROCESS_REPLAY"))
@@ -29,15 +28,18 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
for buf in lsi.outputs:
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
changed = 0
seen_diffs: Set[Tuple[bytes, ...]] = set()
for buf, si in si_for_buf.items():
asts: Dict[bytes, UOp] = {x.ast.key:x.ast for x in si}
seen_diffs: Set[bytes] = set()
for buf,si in si_for_buf.items():
si = list({x.ast.key:x for x in si}.values())
if len(si) == 1: continue
assert len(si) == 2, f"must have a ref and a compare schedule {len(si)}"
ref, compare = si
# no new kernel for buf
if len(asts) == 1: continue
if (cache_key:=tuple(asts)) in seen_diffs: continue
if ref.ast.key == compare.ast.key: continue
if (cache_key:=ref.ast.key+compare.ast.key) in seen_diffs: continue
seen_diffs.add(cache_key)
changed += 1
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values())))
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key]))
if not CI: print_si_diff(si[0], si[1])
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
return changed
@@ -50,7 +52,7 @@ def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
ei0 = lower_schedule_item(si0)
ei1 = lower_schedule_item(si1)
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
print_diff(ei0.prg.p.src, ei1.prg.p.src)
if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src)
# TODO: create new Buffers for process replay to test correctness
if getenv("TIMING"):
with Context(DEBUG=2):

View File

@@ -2,6 +2,7 @@ import difflib, logging
from tinygrad.helpers import colored, getenv
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
if unified:
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)

View File

@@ -4,7 +4,7 @@ import os, multiprocessing, logging, pickle, sqlite3
from typing import Callable, List, cast
from tinygrad.helpers import VERSION, Context, ContextVar, db_connection, getenv, tqdm
from tinygrad.codegen.kernel import Kernel
from test.external.process_replay.utils import print_diff
from test.external.process_replay.helpers import print_diff
# *** process replay settings
@@ -14,9 +14,9 @@ REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
os.environ["RUN_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
# user config
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))

View File

@@ -18,7 +18,6 @@ sys.setrecursionlimit(10000)
# optionally log the ops to disk
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
# use graph rewrite for reduceop fusion
# *** ScheduleItem return type ***