mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add a flag to log when beam surpassed max limit [pr] (#9533)
This commit is contained in:
@@ -6,6 +6,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BEAM_LOG_SURPASS_MAX=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export RESET_STEP=1
|
||||
|
||||
@@ -17,7 +17,7 @@ DATETIME=$(date "+%m%d%H%M")
|
||||
LOGFILE="bert_green_${DATETIME}_${SEED}.log"
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 BEAM_LOG_SURPASS_MAX=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
|
||||
# run
|
||||
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE
|
||||
|
||||
@@ -6,6 +6,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BEAM_LOG_SURPASS_MAX=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export RESET_STEP=1
|
||||
|
||||
@@ -17,7 +17,7 @@ DATETIME=$(date "+%m%d%H%M")
|
||||
LOGFILE="bert_red_${DATETIME}_${SEED}.log"
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 BEAM_LOG_SURPASS_MAX=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
|
||||
# run
|
||||
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE
|
||||
|
||||
@@ -65,7 +65,9 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
|
||||
try:
|
||||
p = x[1].to_program(name_override="test")
|
||||
assert p.uops is not None, "uop list wasn't generated?"
|
||||
if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
||||
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
|
||||
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
|
||||
raise RuntimeError("too many uops")
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(p.src)
|
||||
et = time.perf_counter() - st
|
||||
@@ -121,7 +123,9 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
||||
for s,c in zip(lin2.full_shape, lin2.colors()):
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
elif c in {"cyan", "green", "white"}: lcl *= s
|
||||
if up//tc_up > max_up or lcl > max_lcl: continue
|
||||
if up//tc_up > max_up or lcl > max_lcl:
|
||||
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
|
||||
continue
|
||||
acted_lins[i+1] = lin2
|
||||
except KernelOptError: pass
|
||||
return acted_lins
|
||||
|
||||
Reference in New Issue
Block a user