Add waves_per_eu as kernel parameter (#319)

* Add waves_per_eu as kernel parameter

* Fix failing tests

* Add default value for waves_per_eu for ttgir_to_llir function

* Remove aot.py
This commit is contained in:
oplavsic
2023-10-06 19:08:34 +02:00
committed by GitHub
parent be95edc63f
commit e801638b40
7 changed files with 36 additions and 26 deletions

View File

@@ -276,13 +276,13 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
@@ -292,7 +292,7 @@ class JITFunction(KernelInterface[T]):
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
@@ -364,7 +364,7 @@ class JITFunction(KernelInterface[T]):
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
@@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
@@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup: