hcq cache invalidation for beam (#5630)

* nv full cache invalidation

* the same command on amd

* linter

* fix amd

* nv no hardcoded consts

* beam default
This commit is contained in:
nimlgen
2024-07-22 18:13:17 +03:00
committed by GitHub
parent c64e9591e3
commit 08a9c0ae5e
4 changed files with 15 additions and 2 deletions

View File

@@ -165,6 +165,7 @@ class NVDriver(VirtDriver):
params.workSubmitToken = gpu_fifo.token
elif struct.cmd == nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE: pass
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_PERF_BOOST: pass
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_FB_FLUSH_GPU_CACHE: pass
else: raise RuntimeError(f"Unknown {struct.cmd} to rm_control")
return 0

View File

@@ -46,7 +46,9 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
for _ in range(cnt):
if clear_l2:
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
if early_stop is not None and early_stop < min(tms): break
return tms
@@ -150,7 +152,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
if lib in seen_libs: continue
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
seen_libs.add(lib)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
except RuntimeError: continue # for runtime issues
timed_lins.append((acted_lins[i], min(tms)))
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501

View File

@@ -492,3 +492,8 @@ class AMDDevice(HCQCompiled):
self.kernargs_ptr = self.kernargs.va_addr
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE: self._prof_process_events()
def invalidate_cache(self):
AMDComputeQueue().memory_barrier().signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
self.synchronize()

View File

@@ -586,3 +586,8 @@ class NVDevice(HCQCompiled):
NVComputeQueue().setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
.signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
def invalidate_caches(self):
rmctrl.fb_flush_gpu_cache(self.fd_ctl, self.root, self.subdevice,
flags=((nv_gpu.NV2080_CTRL_FB_FLUSH_GPU_CACHE_FLAGS_WRITE_BACK_YES << 2) | (nv_gpu.NV2080_CTRL_FB_FLUSH_GPU_CACHE_FLAGS_INVALIDATE_YES << 3) |
(nv_gpu.NV2080_CTRL_FB_FLUSH_GPU_CACHE_FLAGS_FLUSH_MODE_FULL_CACHE << 4)))