diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 1b5f5bce61..286058d802 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -4,7 +4,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps from tinygrad.device import Device, Compiled from tinygrad.dtype import dtypes, ImageDType, DType -from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int, get_contraction, BEAM +from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int, get_contraction from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sint from tinygrad.shape.view import View, strides_for_shape @@ -344,7 +344,7 @@ class Kernel: # ******************** high level optimizers ******************** - def _apply_tc_opt(self, use_tensor_cores:int, axis:int) -> bool: + def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool: if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores: for tc in tensor_cores[self.opts.device]: has_cast = tc.dtype_in != tc.dtype_out @@ -353,10 +353,9 @@ class Kernel: mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] if mul_op.op != BinaryOps.MUL: continue - LOOSE_TC = getenv("BEAM_LOOSE_TC", 0) def buf_index(src: LazyOp) -> Optional[int]: if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg)) - if BEAM and LOOSE_TC and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg)) + if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg)) return None if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue @@ -364,7 +363,7 @@ class Kernel: axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501 axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501 if not(axis_buf0 and axis_buf1 and reduce_sz%tc.dims[2] == 0 and reduce_sz >= tc.dims[2]): continue - if not((self.shape_len-self.first_reduce) == 1 or (BEAM and LOOSE_TC)): continue + if not((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1)): continue axis_choices = list(itertools.product(axis_buf0, axis_buf1)) if not(axis < len(axis_choices)): continue @@ -418,9 +417,9 @@ class Kernel: if opt.op == OptOps.TC: assert len(self.applied_opts) == 0, "tensor core opts must be first" # TODO: things like PADTO might be fine - assert opt.axis is not None, "tensor core opts must have an axis" + assert opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt" assert (use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2" - assert self._apply_tc_opt(use_tensor_cores, opt.axis), "no tensor core available" + assert self._apply_tc_opt(use_tensor_cores, opt.axis, opt.amt), "no tensor core available" self.applied_opts.append(opt) return diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index cfa19f2f4f..09d81dd2b7 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -16,8 +16,8 @@ actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,2 actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256] for axis in range(3)] actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] -actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4)] -actions += [Opt(op=OptOps.TC, axis=axis, amt=0) for axis in range(4)] +actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)] +actions += [Opt(op=OptOps.TC, axis=axis, amt=1) for axis in range(4)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals):