tc_select noop (#8801)

* tc_select noop

* revert changes in test
This commit is contained in:
Ignacio Sica
2025-01-29 15:53:23 -03:00
committed by GitHub
parent ec120ce6b9
commit 260df1a17f
7 changed files with 41 additions and 30 deletions

View File

@@ -26,7 +26,7 @@ class TestNV(unittest.TestCase):
def test_oor_kernels(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.TC, axis=6, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501
opts = [Opt(op=OptOps.TC, axis=6, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
def test_error_on_huge_dims(self):

View File

@@ -25,7 +25,7 @@ def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if x in s[-1].outputs else x for x in s[-1].bufs]
return s[-1].ast, bufs
def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0):
def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, acc_dtype=dtype_out)
@@ -34,7 +34,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
run_schedule(sched)
out = r.numpy()
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.apply_tensor_cores(1, axis=axis, tc_select=tc_select, tc_opt=tc_opt)
k.linearize()
assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
@@ -44,13 +44,14 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True):
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
r = a.matmul(b, acc_dtype=dtype_out)
sched = r.schedule()
realized_ast = sched[-1].ast
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.apply_tensor_cores(1, axis=axis, tc_select=tc_select, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.op is Ops.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
@@ -1959,7 +1960,7 @@ class TestKernelOpts(unittest.TestCase):
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
k.apply_opt(Opt(OptOps.TC, 0, (-1, 1)))
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_core_opts(self):

View File

@@ -35,7 +35,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.half, 0.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
k = Kernel(ast, opts=Device["METAL"].renderer)
k.required_optimizations()
for opt in opts: k.apply_opt(opt)

View File

@@ -740,7 +740,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, arg=1), Opt(op=OptOps.PADTO, axis=2, arg=32)]
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 1)), Opt(op=OptOps.PADTO, axis=2, arg=32)]
helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0)
def test_failure_30(self):
@@ -800,7 +800,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)]
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05)
def test_failure_33(self):
@@ -861,7 +861,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
ast_const(dtypes.float, 0.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=2)]
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=(-1, 2))]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_35(self): self.test_failure_34(True)
@@ -910,7 +910,7 @@ class TestLinearizerFailures(unittest.TestCase):
ast_const(dtypes.float, 0.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
for axis in [0,1,2,3,4,5]:
opts = [Opt(op=OptOps.TC, axis=axis, arg=2)]
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_38(self):
@@ -930,7 +930,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
for axis in [0,1,3,4]:
opts = [Opt(op=OptOps.TC, axis=axis, arg=2)]
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
@unittest.skip("very slow, similar to test_failure_37")
@@ -958,7 +958,7 @@ class TestLinearizerFailures(unittest.TestCase):
ast_const(dtypes.float, 0.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
for axis in [0,1,2,3,4,5]:
opts = [Opt(op=OptOps.TC, axis=axis, arg=2)]
opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_40(self):
@@ -996,7 +996,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts=[Opt(op=OptOps.TC, axis=5, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)]
opts=[Opt(op=OptOps.TC, axis=5, arg=(-1, 2)), Opt(op=OptOps.UNROLL, axis=0, arg=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02)
# llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile
@@ -1151,7 +1151,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 0)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_49(self):
@@ -1168,7 +1168,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_50(self):
@@ -1229,7 +1229,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=(
x6,)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, arg=2)]
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2))]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
@unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI")
@@ -1251,7 +1251,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)]
opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_53(self):
@@ -1305,7 +1305,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"])
@unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI")

View File

@@ -11,7 +11,7 @@ from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.linearize import linearize_uop
@@ -291,9 +291,10 @@ class Kernel:
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
for tc in self.opts.tensor_cores:
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
for tc in tensor_cores:
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
# can only fuse reduces with the same tc options
assert all_same(tensor_core_opts)
@@ -312,8 +313,9 @@ class Kernel:
return True
return False
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool:
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None,
tc_opt:Optional[int]=None) -> bool:
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
Keyword arguments:
@@ -322,15 +324,19 @@ class Kernel:
1: enable tensor cores
2: apply tensor core shape but don't use UOp.WMMA
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
[0-N]: uses only the n'th tensor core available; useful for search
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
"""
if tc_select is None: tc_select = TC_SELECT.value
if tc_opt is None: tc_opt = TC_OPT.value
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
try: # check TC first and apply hand-coded opts if successful
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
if (tc_opts:=self.tensor_core_opts) is not None:
if extra_opts is not None:
@@ -353,9 +359,12 @@ class Kernel:
if opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
check(opt.axis is not None and opt.arg is not None, "tensor core opts must have an axis and arg")
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.arg)), "no tensor core available")
check(opt.axis is not None, "tensor core opts must have an axis")
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
self.applied_opts.append(opt)
return

View File

@@ -19,8 +19,8 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
actions += [Opt(op=OptOps.TC, axis=0, arg=0)]
actions += [Opt(op=OptOps.TC, axis=axis, arg=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]

View File

@@ -105,7 +105,8 @@ class ContextVar:
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)