[MFMA] Switch between MFMA types (#352)

This PR introduces matrix_instr_nonkdim flag to switch
between MFMA 16 and MFMA 32.
This commit is contained in:
Alexander Efimov
2023-10-18 16:57:34 +02:00
committed by GitHub
parent 4d539d7dae
commit 20f316b19a
14 changed files with 371 additions and 247 deletions

View File

@@ -276,13 +276,13 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
@@ -364,7 +364,7 @@ class JITFunction(KernelInterface[T]):
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
@@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
@@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup: