hotfix: tasteful ctrl-c in parallel beam

This commit is contained in:
George Hotz
2023-12-05 18:20:10 +00:00
parent 35b5e95097
commit ec594cf03c
2 changed files with 32 additions and 25 deletions

View File

@@ -29,7 +29,7 @@ repos:
pass_filenames: false pass_filenames: false
- id: devicetests - id: devicetests
name: select GPU tests name: select GPU tests
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false

View File

@@ -1,5 +1,5 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time, multiprocessing, traceback import itertools, random, math, time, multiprocessing, traceback, signal
from tinygrad.lazy import vars_from_ast from tinygrad.lazy import vars_from_ast
from tinygrad.device import Device, Compiled, Buffer from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer from tinygrad.ops import MemBuffer
@@ -101,6 +101,9 @@ def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs,
if early_stop is not None and early_stop < tms[-1]: break if early_stop is not None and early_stop < tms[-1]: break
return tms return tms
# workers should ignore ctrl c
def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT} key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1: if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
@@ -112,32 +115,36 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
seen_libs = set() seen_libs = set()
default_parallel = 1 if Device.DEFAULT == "HIP" else 0 default_parallel = 1 if Device.DEFAULT == "HIP" else 0
pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} try:
exiting, st = False, time.perf_counter() var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
dev = Device[Device.DEFAULT] exiting, st = False, time.perf_counter()
assert isinstance(dev, Compiled) dev = Device[Device.DEFAULT]
while not exiting: assert isinstance(dev, Compiled)
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin] while not exiting:
timed_lins: List[Tuple[Linearizer, float]] = [] acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): timed_lins: List[Tuple[Linearizer, float]] = []
if proc is None: continue for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
lib, global_size, local_size = proc if proc is None: continue
if lib in seen_libs: continue lib, global_size, local_size = proc
seen_libs.add(lib) if lib in seen_libs: continue
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None) seen_libs.add(lib)
timed_lins.append((acted_lins[i], min(tms))) tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") timed_lins.append((acted_lins[i], min(tms)))
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
# done # done
opts = sorted(timed_lins, key=lambda x: x[1]) opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1]) exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
if not exiting: beam = opts[:amt] if not exiting: beam = opts[:amt]
assert len(beam) > 0, "no BEAM items succeeded?!?" assert len(beam) > 0, "no BEAM items succeeded?!?"
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
if pool is not None: pool.close() # the pool is closed
except KeyboardInterrupt as e:
if pool is not None: pool.terminate()
raise e
if pool is not None: pool.close() # the pool is closed
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts) if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0] return beam[0][0]