factor out resource access logic in multigraph base class (#4385)

* factor out resource access logic in multigraph base class

* hsa fixes

* clean

* linter

* linter 2

* not need this
This commit is contained in:
nimlgen
2024-05-03 00:38:22 +03:00
committed by GitHub
parent ab01a9433d
commit ca6c8ae739
3 changed files with 28 additions and 46 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
import functools, itertools
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored
@@ -81,7 +81,25 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
self.vars = list(var_vals.keys())
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.w_dependency_map: Dict[Any, Any] = {}
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
super().__init__(jit_cache, input_rawbuffers, var_vals)
def _access_resources(self, read, write, new_dependency:Any):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
wait_nodes = []
for rawbuf in read + write:
if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)])
for rawbuf in write:
if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf)))
for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency)
for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency
return list({id(x):x for x in wait_nodes}.values())
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):

View File

@@ -1,4 +1,4 @@
import ctypes, collections
import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, GraphException, getenv
@@ -19,8 +19,6 @@ class CUDAGraph(MultiGraphRunner):
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
self.w_dependency_map: Dict[Any, Any] = {}
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
self.cpu_buffers = []
for j,ji in enumerate(self.jit_cache):
@@ -28,7 +26,7 @@ class CUDAGraph(MultiGraphRunner):
global_size, local_size = ji.prg.launch_dims(var_vals)
new_node = cuda.CUgraphNode()
deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
deps = self._access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars])
@@ -41,7 +39,7 @@ class CUDAGraph(MultiGraphRunner):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
node_from = cuda.CUgraphNode()
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
deps = self._access_resources(read=[src], write=[dest], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
if getenv("CUDA_P2P", int(CUDADevice.peer_access)):
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
@@ -95,16 +93,3 @@ class CUDAGraph(MultiGraphRunner):
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
def access_resources(self, read, write, new_dependency):
wait_nodes = []
for rawbuf in read + write:
if rawbuf._buf.value in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[rawbuf._buf.value])
for rawbuf in write:
if rawbuf._buf.value in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(rawbuf._buf.value))
if new_dependency is not None:
for rawbuf in read: self.r_dependency_map[rawbuf._buf.value].append(new_dependency)
for rawbuf in write: self.w_dependency_map[rawbuf._buf.value] = new_dependency
return {id(x):x for x in wait_nodes}.values()

View File

@@ -1,5 +1,5 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Union, Tuple
from typing import List, Any, Dict, cast, Optional, Tuple
from tinygrad.helpers import GraphException, init_c_var, round_up
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, Device
@@ -61,8 +61,6 @@ class HSAGraph(MultiGraphRunner):
self.transfers = []
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
self.signals_to_reset: List[hsa.hsa_signal_t] = []
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
@@ -166,27 +164,8 @@ class HSAGraph(MultiGraphRunner):
return packet.completion_signal
return None
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
# The tracked dependencies are either hsa signals or ints that reference a specific aql packet.
wait_signals: List[Optional[hsa.hsa_signal_t]] = []
def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
rdeps = self._access_resources(read, write, new_dependency)
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
for rawbuf in read:
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
for rawbuf in write:
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
if rawbuf._buf in self.r_dependency_map:
rdeps = self.r_dependency_map.pop(rawbuf._buf)
# When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order.
signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)]
deps = signal_deps + ([max(aql_deps)] if len(aql_deps) > 0 else [])
for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets))
if new_dependency is not None:
for rawbuf in read: self.r_dependency_map[rawbuf._buf].append(new_dependency)
for rawbuf in write: self.w_dependency_map[rawbuf._buf] = new_dependency
return dedup_signals(wait_signals)