mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
make tests faster (#13734)
This commit is contained in:
@@ -20,7 +20,7 @@ global_mem_used = 0
|
||||
def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
with Context(JIT=2):
|
||||
tms = []
|
||||
for _ in range(4):
|
||||
for _ in range(2 if CI else 4):
|
||||
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
|
||||
GlobalCounters.reset()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
|
||||
@@ -57,14 +57,6 @@ class TestKeccak(unittest.TestCase):
|
||||
# https://www.di-mgt.com.au/sha_testvectors.html
|
||||
self.assertEqual(bytes(Tensor(b"abc").keccak().tolist()),
|
||||
bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(Tensor(b"").keccak().tolist()),
|
||||
bytearray.fromhex("a7ffc6f8bf1ed766 51c14756a061d662 f580ff4de43b49fa 82d80a4b80f8434a"))
|
||||
t = Tensor(b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu").keccak()
|
||||
self.assertEqual(bytes(t.tolist()),
|
||||
bytearray.fromhex("916f6061fe879741 ca6469b43971dfdb 28b1a32dc36cb325 4e812be27aad1d18"))
|
||||
# TODO: this does not run or very slow
|
||||
# self.assertEqual(bytes(Tensor(b"a" * 1000000).keccak().tolist()),
|
||||
# bytearray.fromhex("5c8875ae474a3634 ba4fd55ec85bffd6 61f32aca75c6d699 d0cdcb6c115891c1"))
|
||||
|
||||
@unittest.skipIf(CI, "times out in ci")
|
||||
def test_long(self):
|
||||
@@ -75,24 +67,14 @@ class TestKeccak(unittest.TestCase):
|
||||
self.assertEqual(bytes(Tensor(data).keccak("shake_128").tolist()), hashlib.shake_128(data).digest(16))
|
||||
|
||||
def test_variable_bs(self):
|
||||
data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1)
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(1)
|
||||
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(1, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(2)
|
||||
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(2, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(3)
|
||||
data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1)
|
||||
bs = UOp.variable("bs", 1, 4096).bind(3)
|
||||
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(3, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df"))
|
||||
|
||||
@unittest.skipIf(CI, "redundant with test_variable_bs")
|
||||
def test_variable_bs_jit(self):
|
||||
def f(data):
|
||||
return data.keccak()
|
||||
|
||||
@@ -167,7 +167,8 @@ class TestViz(BaseTestViz):
|
||||
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
|
||||
# use smaller stack limit for faster test (default is 250000)
|
||||
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
|
||||
graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0))
|
||||
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
|
||||
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
|
||||
|
||||
@@ -1061,6 +1061,7 @@ class PatternMatcher:
|
||||
# *** tracking pattern matcher ***
|
||||
|
||||
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0)
|
||||
REWRITE_STACK_LIMIT = ContextVar("REWRITE_STACK_LIMIT", 250000)
|
||||
match_stats:dict[UPat, list[int|float]] = dict()
|
||||
|
||||
# TRACK_MATCH_STATS>=2 or VIZ=1 saves all matches
|
||||
@@ -1217,7 +1218,6 @@ class RewriteContext:
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
|
||||
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
|
||||
REWRITE_STACK_LIMIT = getenv("REWRITE_STACK_LIMIT", 250000)
|
||||
while stack:
|
||||
if len(stack) > REWRITE_STACK_LIMIT: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
||||
n, stage, new_n = stack.pop()
|
||||
|
||||
Reference in New Issue
Block a user