diff --git a/test/external/process_replay/helpers.py b/test/external/process_replay/helpers.py index 7475893b11..8b9db50494 100644 --- a/test/external/process_replay/helpers.py +++ b/test/external/process_replay/helpers.py @@ -8,10 +8,10 @@ class ProcessReplayContext: loc: str head_sha: str run_id: Optional[int] -def get_process_replay_ctx() -> Tuple[Dict, ProcessReplayContext]: +def get_process_replay_ctx() -> Tuple[ProcessReplayContext, Dict]: stack = filter(lambda x: "tinygrad" in x.filename and not any(n in x.filename for n in ["engine/schedule.py", "engine/realize.py", \ "codegen/kernel.py", "unittest"]), traceback.extract_stack()[:-1]) loc = "\n".join(traceback.format_list(stack)) try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode() except Exception: head_sha = "" - return {k:v.value for k,v in ContextVar._cache.items()}, ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None) + return ProcessReplayContext(loc, head_sha, getenv("GITHUB_RUN_ID") or None), {k:v.value for k,v in ContextVar._cache.items()} diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 3d6f117b7e..5f90bcd515 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -29,7 +29,7 @@ if REF == "master": SKIP_PROCESS_REPLAY = True # *** recreators def recreate_sched(sink:UOp, ctx:ScheduleItemContext) -> UOp: return full_ast_rewrite(sink, ctx) -def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str: +def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, _) -> str: k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) # NOTE: replay with the captured renderer, not the one in master diff --git a/test/test_tensor.py b/test/test_tensor.py index 799dcac56b..51e8002002 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -714,51 +714,51 @@ class TestTensorMetadata(unittest.TestCase): x = Tensor.rand(3, requires_grad=True) W = Tensor.rand(3, 3, requires_grad=True) out = x.matmul(W) - assert out.lazydata.metadata.name == "matmul" + self.assertEqual(out.lazydata.metadata.name, "matmul") s = create_schedule([out.lazydata]) - assert len(s[-1].metadata) == 1 - assert s[-1].metadata[0].name == "matmul" + self.assertEqual(len(s[-1].metadata), 1) + self.assertEqual(s[-1].metadata[0].name, "matmul") def test_relu(self): _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) out = x.relu() - assert out.lazydata.metadata.name == "relu" + self.assertEqual(out.lazydata.metadata.name, "relu") s = create_schedule([out.lazydata]) - assert len(s[-1].metadata) == 1 - assert s[-1].metadata[0].name == "relu" + self.assertEqual(len(s[-1].metadata), 1) + self.assertEqual(s[-1].metadata[0].name, "relu") def test_complex(self): _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) y = Tensor.rand(3, requires_grad=True) out = x.relu() * y.sigmoid() - assert out.lazydata.metadata.name == "__mul__" - assert out.lazydata.srcs[0].metadata.name == "relu" - assert out.lazydata.srcs[1].metadata.name == "sigmoid" + self.assertEqual(out.lazydata.metadata.name, "__mul__") + self.assertEqual(out.lazydata.srcs[0].metadata.name, "relu") + self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid") s = create_schedule([out.lazydata]) - assert len(s[-1].metadata) == 3 - assert s[-1].metadata[0].name == "relu" - assert s[-1].metadata[1].name == "sigmoid" - assert s[-1].metadata[2].name == "__mul__" + self.assertEqual(len(s[-1].metadata), 3) + self.assertEqual(s[-1].metadata[0].name, "relu") + self.assertEqual(s[-1].metadata[1].name, "sigmoid") + self.assertEqual(s[-1].metadata[2].name, "__mul__") def test_complex_backward(self): _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) y = Tensor.rand(3, requires_grad=True) out = (x.relu() * y.sigmoid()).sum() - assert out.lazydata.metadata.name == "sum" + self.assertEqual(out.lazydata.metadata.name, "sum") out.backward() - assert x.grad.lazydata.metadata.name == "relu" - assert x.grad.lazydata.metadata.backward - assert y.grad.lazydata.metadata.name == "sigmoid" - assert y.grad.lazydata.metadata.backward + self.assertEqual(x.grad.lazydata.metadata.name, "relu") + self.assertTrue(x.grad.lazydata.metadata.backward) + self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid") + self.assertTrue(y.grad.lazydata.metadata.backward) s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata]) - assert len(s[-1].metadata) == 3 - assert s[-1].metadata[0].name == "sigmoid" - assert s[-1].metadata[1].name == "sigmoid" - assert s[-1].metadata[1].backward - assert s[-1].metadata[2].name == "relu" + self.assertEqual(len(s[-1].metadata), 3) + self.assertEqual(s[-1].metadata[0].name, "sigmoid") + self.assertEqual(s[-1].metadata[1].name, "sigmoid") + self.assertTrue(s[-1].metadata[1].backward) + self.assertEqual(s[-1].metadata[2].name, "relu") if __name__ == '__main__': unittest.main()