diff --git a/test/external/process_replay/local.sh b/test/external/process_replay/local.sh index efd13953b8..38ab1d123c 100755 --- a/test/external/process_replay/local.sh +++ b/test/external/process_replay/local.sh @@ -4,5 +4,6 @@ HEAD=$(git rev-parse --abbrev-ref HEAD) python test/external/process_replay/reset.py RUN_PROCESS_REPLAY=1 python test/test_ops.py TestOps.test_add git checkout master +git checkout $HEAD -- test/external/process_replay/process_replay.py ASSERT_PROCESS_REPLAY=1 python test/external/process_replay/process_replay.py git checkout $HEAD diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index a588ecdb9a..933dfc4a29 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -34,8 +34,8 @@ class ProcessReplayWarning(Warning): pass # *** recreators def recreate_sched(big_sink:UOp) -> list[UOp]: - sched_sink = get_becomes_map(big_sink)[0][big_sink] - return dedup(u.src[1].arg.ast for u in sched_sink.toposort if u.op is Ops.ASSIGN) + sched_sink = get_becomes_map(big_sink)[big_sink] + return dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL) def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str: k = Kernel(ast, opts=opts)