mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ROCM] Core Functionality for AMD (#1983)
* this pr adds a third party backend for triton that works on AMD * this expose a lot of the work that has been done in our [fork](https://github.com/ROCmSoftwarePlatform/triton) * most unit tests on `test_core.py` pass * it skips some unit tests for various reasons * we plan to follow up with more prs improving Functionality and Performance in the future --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import importlib.util
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from ..runtime.driver import DriverBase
|
||||
@@ -94,7 +95,7 @@ def get_backend(device_type: str):
|
||||
try:
|
||||
importlib.import_module(device_backend_package_name, package=__spec__.name)
|
||||
except Exception:
|
||||
return None
|
||||
traceback.print_exc()
|
||||
else:
|
||||
return None
|
||||
return _backends[device_type] if device_type in _backends else None
|
||||
|
||||
@@ -5,19 +5,18 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir, get_arch_info,
|
||||
get_warp_size)
|
||||
translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
@@ -214,71 +213,6 @@ def ptx_to_cubin(ptx: str, arch: int):
|
||||
return compile_ptx_to_cubin(ptx, ptxas, arch)
|
||||
|
||||
|
||||
# AMDGCN translation
|
||||
|
||||
def get_amdgcn_bitcode_paths(arch):
|
||||
gpu_arch_agnostic_bitcode_libraries = ["opencl.bc",
|
||||
"ocml.bc",
|
||||
"ockl.bc",
|
||||
"oclc_finite_only_off.bc",
|
||||
"oclc_daz_opt_off.bc",
|
||||
"oclc_correctly_rounded_sqrt_on.bc",
|
||||
"oclc_unsafe_math_off.bc",
|
||||
"oclc_wavefrontsize64_on.bc",
|
||||
"oclc_abi_version_400.bc",]
|
||||
|
||||
gfx_arch = arch[1]
|
||||
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
|
||||
|
||||
gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc"
|
||||
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
|
||||
|
||||
amdgcn_bitcode_paths = {}
|
||||
i = 0
|
||||
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
|
||||
bc_path = bitcode_path_dir + bc_lib
|
||||
if os.path.exists(bc_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_path
|
||||
i += 1
|
||||
bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library
|
||||
if os.path.exists(bc_gfx_path):
|
||||
amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path
|
||||
|
||||
return amdgcn_bitcode_paths
|
||||
|
||||
|
||||
def get_amdgpu_arch_fulldetails():
|
||||
"""
|
||||
get the amdgpu full ISA details for compiling:
|
||||
i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack-
|
||||
"""
|
||||
try:
|
||||
# TODO: package rocm.cc with Triton
|
||||
arch_info = get_arch_info()
|
||||
warp_size = get_warp_size()
|
||||
gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--')
|
||||
arch_triple = gfx_arch_details[0]
|
||||
arch_name_features = gfx_arch_details[1].split(':')
|
||||
arch_name = arch_name_features[0]
|
||||
arch_features = ""
|
||||
|
||||
return [arch_triple, arch_name, arch_features, warp_size]
|
||||
except BaseException as e:
|
||||
print("Error: Attempting to get amgpu ISA Details {}".format(e))
|
||||
return None
|
||||
|
||||
|
||||
def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]:
|
||||
'''
|
||||
Translate TritonGPU module to HSACO code based on full details of gpu architecture.
|
||||
:param mod: a TritonGPU dialect module
|
||||
:return:
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# compiler
|
||||
# ------------------------------------------------------------------------------
|
||||
@@ -347,8 +281,10 @@ arg_type_pattern = {
|
||||
"ttgir": mlir_arg_type_pattern,
|
||||
"ptx": ptx_arg_type_pattern,
|
||||
}
|
||||
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
if is_hip():
|
||||
ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:'
|
||||
else:
|
||||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
||||
|
||||
|
||||
def _get_jsonable_constants(constants):
|
||||
@@ -389,17 +325,10 @@ def is_hip():
|
||||
from ..language.semantic import gpu_matrix_core_version
|
||||
|
||||
def get_architecture_descriptor(capability):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
if capability is None:
|
||||
if torch.version.hip is None:
|
||||
device = get_current_device()
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
else:
|
||||
capability = get_amdgpu_arch_fulldetails()
|
||||
device = get_current_device()
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return capability
|
||||
|
||||
|
||||
@@ -429,23 +358,6 @@ def get_arch_default_num_stages(device_type, capability=None):
|
||||
return num_stages
|
||||
|
||||
|
||||
def add_rocm_stages(arch, extern_libs, stages):
|
||||
extern_libs.update(get_amdgcn_bitcode_paths(arch))
|
||||
|
||||
for key in list(extern_libs):
|
||||
if extern_libs[key] == '' or extern_libs[key] is None:
|
||||
extern_libs.pop(key)
|
||||
|
||||
gfx_arch_full_details = arch
|
||||
gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1])
|
||||
if gfx_arch is None:
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
stages["amdgcn"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_arch_full_details[0],
|
||||
gfx_arch_full_details[2]))
|
||||
|
||||
|
||||
def add_cuda_stages(arch, extern_libs, stages):
|
||||
|
||||
stages["ptx"] = (lambda path: Path(path).read_text(),
|
||||
@@ -457,23 +369,22 @@ def add_cuda_stages(arch, extern_libs, stages):
|
||||
def compile(fn, **kwargs):
|
||||
# Get device type to decide which backend should be used
|
||||
device_type = kwargs.get("device_type", "cuda")
|
||||
_device_backend = get_backend(device_type)
|
||||
capability = kwargs.get("cc", None)
|
||||
|
||||
if device_type in ["cuda", "hip"]:
|
||||
# hip with kwargs.get("cc", None) causes multiprocessing issues in torch.compile
|
||||
if device_type == "hip":
|
||||
arch = get_architecture_descriptor(None if type(capability) is int else capability)
|
||||
else:
|
||||
arch = get_architecture_descriptor(capability)
|
||||
if is_hip():
|
||||
device_type = "hip"
|
||||
|
||||
if device_type == "cuda":
|
||||
_device_backend = get_backend(device_type)
|
||||
arch = get_architecture_descriptor(capability)
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
arch = _device_backend.get_architecture_descriptor(**kwargs)
|
||||
|
||||
is_cuda = device_type == "cuda" and _is_cuda(arch)
|
||||
is_hip = device_type in ["cuda", "hip"] and not is_cuda
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
|
||||
if is_hip():
|
||||
is_cuda = False
|
||||
context = ir.context()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
|
||||
@@ -506,14 +417,20 @@ def compile(fn, **kwargs):
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
stages["ttir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
|
||||
if is_cuda:
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
elif is_hip:
|
||||
add_rocm_stages(arch, extern_libs, stages)
|
||||
elif device_type == "hip":
|
||||
_device_backend.add_stages(arch, extern_libs, stages, num_warps=num_warps, num_stages=num_stages)
|
||||
elif device_type == "xpu":
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
# pass the user's configuration to the backend device.
|
||||
arch["num_warps"] = num_warps
|
||||
@@ -632,17 +549,23 @@ def compile(fn, **kwargs):
|
||||
else:
|
||||
asm[ir_name] = str(next_module)
|
||||
if ir_name == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = get_shared_memory_size(module)
|
||||
if is_hip():
|
||||
metadata["shared"] = _device_backend.get_shared_memory_size(module)
|
||||
else:
|
||||
metadata["shared"] = get_shared_memory_size(module)
|
||||
if ir_name == "ttgir":
|
||||
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
metadata["num_warps"] = get_num_warps(next_module)
|
||||
if is_hip():
|
||||
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
|
||||
else:
|
||||
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
metadata["num_warps"] = get_num_warps(next_module)
|
||||
if ir_name == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
if ir_name == "amdgcn":
|
||||
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
|
||||
asm["hsaco_path"] = next_module[1]
|
||||
if not is_cuda and not is_hip:
|
||||
if not is_cuda and not is_hip():
|
||||
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
|
||||
module = next_module
|
||||
|
||||
@@ -667,7 +590,7 @@ def compile(fn, **kwargs):
|
||||
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
|
||||
ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs}
|
||||
# cache manager
|
||||
if is_cuda or is_hip:
|
||||
if is_cuda:
|
||||
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
|
||||
else:
|
||||
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
|
||||
@@ -707,7 +630,7 @@ class CompiledKernel:
|
||||
self.tensormaps_info = metadata["tensormaps_info"]
|
||||
self.constants = metadata["constants"]
|
||||
self.device_type = metadata["device_type"]
|
||||
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
|
||||
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None
|
||||
# initialize asm dict
|
||||
self.asm = asm
|
||||
# binaries are lazily initialized
|
||||
@@ -721,7 +644,7 @@ class CompiledKernel:
|
||||
if self.cu_module is not None:
|
||||
return
|
||||
|
||||
if self.device_type in ["cuda", "hip"]:
|
||||
if self.device_type in ["cuda"]:
|
||||
device = get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
@@ -767,7 +690,7 @@ class CompiledKernel:
|
||||
def runner(*args, stream=None):
|
||||
args_expand = self.assemble_tensormap_to_arg(args)
|
||||
if stream is None:
|
||||
if self.device_type in ["cuda", "hip"]:
|
||||
if self.device_type in ["cuda"]:
|
||||
stream = get_cuda_stream()
|
||||
else:
|
||||
stream = get_backend(self.device_type).get_stream(None)
|
||||
|
||||
@@ -3,16 +3,11 @@ import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..common.build import is_hip
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
from .utils import generate_cu_signature
|
||||
|
||||
|
||||
def is_hip():
|
||||
import torch
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
# ----- stub --------
|
||||
|
||||
|
||||
@@ -103,150 +98,9 @@ def generate_launcher(constants, signature, ids):
|
||||
format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
if is_hip():
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
src = f"""
|
||||
#define __HIP_PLATFORM_AMD__
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <Python.h>
|
||||
#include <stdbool.h>
|
||||
#include <dlfcn.h>
|
||||
|
||||
static inline void gpuAssert(hipError_t code, const char *file, int line)
|
||||
{{
|
||||
if (code != HIP_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [HIP]: ";
|
||||
const char* str = hipGetErrorString(code);
|
||||
char err[1024] = {{0}};
|
||||
snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str );
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
||||
// printf("_launch hip kernel\\n");
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
|
||||
if (gridX*gridY*gridZ > 0) {{
|
||||
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
typedef struct _DevicePtrInfo {{
|
||||
hipDeviceptr_t dev_ptr;
|
||||
bool valid;
|
||||
}} DevicePtrInfo;
|
||||
|
||||
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
|
||||
DevicePtrInfo ptr_info;
|
||||
ptr_info.dev_ptr = 0;
|
||||
ptr_info.valid = true;
|
||||
if (PyLong_Check(obj)) {{
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj);
|
||||
return ptr_info;
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
// valid nullptr
|
||||
return ptr_info;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
ptr_info.valid = false;
|
||||
return ptr_info;
|
||||
}}
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret);
|
||||
if(!ptr_info.dev_ptr)
|
||||
return ptr_info;
|
||||
uint64_t dev_ptr;
|
||||
hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
|
||||
if (status == hipErrorInvalidValue) {{
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
|
||||
ptr_info.valid = false;
|
||||
}}
|
||||
ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr;
|
||||
Py_DECREF(ret);
|
||||
return ptr_info;
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return ptr_info;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
// printf("launch\\n");
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int num_ctas;
|
||||
int clusterDimX;
|
||||
int clusterDimY;
|
||||
int clusterDimZ;
|
||||
int shared_memory;
|
||||
PyObject *launch_enter_hook = NULL;
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
if (launch_enter_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_enter_hook, args);
|
||||
}}
|
||||
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
src = f"""
|
||||
folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']]
|
||||
params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)]
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <stdbool.h>
|
||||
#include <Python.h>
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import functools
|
||||
import os
|
||||
|
||||
from ..common.build import is_hip
|
||||
from . import core
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def libdevice_path():
|
||||
import torch
|
||||
third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party")
|
||||
if torch.version.hip is None:
|
||||
default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc")
|
||||
if is_hip():
|
||||
default = os.path.join(third_party_dir, "hip", "lib", "bitcode", "cuda2gcn.bc")
|
||||
else:
|
||||
default = os.path.join(third_party_dir, "rocm", "lib", "bitcode", "cuda2gcn.bc")
|
||||
default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc")
|
||||
|
||||
return os.getenv("TRITON_LIBDEVICE_PATH", default)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from functools import wraps
|
||||
from typing import List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
from ..common.build import is_hip
|
||||
from . import core as tl
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
@@ -1301,6 +1302,19 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def gpu_has_mfma() -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
return True # mfma supported in ['gfx908', 'gfx90a']
|
||||
|
||||
|
||||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
|
||||
if not gpu_has_mfma():
|
||||
return False
|
||||
# TODO: Add check for configurations and types.
|
||||
return True
|
||||
|
||||
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
|
||||
@@ -383,20 +383,20 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda', 'hip']:
|
||||
if device_type not in ['cuda']:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
|
||||
if device is None:
|
||||
if device_type in ['cuda', 'hip']:
|
||||
if device_type in ['cuda']:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
if device_type in ['cuda', 'hip']:
|
||||
if device_type in ['cuda']:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
Reference in New Issue
Block a user