diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 29f9b53598..a3ca2df9e2 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,7 +1,7 @@ import ctypes, collections from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda -from tinygrad.helpers import init_c_var, GraphException +from tinygrad.helpers import init_c_var, GraphException, getenv from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable @@ -43,23 +43,27 @@ class CUDAGraph(MultiDeviceJITGraph): elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device]) - cpu_buffer = Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate() - self.cpu_buffers.append(cpu_buffer) - - node_to = cuda.CUgraphNode() node_from = cuda.CUgraphNode() deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None + if getenv("CUDA_P2P"): + cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, + dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, + WidthInBytes=dest.nbytes, Height=1, Depth=1) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) + else: + self.cpu_buffers.append(cpu_buffer:=Buffer(device=src.device, dtype=src.dtype, size=src.size, options=BufferOptions(host=True)).allocate()) - cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, - dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1, - WidthInBytes=dest.nbytes, Height=1, Depth=1) - check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) - cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1, - dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, - WidthInBytes=dest.nbytes, Height=1, Depth=1) - check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1, - ctypes.byref(cp_params), dest_dev.context)) + node_to = cuda.CUgraphNode() + cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, + dstMemoryType=cuda.CU_MEMORYTYPE_HOST, dstHost=cpu_buffer._buf, dstPitch=dest.nbytes, dstHeight=1, + WidthInBytes=dest.nbytes, Height=1, Depth=1) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_to), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) + cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_HOST, srcHost=cpu_buffer._buf, srcPitch=src.nbytes, srcHeight=1, + dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, + WidthInBytes=dest.nbytes, Height=1, Depth=1) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1, + ctypes.byref(cp_params), dest_dev.context)) if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True) self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))