mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hcq cache invalidation for beam (#5630)
* nv full cache invalidation * the same command on amd * linter * fix amd * nv no hardcoded consts * beam default
This commit is contained in:
@@ -165,6 +165,7 @@ class NVDriver(VirtDriver):
|
||||
params.workSubmitToken = gpu_fifo.token
|
||||
elif struct.cmd == nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE: pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_PERF_BOOST: pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_FB_FLUSH_GPU_CACHE: pass
|
||||
else: raise RuntimeError(f"Unknown {struct.cmd} to rm_control")
|
||||
return 0
|
||||
|
||||
|
||||
@@ -46,7 +46,9 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
||||
input_bufs = [rawbufs[i] for i,_ in car.p.globals]
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
|
||||
else:
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
||||
if early_stop is not None and early_stop < min(tms): break
|
||||
return tms
|
||||
@@ -150,7 +152,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
||||
if lib in seen_libs: continue
|
||||
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
||||
except RuntimeError: continue # for runtime issues
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
|
||||
@@ -492,3 +492,8 @@ class AMDDevice(HCQCompiled):
|
||||
self.kernargs_ptr = self.kernargs.va_addr
|
||||
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
||||
if PROFILE: self._prof_process_events()
|
||||
|
||||
def invalidate_cache(self):
|
||||
AMDComputeQueue().memory_barrier().signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
self.timeline_value += 1
|
||||
self.synchronize()
|
||||
|
||||
@@ -586,3 +586,8 @@ class NVDevice(HCQCompiled):
|
||||
NVComputeQueue().setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
|
||||
.signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
self.timeline_value += 1
|
||||
|
||||
def invalidate_caches(self):
|
||||
rmctrl.fb_flush_gpu_cache(self.fd_ctl, self.root, self.subdevice,
|
||||
flags=((nv_gpu.NV2080_CTRL_FB_FLUSH_GPU_CACHE_FLAGS_WRITE_BACK_YES << 2) | (nv_gpu.NV2080_CTRL_FB_FLUSH_GPU_CACHE_FLAGS_INVALIDATE_YES << 3) |
|
||||
(nv_gpu.NV2080_CTRL_FB_FLUSH_GPU_CACHE_FLAGS_FLUSH_MODE_FULL_CACHE << 4)))
|
||||
|
||||
Reference in New Issue
Block a user