mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
search: hotfix to make sure TC behavior is all in applied_opts (#3598)
* search: hotfix to make sure TC behavior is all in applied_opts * fix linter error * fix mypy
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType
|
||||
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int, get_contraction, BEAM
|
||||
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -344,7 +344,7 @@ class Kernel:
|
||||
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int) -> bool:
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
@@ -353,10 +353,9 @@ class Kernel:
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
if mul_op.op != BinaryOps.MUL: continue
|
||||
|
||||
LOOSE_TC = getenv("BEAM_LOOSE_TC", 0)
|
||||
def buf_index(src: LazyOp) -> Optional[int]:
|
||||
if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
if BEAM and LOOSE_TC and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
||||
|
||||
@@ -364,7 +363,7 @@ class Kernel:
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501
|
||||
if not(axis_buf0 and axis_buf1 and reduce_sz%tc.dims[2] == 0 and reduce_sz >= tc.dims[2]): continue
|
||||
if not((self.shape_len-self.first_reduce) == 1 or (BEAM and LOOSE_TC)): continue
|
||||
if not((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1)): continue
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1))
|
||||
if not(axis < len(axis_choices)): continue
|
||||
@@ -418,9 +417,9 @@ class Kernel:
|
||||
|
||||
if opt.op == OptOps.TC:
|
||||
assert len(self.applied_opts) == 0, "tensor core opts must be first" # TODO: things like PADTO might be fine
|
||||
assert opt.axis is not None, "tensor core opts must have an axis"
|
||||
assert opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt"
|
||||
assert (use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2"
|
||||
assert self._apply_tc_opt(use_tensor_cores, opt.axis), "no tensor core available"
|
||||
assert self._apply_tc_opt(use_tensor_cores, opt.axis, opt.amt), "no tensor core available"
|
||||
self.applied_opts.append(opt)
|
||||
return
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,2
|
||||
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256] for axis in range(3)]
|
||||
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=0) for axis in range(4)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=1) for axis in range(4)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
|
||||
Reference in New Issue
Block a user