mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
parallel mcts (#5626)
* start work on parallel mcts * compile was linearizing twice * typing + more early stopping * fix compiler error
This commit is contained in:
@@ -3,14 +3,15 @@ from typing import List, Optional, Dict, cast
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
import math, functools, time, random, statistics
|
||||
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, flatten
|
||||
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _try_compile_linearized_w_idx, _time_program
|
||||
from tinygrad.device import Buffer, Device, CompileError
|
||||
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
|
||||
from tinygrad.ops import LazyOp
|
||||
|
||||
class MCTSNode:
|
||||
def __init__(self, kernel, parent=None):
|
||||
self.kernel = kernel
|
||||
def __init__(self, kernel:Kernel, parent=None):
|
||||
self.kernel:Kernel = kernel
|
||||
self.t = math.inf
|
||||
self.n = 0
|
||||
self.tm = math.inf
|
||||
@@ -38,9 +39,11 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
|
||||
ucb_explored_children = []
|
||||
for child in node.children:
|
||||
if child.n == 0: unexplored_children.append(child)
|
||||
elif not math.isinf(child.t):
|
||||
explored_children.append(child)
|
||||
ucb_explored_children.append(-child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n))
|
||||
else:
|
||||
ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
|
||||
if not math.isinf(ucb):
|
||||
explored_children.append(child)
|
||||
ucb_explored_children.append(ucb)
|
||||
if len(unexplored_children): return random.choice(unexplored_children)
|
||||
if not len(explored_children): return node
|
||||
ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP)
|
||||
@@ -87,34 +90,59 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
dev = Device[lin.opts.device]
|
||||
root = MCTSNode(lin)
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
|
||||
st = time.perf_counter()
|
||||
best, best_idx, best_tm = lin, 0, math.inf
|
||||
seen_libs: Dict[bytes, MCTSNode] = {}
|
||||
seen_asts: Dict[LazyOp, MCTSNode] = {}
|
||||
compile_time, runtime_time = 0.0, 0.0
|
||||
for i in range(amt):
|
||||
node = sample_tree(root, best_tm) # sample and expand
|
||||
if node is None: break # finished the whole tree
|
||||
node.i = i # when was node explored
|
||||
|
||||
# rollout
|
||||
_, compile_ret = _compile_fn((0, node.kernel))
|
||||
if compile_ret is None:
|
||||
tm = math.inf
|
||||
opt_ast = node.kernel.get_optimized_ast()
|
||||
if (sibling_node:=seen_asts.get(opt_ast, None)) is not None:
|
||||
# early check for same optimized AST hit
|
||||
remove_node(node)
|
||||
tm = sibling_node.t
|
||||
else:
|
||||
p, lib, _ = compile_ret
|
||||
if (sibling_node:=seen_libs.get(lib, None)) is not None:
|
||||
# remove this node, it's a duplicate
|
||||
remove_node(node)
|
||||
tm = sibling_node.t
|
||||
seen_asts[opt_ast] = node
|
||||
|
||||
# lowering (50% of the time)
|
||||
p = node.kernel.to_program(name_override="test")
|
||||
|
||||
# rollout
|
||||
tm1 = time.perf_counter()
|
||||
try:
|
||||
lib = dev.compiler.compile(p.src)
|
||||
except CompileError:
|
||||
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
|
||||
lib = None
|
||||
tm2 = time.perf_counter()
|
||||
if lib is None:
|
||||
tm = math.inf
|
||||
else:
|
||||
seen_libs[lib] = node
|
||||
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=5, early_stop=best_tm*10/1e6))*1e6
|
||||
except RuntimeError: tm = math.inf
|
||||
node.tm = tm
|
||||
if (sibling_node:=seen_libs.get(lib, None)) is not None:
|
||||
# NOTE: these should all be caught by the AST check, need to canonicalize
|
||||
# remove this node, it's a duplicate
|
||||
remove_node(node)
|
||||
tm = sibling_node.t
|
||||
else:
|
||||
seen_libs[lib] = node
|
||||
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
|
||||
except RuntimeError: tm = math.inf
|
||||
node.tm = tm
|
||||
tm3 = time.perf_counter()
|
||||
compile_time += tm2-tm1
|
||||
runtime_time += tm3-tm2
|
||||
|
||||
# mock rollout
|
||||
#node.tm = tm = random.random() + 0.1
|
||||
|
||||
if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
|
||||
if DEBUG>=2: print(f"\r{time.perf_counter() - st:7.2f}s: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
|
||||
et = time.perf_counter() - st
|
||||
if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# backprop
|
||||
backprop(node, tm)
|
||||
|
||||
@@ -681,6 +681,7 @@ class Kernel:
|
||||
if self.opts.device == "AMD":
|
||||
reduce_axes = [self.shape_len-self.upcasted]
|
||||
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
|
||||
fix_st2 = None
|
||||
elif self.opts.device == "METAL":
|
||||
|
||||
Reference in New Issue
Block a user