mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Add support for default args in kernel wrappers (#1943)
Fixes the case where setting default values for arguments in a kernel
function signature results in a generated kernel wrapper function
without these default values.
For example:
```
@triton.jit
def kernel(x, y, z=3):
...
...
kernel[grid](x,y)
```
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -2493,7 +2493,7 @@ def test_default(device):
|
||||
ret1 = torch.zeros(1, dtype=torch.int32, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(ret0, ret1, value):
|
||||
def _kernel(ret0, ret1, value=3):
|
||||
tl.store(ret0, _impl())
|
||||
tl.store(ret1, _impl(value))
|
||||
|
||||
@@ -2501,6 +2501,10 @@ def test_default(device):
|
||||
assert ret0.item() == 10
|
||||
assert ret1.item() == value
|
||||
|
||||
_kernel[(1,)](ret0, ret1)
|
||||
assert ret0.item() == 10
|
||||
assert ret1.item() == 3
|
||||
|
||||
# ---------------
|
||||
# test noop
|
||||
# ----------------
|
||||
|
||||
@@ -316,9 +316,10 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
@@ -327,6 +328,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
assert grid is not None
|
||||
if callable(grid):
|
||||
grid = grid({{{grid_args}}})
|
||||
grid_size = len(grid)
|
||||
@@ -407,7 +409,8 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
self.has_defaults = any(v != inspect._empty for v in self.arg_defaults)
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
|
||||
Reference in New Issue
Block a user