From ec594cf03cca0b6006457164fc15515b817033bc Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 5 Dec 2023 18:20:10 +0000 Subject: [PATCH] hotfix: tasteful ctrl-c in parallel beam --- .pre-commit-config.yaml | 2 +- tinygrad/features/search.py | 55 +++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c1a8072fe5..b9b3bb0858 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: pass_filenames: false - id: devicetests name: select GPU tests - entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py + entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py language: system always_run: true pass_filenames: false diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 4b01785324..1f47659a6a 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,5 +1,5 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable -import itertools, random, math, time, multiprocessing, traceback +import itertools, random, math, time, multiprocessing, traceback, signal from tinygrad.lazy import vars_from_ast from tinygrad.device import Device, Compiled, Buffer from tinygrad.ops import MemBuffer @@ -101,6 +101,9 @@ def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, if early_stop is not None and early_stop < tms[-1]: break return tms +# workers should ignore ctrl c +def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) + def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT} if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1: @@ -112,32 +115,36 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea seen_libs = set() default_parallel = 1 if Device.DEFAULT == "HIP" else 0 - pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None + pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None - var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} - exiting, st = False, time.perf_counter() - dev = Device[Device.DEFAULT] - assert isinstance(dev, Compiled) - while not exiting: - acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin] - timed_lins: List[Tuple[Linearizer, float]] = [] - for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): - if proc is None: continue - lib, global_size, local_size = proc - if lib in seen_libs: continue - seen_libs.add(lib) - tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None) - timed_lins.append((acted_lins[i], min(tms))) - if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") + try: + var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} + exiting, st = False, time.perf_counter() + dev = Device[Device.DEFAULT] + assert isinstance(dev, Compiled) + while not exiting: + acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin] + timed_lins: List[Tuple[Linearizer, float]] = [] + for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): + if proc is None: continue + lib, global_size, local_size = proc + if lib in seen_libs: continue + seen_libs.add(lib) + tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None) + timed_lins.append((acted_lins[i], min(tms))) + if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") - # done - opts = sorted(timed_lins, key=lambda x: x[1]) - exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1]) - if not exiting: beam = opts[:amt] - assert len(beam) > 0, "no BEAM items succeeded?!?" - if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) + # done + opts = sorted(timed_lins, key=lambda x: x[1]) + exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1]) + if not exiting: beam = opts[:amt] + assert len(beam) > 0, "no BEAM items succeeded?!?" + if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) + if pool is not None: pool.close() # the pool is closed + except KeyboardInterrupt as e: + if pool is not None: pool.terminate() + raise e - if pool is not None: pool.close() # the pool is closed if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if DEBUG >= 3: print(beam[0][0].applied_opts) return beam[0][0]