mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
factor out resource access logic in multigraph base class (#4385)
* factor out resource access logic in multigraph base class * hsa fixes * clean * linter * linter 2 * not need this
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
|
||||
import functools, itertools
|
||||
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
|
||||
import functools, itertools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored
|
||||
@@ -81,7 +81,25 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
||||
self.vars = list(var_vals.keys())
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
|
||||
|
||||
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
|
||||
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.w_dependency_map: Dict[Any, Any] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
def _access_resources(self, read, write, new_dependency:Any):
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||
wait_nodes = []
|
||||
|
||||
for rawbuf in read + write:
|
||||
if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)])
|
||||
for rawbuf in write:
|
||||
if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf)))
|
||||
|
||||
for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency)
|
||||
for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import ctypes, collections
|
||||
import ctypes
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
@@ -19,8 +19,6 @@ class CUDAGraph(MultiGraphRunner):
|
||||
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
||||
|
||||
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
self.w_dependency_map: Dict[Any, Any] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
|
||||
self.cpu_buffers = []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
@@ -28,7 +26,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
global_size, local_size = ji.prg.launch_dims(var_vals)
|
||||
|
||||
new_node = cuda.CUgraphNode()
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
|
||||
deps = self._access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars])
|
||||
@@ -41,7 +39,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
|
||||
node_from = cuda.CUgraphNode()
|
||||
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
|
||||
deps = self._access_resources(read=[src], write=[dest], new_dependency=node_from)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
if getenv("CUDA_P2P", int(CUDADevice.peer_access)):
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
@@ -95,16 +93,3 @@ class CUDAGraph(MultiGraphRunner):
|
||||
|
||||
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
||||
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
||||
|
||||
def access_resources(self, read, write, new_dependency):
|
||||
wait_nodes = []
|
||||
|
||||
for rawbuf in read + write:
|
||||
if rawbuf._buf.value in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[rawbuf._buf.value])
|
||||
for rawbuf in write:
|
||||
if rawbuf._buf.value in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(rawbuf._buf.value))
|
||||
|
||||
if new_dependency is not None:
|
||||
for rawbuf in read: self.r_dependency_map[rawbuf._buf.value].append(new_dependency)
|
||||
for rawbuf in write: self.w_dependency_map[rawbuf._buf.value] = new_dependency
|
||||
return {id(x):x for x in wait_nodes}.values()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Union, Tuple
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, Device
|
||||
@@ -61,8 +61,6 @@ class HSAGraph(MultiGraphRunner):
|
||||
self.transfers = []
|
||||
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
||||
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
|
||||
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
||||
|
||||
@@ -166,27 +164,8 @@ class HSAGraph(MultiGraphRunner):
|
||||
return packet.completion_signal
|
||||
return None
|
||||
|
||||
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||
# The tracked dependencies are either hsa signals or ints that reference a specific aql packet.
|
||||
wait_signals: List[Optional[hsa.hsa_signal_t]] = []
|
||||
|
||||
def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
|
||||
rdeps = self._access_resources(read, write, new_dependency)
|
||||
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
|
||||
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
|
||||
for rawbuf in read:
|
||||
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
for rawbuf in write:
|
||||
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
if rawbuf._buf in self.r_dependency_map:
|
||||
rdeps = self.r_dependency_map.pop(rawbuf._buf)
|
||||
|
||||
# When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order.
|
||||
signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)]
|
||||
deps = signal_deps + ([max(aql_deps)] if len(aql_deps) > 0 else [])
|
||||
for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets))
|
||||
|
||||
if new_dependency is not None:
|
||||
for rawbuf in read: self.r_dependency_map[rawbuf._buf].append(new_dependency)
|
||||
for rawbuf in write: self.w_dependency_map[rawbuf._buf] = new_dependency
|
||||
|
||||
return dedup_signals(wait_signals)
|
||||
|
||||
Reference in New Issue
Block a user