mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
59 lines
2.2 KiB
Python
59 lines
2.2 KiB
Python
from typing import Tuple, Any, List
|
|
import unittest
|
|
from tinygrad.ops import ASTRunner, GraphBatchExecutor
|
|
|
|
class TestGraph(GraphBatchExecutor):
|
|
def __init__(self, jit_cache: List[Tuple[Any, Any, Any]]):
|
|
super().__init__(jit_cache)
|
|
|
|
self.next_jit = 0
|
|
self.update_called = 0
|
|
self.jcid_to_instid = {}
|
|
self.jc_info = []
|
|
self.exec_set = set()
|
|
self.split_into_graphs(jit_cache)
|
|
assert len(self.jc_info) == len(jit_cache), f"each jit cache entry should be captured into nodes. {len(self.jc_info)} != {len(jit_cache)}"
|
|
|
|
target_size = [4, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
|
for i, inst in enumerate(self.graphs):
|
|
assert len(inst) == target_size[i] or (i == len(self.graphs) - 1 and len(inst) > 0), "unexpected graph size"
|
|
|
|
def create_graph(self, jit_cache: List[Tuple[Any, Any, Any]]):
|
|
for prg, pargs, variables in jit_cache:
|
|
self.jcid_to_instid[len(self.jc_info)] = len(self.graphs)
|
|
assert pargs == self.next_jit, "prog is written 2+ times in the graph or some of them are skipped"
|
|
self.next_jit += 1
|
|
self.jc_info.append((prg, pargs, variables))
|
|
self.graphs.append(jit_cache)
|
|
|
|
def update_node(self, instid, jcid, prg, pargs, variables, updated_args=None):
|
|
self.update_called += 1
|
|
assert instid == self.jcid_to_instid[jcid], "jit cache entry does not belong to the given instance"
|
|
|
|
def exec_instance(self, instid):
|
|
assert 0 <= instid < len(self.graphs), "called unknown instance"
|
|
self.exec_set.add(instid)
|
|
|
|
class TestBatchExec(unittest.TestCase):
|
|
def test_graph_batch_exec_partition(self):
|
|
|
|
def _helper(jit_cache_size, updates_count):
|
|
fake_jit_cache = [(ASTRunner("test", "", [1], [1]), i, i) for i in range(jit_cache_size)]
|
|
updatable_entries = {i:0 for i in range(updates_count)}
|
|
gr = TestGraph(fake_jit_cache)
|
|
gr.exec(fake_jit_cache, updatable_entries)
|
|
|
|
assert gr.update_called == updates_count, "not all updates are called"
|
|
assert len(gr.exec_set) == len(gr.graphs), "every instance should be executed"
|
|
|
|
_helper(512, 512)
|
|
_helper(334, 13)
|
|
_helper(812, 111)
|
|
_helper(2, 0)
|
|
_helper(1, 1)
|
|
_helper(4, 3)
|
|
_helper(7, 2)
|
|
_helper(8, 8)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |