mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
hotfix: tasteful ctrl-c in parallel beam
This commit is contained in:
@@ -29,7 +29,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: devicetests
|
||||
name: select GPU tests
|
||||
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py
|
||||
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, random, math, time, multiprocessing, traceback
|
||||
import itertools, random, math, time, multiprocessing, traceback, signal
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
from tinygrad.ops import MemBuffer
|
||||
@@ -101,6 +101,9 @@ def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs,
|
||||
if early_stop is not None and early_stop < tms[-1]: break
|
||||
return tms
|
||||
|
||||
# workers should ignore ctrl c
|
||||
def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||
@@ -112,32 +115,36 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = 1 if Device.DEFAULT == "HIP" else 0
|
||||
pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None
|
||||
pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[Device.DEFAULT]
|
||||
assert isinstance(dev, Compiled)
|
||||
while not exiting:
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size = proc
|
||||
if lib in seen_libs: continue
|
||||
seen_libs.add(lib)
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
||||
try:
|
||||
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[Device.DEFAULT]
|
||||
assert isinstance(dev, Compiled)
|
||||
while not exiting:
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size = proc
|
||||
if lib in seen_libs: continue
|
||||
seen_libs.add(lib)
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
||||
if not exiting: beam = opts[:amt]
|
||||
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
||||
if not exiting: beam = opts[:amt]
|
||||
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
||||
if pool is not None: pool.close() # the pool is closed
|
||||
except KeyboardInterrupt as e:
|
||||
if pool is not None: pool.terminate()
|
||||
raise e
|
||||
|
||||
if pool is not None: pool.close() # the pool is closed
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||
return beam[0][0]
|
||||
|
||||
Reference in New Issue
Block a user