search: hotfix to make sure TC behavior is all in applied_opts (#3598)

* search: hotfix to make sure TC behavior is all in applied_opts

* fix linter error

* fix mypy
This commit is contained in:
Francis Lam
2024-03-03 18:44:38 -08:00
committed by GitHub
parent 8e5d60a322
commit 7c90005c65
2 changed files with 8 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
from tinygrad.device import Device, Compiled
from tinygrad.dtype import dtypes, ImageDType, DType
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int, get_contraction, BEAM
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int, get_contraction
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import View, strides_for_shape
@@ -344,7 +344,7 @@ class Kernel:
# ******************** high level optimizers ********************
def _apply_tc_opt(self, use_tensor_cores:int, axis:int) -> bool:
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
for tc in tensor_cores[self.opts.device]:
has_cast = tc.dtype_in != tc.dtype_out
@@ -353,10 +353,9 @@ class Kernel:
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
if mul_op.op != BinaryOps.MUL: continue
LOOSE_TC = getenv("BEAM_LOOSE_TC", 0)
def buf_index(src: LazyOp) -> Optional[int]:
if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
if BEAM and LOOSE_TC and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
@@ -364,7 +363,7 @@ class Kernel:
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501
if not(axis_buf0 and axis_buf1 and reduce_sz%tc.dims[2] == 0 and reduce_sz >= tc.dims[2]): continue
if not((self.shape_len-self.first_reduce) == 1 or (BEAM and LOOSE_TC)): continue
if not((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1)): continue
axis_choices = list(itertools.product(axis_buf0, axis_buf1))
if not(axis < len(axis_choices)): continue
@@ -418,9 +417,9 @@ class Kernel:
if opt.op == OptOps.TC:
assert len(self.applied_opts) == 0, "tensor core opts must be first" # TODO: things like PADTO might be fine
assert opt.axis is not None, "tensor core opts must have an axis"
assert opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt"
assert (use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2"
assert self._apply_tc_opt(use_tensor_cores, opt.axis), "no tensor core available"
assert self._apply_tc_opt(use_tensor_cores, opt.axis, opt.amt), "no tensor core available"
self.applied_opts.append(opt)
return

View File

@@ -16,8 +16,8 @@ actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,2
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256] for axis in range(3)]
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4)]
actions += [Opt(op=OptOps.TC, axis=axis, amt=0) for axis in range(4)]
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
actions += [Opt(op=OptOps.TC, axis=axis, amt=1) for axis in range(4)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def _get_test_global_size(global_size, max_global_size, var_vals):