mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Support jit functions without arguments (#2043)
Issue https://github.com/openai/triton/issues/1973 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -28,6 +28,11 @@ def kernel_static_print(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_no_arg_print():
|
||||
print("", tl.program_id(0))
|
||||
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
@@ -39,7 +44,11 @@ def test_print(func: str, data_type: str):
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
elif func == "no_arg_print":
|
||||
kernel_no_arg_print[(1,)](num_warps=4)
|
||||
|
||||
if func != "no_arg_print":
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,7 +15,7 @@ torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32",
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
@@ -29,10 +29,9 @@ def test_print(func_type: str, data_type: str):
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print":
|
||||
if func_type != "static_print" and func_type != "no_arg_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
|
||||
|
||||
@@ -198,7 +198,7 @@ def generate_launcher(constants, signature, ids):
|
||||
PyObject *compiled_kernel = NULL;
|
||||
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
@@ -208,7 +208,7 @@ def generate_launcher(constants, signature, ids):
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
}}
|
||||
@@ -267,7 +267,7 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
if (num_ctas == 1) {{
|
||||
@@ -356,7 +356,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
PyObject *launch_exit_hook = NULL;
|
||||
PyObject *compiled_kernel = NULL;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
|
||||
return NULL;
|
||||
}}
|
||||
|
||||
@@ -367,7 +367,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
|
||||
// raise exception asap
|
||||
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
|
||||
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
|
||||
|
||||
if (launch_exit_hook != Py_None) {{
|
||||
PyObject_CallObject(launch_exit_hook, args);
|
||||
|
||||
@@ -352,11 +352,12 @@ class JITFunction(KernelInterface[T]):
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_ctas=1, num_stages=3, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=4, num_ctas=1, num_stages=3, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
@@ -399,7 +400,7 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_ctas=1, num
|
||||
if bin is not None:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
@@ -413,7 +414,7 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_ctas=1, num
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
|
||||
configs = self._get_config(*all_args),
|
||||
constants = self._make_constants(constexpr_key)
|
||||
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
|
||||
|
||||
Reference in New Issue
Block a user