parallel mcts (#5626)

* start work on parallel mcts

* compile was linearizing twice

* typing + more early stopping

* fix compiler error
This commit is contained in:
George Hotz
2024-07-21 14:53:23 -07:00
committed by GitHub
parent c56c9c7519
commit 6c6d74d922
2 changed files with 52 additions and 23 deletions

View File

@@ -3,14 +3,15 @@ from typing import List, Optional, Dict, cast
import numpy as np
np.set_printoptions(suppress=True)
import math, functools, time, random, statistics
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, flatten
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Buffer, Device
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _try_compile_linearized_w_idx, _time_program
from tinygrad.device import Buffer, Device, CompileError
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
from tinygrad.ops import LazyOp
class MCTSNode:
def __init__(self, kernel, parent=None):
self.kernel = kernel
def __init__(self, kernel:Kernel, parent=None):
self.kernel:Kernel = kernel
self.t = math.inf
self.n = 0
self.tm = math.inf
@@ -38,9 +39,11 @@ def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
ucb_explored_children = []
for child in node.children:
if child.n == 0: unexplored_children.append(child)
elif not math.isinf(child.t):
explored_children.append(child)
ucb_explored_children.append(-child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n))
else:
ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
if not math.isinf(ucb):
explored_children.append(child)
ucb_explored_children.append(ucb)
if len(unexplored_children): return random.choice(unexplored_children)
if not len(explored_children): return node
ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP)
@@ -87,34 +90,59 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
dev = Device[lin.opts.device]
root = MCTSNode(lin)
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
st = time.perf_counter()
best, best_idx, best_tm = lin, 0, math.inf
seen_libs: Dict[bytes, MCTSNode] = {}
seen_asts: Dict[LazyOp, MCTSNode] = {}
compile_time, runtime_time = 0.0, 0.0
for i in range(amt):
node = sample_tree(root, best_tm) # sample and expand
if node is None: break # finished the whole tree
node.i = i # when was node explored
# rollout
_, compile_ret = _compile_fn((0, node.kernel))
if compile_ret is None:
tm = math.inf
opt_ast = node.kernel.get_optimized_ast()
if (sibling_node:=seen_asts.get(opt_ast, None)) is not None:
# early check for same optimized AST hit
remove_node(node)
tm = sibling_node.t
else:
p, lib, _ = compile_ret
if (sibling_node:=seen_libs.get(lib, None)) is not None:
# remove this node, it's a duplicate
remove_node(node)
tm = sibling_node.t
seen_asts[opt_ast] = node
# lowering (50% of the time)
p = node.kernel.to_program(name_override="test")
# rollout
tm1 = time.perf_counter()
try:
lib = dev.compiler.compile(p.src)
except CompileError:
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
lib = None
tm2 = time.perf_counter()
if lib is None:
tm = math.inf
else:
seen_libs[lib] = node
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=5, early_stop=best_tm*10/1e6))*1e6
except RuntimeError: tm = math.inf
node.tm = tm
if (sibling_node:=seen_libs.get(lib, None)) is not None:
# NOTE: these should all be caught by the AST check, need to canonicalize
# remove this node, it's a duplicate
remove_node(node)
tm = sibling_node.t
else:
seen_libs[lib] = node
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
except RuntimeError: tm = math.inf
node.tm = tm
tm3 = time.perf_counter()
compile_time += tm2-tm1
runtime_time += tm3-tm2
# mock rollout
#node.tm = tm = random.random() + 0.1
if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
if DEBUG>=2: print(f"\r{time.perf_counter() - st:7.2f}s: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
et = time.perf_counter() - st
if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
# backprop
backprop(node, tm)

View File

@@ -681,6 +681,7 @@ class Kernel:
if self.opts.device == "AMD":
reduce_axes = [self.shape_len-self.upcasted]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
# https://gpuopen.com/learn/wmma_on_rdna3/
fix_st1 = functools.partial(fix_st, (8,2,2), (16,8), (16,2,4), ((1,2), (0,2), (1,1), (0,1)), ((1,0), (0,0)))
fix_st2 = None
elif self.opts.device == "METAL":