mask out writable bufs in runtime access_resources (#7234)

This commit is contained in:
qazal
2024-10-23 16:13:50 +03:00
committed by GitHub
parent d2b608233a
commit aeeb917b6e
6 changed files with 15 additions and 20 deletions

View File

@@ -25,8 +25,7 @@ class CUDAGraph(MultiGraphRunner):
global_size, local_size = ji.prg.p.launch_dims(var_vals)
new_node = cuda.CUgraphNode()
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
[x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
@@ -39,7 +38,7 @@ class CUDAGraph(MultiGraphRunner):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
src_dev = cast(CUDADevice, Device[src.device])
node_from = cuda.CUgraphNode()
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,

View File

@@ -56,7 +56,7 @@ class HCQGraph(MultiGraphRunner):
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
# Get dependencies based on input and output buffers.
rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])