[FRONTEND] expose tl.max_constancy hint (#1951)

Similar to `tl.multiple_of` and `tl.max_contiguous`, `tl.max_constancy`
will expose a compiler hint indicating that all the values are equal in
a block of a certain size.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
David Berard
2023-07-17 11:30:25 -07:00
committed by GitHub
parent f6c4e8de76
commit 7202c6cff0
5 changed files with 32 additions and 0 deletions

View File

@@ -554,6 +554,10 @@ class TritonLangProxy:
def max_contiguous(self, input, values):
return input
@_tensor_operation
def max_constancy(self, input, values):
return input
@_tensor_operation
def abs(self, x):
return torch.abs(x)

View File

@@ -62,6 +62,7 @@ from .core import (
log,
make_block_ptr,
max,
max_constancy,
max_contiguous,
maximum,
min,
@@ -162,6 +163,7 @@ __all__ = [
"log",
"make_block_ptr",
"max",
"max_constancy",
"max_contiguous",
"maximum",
"min",

View File

@@ -1665,6 +1665,24 @@ def max_contiguous(input, values, _builder=None):
values = [x.value for x in values]
return semantic.max_contiguous(input, values)
@builtin
def max_constancy(input, values, _builder=None):
"""
Let the compiler knows that the `value` first values in :code:`input` are constant.
e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal,
for example [0, 0, 0, 0, 1, 1, 1, 1].
"""
if isinstance(values, constexpr):
values = [values]
for i, d in enumerate(values):
if not isinstance(d, constexpr):
raise TypeError(f"values element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
values = [x.value for x in values]
return semantic.max_constancy(input, values)
# -----------------------
# Debugging functions
# -----------------------

View File

@@ -1457,6 +1457,13 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
return x
def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_constancy does not match the length of values")
x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
return x
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(), tl.void)