beam in RKernel

This commit is contained in:
George Hotz
2025-08-29 18:17:23 -07:00
parent 6e57905c6d
commit 59081645f7
2 changed files with 37 additions and 19 deletions

View File

@@ -15,7 +15,7 @@ def shape_to_idx(s, axis_types, start=0):
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
if len(ast.full_shape) != len(axis_types):
if len(ast.full_shape) != len(axis_types) and ast.st is not None:
axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
return IndexContext(axis_types, [], 0)

View File

@@ -1,10 +1,11 @@
import math, itertools
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import UOp, Ops, sint, ssimplify, AxisType, KernelInfo, PatternMatcher, UPat
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, BEAM, getenv
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.device import Buffer
def flatten_range(r:UOp):
off = 2 if r.op is Ops.STORE else 1
@@ -44,12 +45,21 @@ class RKernel(Kernel):
self.maxarg = max([x.arg[0] for x in self.rng]) if len(self.rng) else 0
def substitute(self) -> UOp:
self.ast = self.ast.substitute(self.replaces)
self.replaces = {}
return self.ast
def copy(self):
self.substitute()
return RKernel(self.ast, self.opts)
# must be done earlier
def simplify_merge_adjacent(self): return
def apply_opt(self, opt:Opt, append_opt:bool=True) -> UOp|None:
if opt.op == OptOps.PADTO: raise RuntimeError("PAD is not supported yet. needs INVALID")
if opt.op == OptOps.SWAP: raise RuntimeError("SWAP is not supported yet")
if opt.op == OptOps.PADTO: raise KernelOptError("PAD is not supported yet. needs INVALID")
if opt.op == OptOps.SWAP: raise KernelOptError("SWAP is not supported yet")
return super().apply_opt(opt, append_opt)
def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None):
@@ -57,7 +67,7 @@ class RKernel(Kernel):
assert old_sz > 0, f"bad old_sz on {axis} {amount} {self.rng[axis]}"
self.maxarg += 1
new_rng = UOp.range(dtypes.int, amount, self.maxarg, new_type)
new_rng = UOp.range(amount, self.maxarg, new_type)
if old_sz == 1:
self.replaces[self.rng[axis]] = new_rng
@@ -92,25 +102,24 @@ class RKernel(Kernel):
@axis_types.setter
def axis_types(self, value): pass
def substitute(self) -> UOp:
self.ast = self.ast.substitute(self.replaces)
self.replaces = {}
return self.ast
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
if not len(reduceops): raise KernelOptError("no reduce ops")
reduceop = reduceops[0]
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
if mul.op is not Ops.MUL: return False
in0, in1 = mul.src
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
for tc in tensor_cores:
in0, in1 = list((reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]).src)
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0])
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: x.arg[0])
red_ranges = sorted([u for u in reduceop.src[1:]], key=lambda x: x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: x.arg[0])
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): return None
# pick ranges
@@ -156,10 +165,10 @@ class RKernel(Kernel):
# TODO: remove tc_upcast_axes from the arg
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
# preserve extra reduces
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
@@ -168,9 +177,18 @@ class RKernel(Kernel):
return True
return False
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
return [Buffer(dname, x.dtype.size, x.dtype.base) for x in glbls]
def apply_ropt(ast:UOp, renderer:Renderer):
k = RKernel(ast, opts=renderer)
if ast.arg is not None: k.apply_opts(ast.arg.opts_to_apply)
if BEAM >= 1:
from tinygrad.codegen.opt.search import beam_search
kb = RKernel(ast, opts=renderer)
rawbufs = bufs_from_ast(ast, renderer.device)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
elif ast.arg is not None: k.apply_opts(ast.arg.opts_to_apply)
return k.get_optimized_ast()
pm_postrange_opt = pm_flatten_range+PatternMatcher([