mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts: bin/triton-translate.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/triton/compiler/compiler.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -453,7 +453,7 @@ def _unwrap(tensor):
|
||||
|
||||
builder = Builder()
|
||||
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
|
||||
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion']
|
||||
|
||||
|
||||
class GridExecutor:
|
||||
|
||||
@@ -281,13 +281,21 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
<<<<<<< HEAD
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
|
||||
=======
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
<<<<<<< HEAD
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
=======
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -297,7 +305,11 @@ class JITFunction(KernelInterface[T]):
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
<<<<<<< HEAD
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
=======
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs,
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
@@ -351,7 +363,11 @@ class JITFunction(KernelInterface[T]):
|
||||
def regular_args_v(args_proxy):
|
||||
return [args_proxy[arg_name] for arg_name in regular_args]
|
||||
|
||||
<<<<<<< HEAD
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type):
|
||||
=======
|
||||
def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type):
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
from ..compiler import (CompiledKernel, compile,
|
||||
get_arch_default_num_stages,
|
||||
get_arch_default_num_warps)
|
||||
@@ -402,7 +418,11 @@ class JITFunction(KernelInterface[T]):
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
<<<<<<< HEAD
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug)
|
||||
=======
|
||||
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -430,8 +450,13 @@ class JITFunction(KernelInterface[T]):
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
<<<<<<< HEAD
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
=======
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
@@ -446,8 +471,13 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
src = f"""
|
||||
import triton
|
||||
<<<<<<< HEAD
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type)
|
||||
=======
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type)
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
"""
|
||||
scope = {"launcher_body": launcher_body}
|
||||
exec(src, scope)
|
||||
|
||||
Reference in New Issue
Block a user