Files
tinygrad/test/test_exec.py
nimlgen 2e89fd264f Refactor hipgraph (#2141)
* refactor hip graph

* linter happy

* happy liner
2023-10-24 15:45:56 -04:00

59 lines
2.2 KiB
Python

from typing import Tuple, Any, List
import unittest
from tinygrad.ops import ASTRunner, GraphBatchExecutor
class TestGraph(GraphBatchExecutor):
def __init__(self, jit_cache: List[Tuple[Any, Any, Any]]):
super().__init__(jit_cache)
self.next_jit = 0
self.update_called = 0
self.jcid_to_instid = {}
self.jc_info = []
self.exec_set = set()
self.split_into_graphs(jit_cache)
assert len(self.jc_info) == len(jit_cache), f"each jit cache entry should be captured into nodes. {len(self.jc_info)} != {len(jit_cache)}"
target_size = [4, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
for i, inst in enumerate(self.graphs):
assert len(inst) == target_size[i] or (i == len(self.graphs) - 1 and len(inst) > 0), "unexpected graph size"
def create_graph(self, jit_cache: List[Tuple[Any, Any, Any]]):
for prg, pargs, variables in jit_cache:
self.jcid_to_instid[len(self.jc_info)] = len(self.graphs)
assert pargs == self.next_jit, "prog is written 2+ times in the graph or some of them are skipped"
self.next_jit += 1
self.jc_info.append((prg, pargs, variables))
self.graphs.append(jit_cache)
def update_node(self, instid, jcid, prg, pargs, variables, updated_args=None):
self.update_called += 1
assert instid == self.jcid_to_instid[jcid], "jit cache entry does not belong to the given instance"
def exec_instance(self, instid):
assert 0 <= instid < len(self.graphs), "called unknown instance"
self.exec_set.add(instid)
class TestBatchExec(unittest.TestCase):
def test_graph_batch_exec_partition(self):
def _helper(jit_cache_size, updates_count):
fake_jit_cache = [(ASTRunner("test", "", [1], [1]), i, i) for i in range(jit_cache_size)]
updatable_entries = {i:0 for i in range(updates_count)}
gr = TestGraph(fake_jit_cache)
gr.exec(fake_jit_cache, updatable_entries)
assert gr.update_called == updates_count, "not all updates are called"
assert len(gr.exec_set) == len(gr.graphs), "every instance should be executed"
_helper(512, 512)
_helper(334, 13)
_helper(812, 111)
_helper(2, 0)
_helper(1, 1)
_helper(4, 3)
_helper(7, 2)
_helper(8, 8)
if __name__ == '__main__':
unittest.main()