diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 651f67f792..7e797a2314 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -20,7 +20,7 @@ global_mem_used = 0 def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False): with Context(JIT=2): tms = [] - for _ in range(4): + for _ in range(2 if CI else 4): early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()] GlobalCounters.reset() Device[Device.DEFAULT].synchronize() diff --git a/test/unit/test_hashing.py b/test/unit/test_hashing.py index d35fac435c..6024b3090e 100644 --- a/test/unit/test_hashing.py +++ b/test/unit/test_hashing.py @@ -57,14 +57,6 @@ class TestKeccak(unittest.TestCase): # https://www.di-mgt.com.au/sha_testvectors.html self.assertEqual(bytes(Tensor(b"abc").keccak().tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) - self.assertEqual(bytes(Tensor(b"").keccak().tolist()), - bytearray.fromhex("a7ffc6f8bf1ed766 51c14756a061d662 f580ff4de43b49fa 82d80a4b80f8434a")) - t = Tensor(b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu").keccak() - self.assertEqual(bytes(t.tolist()), - bytearray.fromhex("916f6061fe879741 ca6469b43971dfdb 28b1a32dc36cb325 4e812be27aad1d18")) - # TODO: this does not run or very slow - # self.assertEqual(bytes(Tensor(b"a" * 1000000).keccak().tolist()), - # bytearray.fromhex("5c8875ae474a3634 ba4fd55ec85bffd6 61f32aca75c6d699 d0cdcb6c115891c1")) @unittest.skipIf(CI, "times out in ci") def test_long(self): @@ -75,24 +67,14 @@ class TestKeccak(unittest.TestCase): self.assertEqual(bytes(Tensor(data).keccak("shake_128").tolist()), hashlib.shake_128(data).digest(16)) def test_variable_bs(self): - data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1) - - bs = UOp.variable("bs", 1, 4096).bind(1) - out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(1, 32) - self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) - - bs = UOp.variable("bs", 1, 4096).bind(2) - out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(2, 32) - self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) - self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) - - bs = UOp.variable("bs", 1, 4096).bind(3) data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1) + bs = UOp.variable("bs", 1, 4096).bind(3) out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(3, 32) self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df")) + @unittest.skipIf(CI, "redundant with test_variable_bs") def test_variable_bs_jit(self): def f(data): return data.keccak() diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 6cd5673b46..5868c1cafb 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -167,7 +167,8 @@ class TestViz(BaseTestViz): (UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)), (UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)), ]) - with self.assertRaises(RuntimeError): exec_rewrite(a, [pm]) + # use smaller stack limit for faster test (default is 250000) + with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm]) graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0)) self.assertEqual(graphs[0], uop_to_json(a)[id(a)]) self.assertEqual(graphs[1], uop_to_json(b)[id(b)]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 7ff428ce1b..d17ead39dc 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1061,6 +1061,7 @@ class PatternMatcher: # *** tracking pattern matcher *** TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0) +REWRITE_STACK_LIMIT = ContextVar("REWRITE_STACK_LIMIT", 250000) match_stats:dict[UPat, list[int|float]] = dict() # TRACK_MATCH_STATS>=2 or VIZ=1 saves all matches @@ -1217,7 +1218,6 @@ class RewriteContext: def unified_rewrite(self, root:UOp) -> UOp: stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)]) on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again - REWRITE_STACK_LIMIT = getenv("REWRITE_STACK_LIMIT", 250000) while stack: if len(stack) > REWRITE_STACK_LIMIT: raise RuntimeError("infinite loop in graph_rewrite (stack too big)") n, stage, new_n = stack.pop()