[FRONTEND] Support jit functions without arguments (#2043)

Issue https://github.com/openai/triton/issues/1973

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2023-08-07 19:05:56 -07:00
committed by GitHub
parent 98523bcc48
commit 30a331e628
4 changed files with 22 additions and 13 deletions

View File

@@ -28,6 +28,11 @@ def kernel_static_print(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_no_arg_print():
print("", tl.program_id(0))
def test_print(func: str, data_type: str):
shape = (128, )
# limit the range of integers so that the sum does not overflow
@@ -39,7 +44,11 @@ def test_print(func: str, data_type: str):
kernel_print[(1,)](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
elif func == "no_arg_print":
kernel_no_arg_print[(1,)](num_warps=4)
if func != "no_arg_print":
assert_close(y, x)
if __name__ == "__main__":

View File

@@ -15,7 +15,7 @@ torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32",
@pytest.mark.parametrize("func_type, data_type",
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
def test_print(func_type: str, data_type: str):
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
outs, _ = proc.communicate()
@@ -29,10 +29,9 @@ def test_print(func_type: str, data_type: str):
new_lines.add(value)
except Exception as e:
print(e)
if func_type != "static_print":
if func_type != "static_print" and func_type != "no_arg_print":
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128
else:
assert len(new_lines) == 1

View File

@@ -198,7 +198,7 @@ def generate_launcher(constants, signature, ids):
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
return NULL;
}}
@@ -208,7 +208,7 @@ def generate_launcher(constants, signature, ids):
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())});
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
@@ -267,7 +267,7 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if(gridX*gridY*gridZ > 0){{
if (num_ctas == 1) {{
@@ -356,7 +356,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{
return NULL;
}}
@@ -367,7 +367,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);

View File

@@ -352,11 +352,12 @@ class JITFunction(KernelInterface[T]):
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_ctas=1, num_stages=3, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
def {self.fn.__name__}({args_signature}grid=None, num_warps=4, num_ctas=1, num_stages=3, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
@@ -399,7 +400,7 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_ctas=1, num
if bin is not None:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
@@ -413,7 +414,7 @@ def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_ctas=1, num
else:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
all_args = {', '.join([f'{arg}' for arg in self.arg_names]) + ', ' if len(self.arg_names) > 0 else ()}
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})