mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
hotfix: scheduler differ (#6115)
* hotfix: scheduler differ * add the test back * track keys
This commit is contained in:
11
test/external/process_replay/diff_schedule.py
vendored
11
test/external/process_replay/diff_schedule.py
vendored
@@ -3,9 +3,8 @@ import shutil, importlib, uuid, os, logging
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, List, Set, Tuple
|
||||
from test.external.process_replay.utils import print_diff
|
||||
from tinygrad.codegen.uops import UOp
|
||||
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
|
||||
from tinygrad.helpers import DEBUG, Context, colored, dedup, diskcache_put, fetch, getenv
|
||||
from tinygrad.helpers import DEBUG, Context, colored, diskcache_put, fetch, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
|
||||
|
||||
@@ -14,7 +13,7 @@ def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List
|
||||
ref_schedule = getenv("REF_COMMIT_HASH", "master")
|
||||
fp = __file__.replace("diff_schedule", "master_schedule")
|
||||
if not os.path.isfile(fp):
|
||||
shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py"), fp)
|
||||
shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py", allow_caching=False), fp)
|
||||
# create the reference graph
|
||||
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set())
|
||||
# compare
|
||||
@@ -27,15 +26,15 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||
for buf in lsi.outputs:
|
||||
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
||||
changed = 0
|
||||
seen_diffs: Set[Tuple[UOp, ...]] = set()
|
||||
seen_diffs: Set[Tuple[bytes, ...]] = set()
|
||||
for buf, si in si_for_buf.items():
|
||||
asts = tuple(dedup([x.ast for x in si]))
|
||||
asts = tuple({x.ast.key:x.ast for x in si})
|
||||
# kernels didn't change
|
||||
if len(si) > 1 and len(asts) == 1: continue
|
||||
if asts in seen_diffs: continue
|
||||
seen_diffs.add(asts)
|
||||
changed += 1
|
||||
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), asts))
|
||||
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [x.ast for x in si]))
|
||||
if len(asts) == 1:
|
||||
print(f"{buf} folded in the second schedule")
|
||||
else: print_si_diff(si[0], si[1])
|
||||
|
||||
@@ -6,7 +6,6 @@ from tinygrad.helpers import Context
|
||||
from tinygrad.engine.schedule import _graph_schedule
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
@unittest.skip("TODO: uop compare")
|
||||
class TestDiffSchedule(unittest.TestCase):
|
||||
def test_diff_arange(self):
|
||||
# diff a single arange kernel
|
||||
|
||||
Reference in New Issue
Block a user