mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix: tasteful ctrl-c in parallel beam
This commit is contained in:
@@ -29,7 +29,7 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: devicetests
|
- id: devicetests
|
||||||
name: select GPU tests
|
name: select GPU tests
|
||||||
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py
|
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||||
import itertools, random, math, time, multiprocessing, traceback
|
import itertools, random, math, time, multiprocessing, traceback, signal
|
||||||
from tinygrad.lazy import vars_from_ast
|
from tinygrad.lazy import vars_from_ast
|
||||||
from tinygrad.device import Device, Compiled, Buffer
|
from tinygrad.device import Device, Compiled, Buffer
|
||||||
from tinygrad.ops import MemBuffer
|
from tinygrad.ops import MemBuffer
|
||||||
@@ -101,6 +101,9 @@ def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs,
|
|||||||
if early_stop is not None and early_stop < tms[-1]: break
|
if early_stop is not None and early_stop < tms[-1]: break
|
||||||
return tms
|
return tms
|
||||||
|
|
||||||
|
# workers should ignore ctrl c
|
||||||
|
def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||||
|
|
||||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
||||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||||
@@ -112,32 +115,36 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
|||||||
seen_libs = set()
|
seen_libs = set()
|
||||||
|
|
||||||
default_parallel = 1 if Device.DEFAULT == "HIP" else 0
|
default_parallel = 1 if Device.DEFAULT == "HIP" else 0
|
||||||
pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None
|
pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None
|
||||||
|
|
||||||
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
try:
|
||||||
exiting, st = False, time.perf_counter()
|
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
||||||
dev = Device[Device.DEFAULT]
|
exiting, st = False, time.perf_counter()
|
||||||
assert isinstance(dev, Compiled)
|
dev = Device[Device.DEFAULT]
|
||||||
while not exiting:
|
assert isinstance(dev, Compiled)
|
||||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
while not exiting:
|
||||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||||
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
|
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||||
if proc is None: continue
|
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
|
||||||
lib, global_size, local_size = proc
|
if proc is None: continue
|
||||||
if lib in seen_libs: continue
|
lib, global_size, local_size = proc
|
||||||
seen_libs.add(lib)
|
if lib in seen_libs: continue
|
||||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
|
seen_libs.add(lib)
|
||||||
timed_lins.append((acted_lins[i], min(tms)))
|
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
|
||||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
timed_lins.append((acted_lins[i], min(tms)))
|
||||||
|
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
|
||||||
|
|
||||||
# done
|
# done
|
||||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||||
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
||||||
if not exiting: beam = opts[:amt]
|
if not exiting: beam = opts[:amt]
|
||||||
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
||||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
|
||||||
|
if pool is not None: pool.close() # the pool is closed
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
|
if pool is not None: pool.terminate()
|
||||||
|
raise e
|
||||||
|
|
||||||
if pool is not None: pool.close() # the pool is closed
|
|
||||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||||
return beam[0][0]
|
return beam[0][0]
|
||||||
|
|||||||
Reference in New Issue
Block a user