pickle ContextVars in process replay [pr] (#8484)

* pickle ContextVars in process replay

* add test_pickle_context_var [pr]

* more realistic
This commit is contained in:
qazal
2025-01-03 17:11:54 +02:00
committed by GitHub
parent bd4d7dc4eb
commit 12fa4340b3
4 changed files with 11 additions and 5 deletions

View File

@@ -14,4 +14,4 @@ def get_process_replay_ctx() -> Tuple[ProcessReplayContext, Dict]:
loc = "\n".join(traceback.format_list(stack))
try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
except Exception: head_sha = ""
return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), {k:v.value for k,v in ContextVar._cache.items()}
return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), ContextVar._cache

View File

@@ -60,7 +60,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
continue
# try recreate
try:
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
with Context(**{k:v.value for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
if good is None: continue
except Exception as e:
changed += 1

View File

@@ -1,7 +1,7 @@
import unittest, pickle, types
import numpy as np
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.helpers import GlobalCounters
from tinygrad.helpers import GlobalCounters, ContextVar, Context
from tinygrad.ops import PatternMatcher, UPat, UOp
class TestPickle(unittest.TestCase):
@@ -95,6 +95,13 @@ class TestPickle(unittest.TestCase):
out = add_fxn(x, y)
np.testing.assert_equal(out.numpy(), 102)
def test_pickle_context_var(self):
v = ContextVar("test_var", 0)
with Context(test_var=1):
vs = pickle.dumps(v)
v2 = pickle.loads(vs)
self.assertEqual(v2.value, 1)
def test_pickle_schedule(self):
a = Tensor([1,2])
out = a + 2

View File

@@ -276,8 +276,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# capture process replay
if getenv("RUN_PROCESS_REPLAY"):
PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, {k:v.value for k,v in ContextVar._cache.items()}, sink))
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))