mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
beam in RKernel
This commit is contained in:
@@ -15,7 +15,7 @@ def shape_to_idx(s, axis_types, start=0):
|
||||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
if len(ast.full_shape) != len(axis_types):
|
||||
if len(ast.full_shape) != len(axis_types) and ast.st is not None:
|
||||
axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
|
||||
return IndexContext(axis_types, [], 0)
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import math, itertools
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, ssimplify, AxisType, KernelInfo, PatternMatcher, UPat
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, BEAM, getenv
|
||||
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Buffer
|
||||
|
||||
def flatten_range(r:UOp):
|
||||
off = 2 if r.op is Ops.STORE else 1
|
||||
@@ -44,12 +45,21 @@ class RKernel(Kernel):
|
||||
|
||||
self.maxarg = max([x.arg[0] for x in self.rng]) if len(self.rng) else 0
|
||||
|
||||
def substitute(self) -> UOp:
|
||||
self.ast = self.ast.substitute(self.replaces)
|
||||
self.replaces = {}
|
||||
return self.ast
|
||||
|
||||
def copy(self):
|
||||
self.substitute()
|
||||
return RKernel(self.ast, self.opts)
|
||||
|
||||
# must be done earlier
|
||||
def simplify_merge_adjacent(self): return
|
||||
|
||||
def apply_opt(self, opt:Opt, append_opt:bool=True) -> UOp|None:
|
||||
if opt.op == OptOps.PADTO: raise RuntimeError("PAD is not supported yet. needs INVALID")
|
||||
if opt.op == OptOps.SWAP: raise RuntimeError("SWAP is not supported yet")
|
||||
if opt.op == OptOps.PADTO: raise KernelOptError("PAD is not supported yet. needs INVALID")
|
||||
if opt.op == OptOps.SWAP: raise KernelOptError("SWAP is not supported yet")
|
||||
return super().apply_opt(opt, append_opt)
|
||||
|
||||
def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None):
|
||||
@@ -57,7 +67,7 @@ class RKernel(Kernel):
|
||||
assert old_sz > 0, f"bad old_sz on {axis} {amount} {self.rng[axis]}"
|
||||
|
||||
self.maxarg += 1
|
||||
new_rng = UOp.range(dtypes.int, amount, self.maxarg, new_type)
|
||||
new_rng = UOp.range(amount, self.maxarg, new_type)
|
||||
|
||||
if old_sz == 1:
|
||||
self.replaces[self.rng[axis]] = new_rng
|
||||
@@ -92,25 +102,24 @@ class RKernel(Kernel):
|
||||
@axis_types.setter
|
||||
def axis_types(self, value): pass
|
||||
|
||||
def substitute(self) -> UOp:
|
||||
self.ast = self.ast.substitute(self.replaces)
|
||||
self.replaces = {}
|
||||
return self.ast
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
|
||||
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
|
||||
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
||||
if not len(reduceops): raise KernelOptError("no reduce ops")
|
||||
reduceop = reduceops[0]
|
||||
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
|
||||
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
|
||||
if mul.op is not Ops.MUL: return False
|
||||
in0, in1 = mul.src
|
||||
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
||||
for tc in tensor_cores:
|
||||
in0, in1 = list((reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]).src)
|
||||
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
||||
# tensor cores have three ranges. X, Y, and REDUCE
|
||||
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0])
|
||||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: x.arg[0])
|
||||
red_ranges = sorted([u for u in reduceop.src[1:]], key=lambda x: x.arg[0])
|
||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: x.arg[0])
|
||||
if DEBUG >= 3:
|
||||
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): return None
|
||||
|
||||
# pick ranges
|
||||
@@ -156,10 +165,10 @@ class RKernel(Kernel):
|
||||
# TODO: remove tc_upcast_axes from the arg
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
|
||||
|
||||
# preserve extra reduces
|
||||
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
|
||||
@@ -168,9 +177,18 @@ class RKernel(Kernel):
|
||||
return True
|
||||
return False
|
||||
|
||||
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
|
||||
glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
|
||||
return [Buffer(dname, x.dtype.size, x.dtype.base) for x in glbls]
|
||||
|
||||
def apply_ropt(ast:UOp, renderer:Renderer):
|
||||
k = RKernel(ast, opts=renderer)
|
||||
if ast.arg is not None: k.apply_opts(ast.arg.opts_to_apply)
|
||||
if BEAM >= 1:
|
||||
from tinygrad.codegen.opt.search import beam_search
|
||||
kb = RKernel(ast, opts=renderer)
|
||||
rawbufs = bufs_from_ast(ast, renderer.device)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
elif ast.arg is not None: k.apply_opts(ast.arg.opts_to_apply)
|
||||
return k.get_optimized_ast()
|
||||
|
||||
pm_postrange_opt = pm_flatten_range+PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user