mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Add waves_per_eu as kernel parameter (#319)
* Add waves_per_eu as kernel parameter * Fix failing tests * Add default value for waves_per_eu for ttgir_to_llir function * Remove aot.py
This commit is contained in:
@@ -276,13 +276,13 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -292,7 +292,7 @@ class JITFunction(KernelInterface[T]):
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
@@ -364,7 +364,7 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
@@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
|
||||
Reference in New Issue
Block a user