mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pickle ContextVars in process replay [pr] (#8484)
* pickle ContextVars in process replay * add test_pickle_context_var [pr] * more realistic
This commit is contained in:
2
test/external/process_replay/helpers.py
vendored
2
test/external/process_replay/helpers.py
vendored
@@ -14,4 +14,4 @@ def get_process_replay_ctx() -> Tuple[ProcessReplayContext, Dict]:
|
||||
loc = "\n".join(traceback.format_list(stack))
|
||||
try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||
except Exception: head_sha = ""
|
||||
return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), {k:v.value for k,v in ContextVar._cache.items()}
|
||||
return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), ContextVar._cache
|
||||
|
||||
@@ -60,7 +60,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
|
||||
continue
|
||||
# try recreate
|
||||
try:
|
||||
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
|
||||
with Context(**{k:v.value for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
|
||||
if good is None: continue
|
||||
except Exception as e:
|
||||
changed += 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, pickle, types
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.helpers import GlobalCounters, ContextVar, Context
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
@@ -95,6 +95,13 @@ class TestPickle(unittest.TestCase):
|
||||
out = add_fxn(x, y)
|
||||
np.testing.assert_equal(out.numpy(), 102)
|
||||
|
||||
def test_pickle_context_var(self):
|
||||
v = ContextVar("test_var", 0)
|
||||
with Context(test_var=1):
|
||||
vs = pickle.dumps(v)
|
||||
v2 = pickle.loads(vs)
|
||||
self.assertEqual(v2.value, 1)
|
||||
|
||||
def test_pickle_schedule(self):
|
||||
a = Tensor([1,2])
|
||||
out = a + 2
|
||||
|
||||
@@ -276,8 +276,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# capture process replay
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, {k:v.value for k,v in ContextVar._cache.items()}, sink))
|
||||
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
|
||||
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
|
||||
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user