mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Dot slicing pass (#440)
* First commit * Implement DotSlicing pass. * small fixes * Support chained dot in DotSlicingPass (second GEMM in FA) * Add lit test for FA dot slicing --------- Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com> Co-authored-by: Ognjen <oplavsic@luxoft.com>
This commit is contained in:
@@ -100,7 +100,7 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
||||
enable_persistent, optimize_epilogue, matrix_inst_type):
|
||||
enable_persistent, optimize_epilogue, matrix_inst_type, slice_k_tile):
|
||||
is_cuda = _is_cuda(target)
|
||||
if is_cuda:
|
||||
capability = target.capability
|
||||
@@ -123,6 +123,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritonamdgpu_dot_slicing_pass(slice_k_tile)
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
@@ -273,6 +274,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
slice_k_tile = kwargs.get("slice_k_tile", 0)
|
||||
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0);
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
@@ -282,7 +284,7 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{slice_k_tile}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
@@ -414,6 +416,7 @@ def compile(fn, **kwargs):
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
slice_k_tile = kwargs.get("slice_k_tile", 0)
|
||||
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
|
||||
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
|
||||
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
|
||||
@@ -453,7 +456,7 @@ def compile(fn, **kwargs):
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
||||
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
||||
add_cuda_stages(target, extern_libs, stages)
|
||||
@@ -472,12 +475,13 @@ def compile(fn, **kwargs):
|
||||
other["optimize_epilogue"] = optimize_epilogue
|
||||
other["tma_infos"] = tma_infos
|
||||
other["waves_per_eu"] = waves_per_eu
|
||||
other["slice_k_tile"] = slice_k_tile
|
||||
other["matrix_instr_nonkdim"] = matrix_instr_nonkdim
|
||||
|
||||
_device_backend.add_stages(target, extern_libs, stages, other)
|
||||
elif device_type == "xpu":
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
@@ -556,6 +560,7 @@ def compile(fn, **kwargs):
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"slice_k_tile": slice_k_tile,
|
||||
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
@@ -689,6 +694,7 @@ class CompiledKernel:
|
||||
self.num_ctas = metadata["num_ctas"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.waves_per_eu = metadata["waves_per_eu"]
|
||||
self.slice_k_tile = metadata["slice_k_tile"]
|
||||
self.clusterDims = metadata["clusterDims"]
|
||||
if "tensormaps_info" in metadata:
|
||||
self.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
@@ -351,6 +351,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
slice_k_tile,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
@@ -363,7 +364,7 @@ class JITFunction(KernelInterface[T]):
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{param.name}: {ty}' for param, ty in zip(self.params, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, slice_k_tile={slice_k_tile}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -381,6 +382,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
slice_k_tile=slice_k_tile,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
extern_libs=extern_libs,
|
||||
@@ -427,6 +429,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas = get_special_arg("num_ctas", 1)
|
||||
num_stages = get_special_arg("num_stages")
|
||||
waves_per_eu = get_special_arg("waves_per_eu", 0)
|
||||
slice_k_tile = get_special_arg("slice_k_tile", 0)
|
||||
matrix_instr_nonkdim = get_special_arg("matrix_instr_nonkdim", 0)
|
||||
enable_warp_specialization = get_special_arg("enable_warp_specialization", False)
|
||||
enable_fp_fusion = get_special_arg("enable_fp_fusion", True)
|
||||
@@ -503,6 +506,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
slice_k_tile,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
@@ -539,6 +543,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas,
|
||||
num_stages,
|
||||
waves_per_eu,
|
||||
slice_k_tile,
|
||||
matrix_instr_nonkdim,
|
||||
enable_warp_specialization,
|
||||
enable_fp_fusion,
|
||||
@@ -556,6 +561,7 @@ class JITFunction(KernelInterface[T]):
|
||||
num_ctas=num_ctas,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu=waves_per_eu,
|
||||
slice_k_tile=slice_k_tile,
|
||||
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
||||
enable_warp_specialization=enable_warp_specialization,
|
||||
enable_fp_fusion=enable_fp_fusion,
|
||||
|
||||
3
python/triton/third_party/hip/hip_backend.py
vendored
3
python/triton/third_party/hip/hip_backend.py
vendored
@@ -449,10 +449,11 @@ class HIPBackend(BaseBackend):
|
||||
optimize_epilogue = other["optimize_epilogue"]
|
||||
tma_infos = other["tma_infos"]
|
||||
waves_per_eu = other["waves_per_eu"]
|
||||
slice_k_tile = other["slice_k_tile"]
|
||||
matrix_instr_nonkdim = other["matrix_instr_nonkdim"]
|
||||
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim, slice_k_tile))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user