mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
mask out writable bufs in runtime access_resources (#7234)
This commit is contained in:
@@ -25,8 +25,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
||||
|
||||
new_node = cuda.CUgraphNode()
|
||||
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
|
||||
[x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
|
||||
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
||||
@@ -39,7 +38,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
src_dev = cast(CUDADevice, Device[src.device])
|
||||
node_from = cuda.CUgraphNode()
|
||||
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
|
||||
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
|
||||
@@ -56,7 +56,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
|
||||
# Get dependencies based on input and output buffers.
|
||||
rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore
|
||||
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
|
||||
|
||||
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
|
||||
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
|
||||
|
||||
Reference in New Issue
Block a user