[FRONTEND] Add support for default args in kernel wrappers (#1943)

Fixes the case where setting default values for arguments in a kernel
function signature results in a generated kernel wrapper function
without these default values.

For example:
```
@triton.jit
def kernel(x, y, z=3):
    ...

...
kernel[grid](x,y)
```

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Alex Collins
2023-07-14 22:32:47 +01:00
committed by GitHub
parent 4042bd57a0
commit 80163a9c1e
2 changed files with 10 additions and 3 deletions

View File

@@ -2493,7 +2493,7 @@ def test_default(device):
ret1 = torch.zeros(1, dtype=torch.int32, device=device)
@triton.jit
def _kernel(ret0, ret1, value):
def _kernel(ret0, ret1, value=3):
tl.store(ret0, _impl())
tl.store(ret1, _impl(value))
@@ -2501,6 +2501,10 @@ def test_default(device):
assert ret0.item() == 10
assert ret1.item() == value
_kernel[(1,)](ret0, ret1)
assert ret0.item() == 10
assert ret1.item() == 3
# ---------------
# test noop
# ----------------

View File

@@ -316,9 +316,10 @@ class JITFunction(KernelInterface[T]):
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
@@ -327,6 +328,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
assert grid is not None
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
@@ -407,7 +409,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}