unset TRACK_MATCH_STATS while initing beam buffers [pr] (#7971)

This commit is contained in:
qazal
2024-11-30 10:56:58 -05:00
committed by GitHub
parent d0735d6489
commit bb8e319680
2 changed files with 3 additions and 3 deletions

View File

@@ -38,7 +38,7 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
if beam_compare == 2:
from tinygrad import Tensor
all_outs: List[List[Tensor]] = []
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0):
rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
(Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
@@ -47,7 +47,7 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0):
for bufs in zip(*all_outs):
for b in bufs[1:]:
if dtypes.is_float(bufs[0].dtype):

View File

@@ -47,7 +47,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:Dict[Variable, int], rawbuf
if clear_l2:
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
if early_stop is not None and early_stop < min(tms): break
return tms