mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] expose tl.max_constancy hint (#1951)
Similar to `tl.multiple_of` and `tl.max_contiguous`, `tl.max_constancy` will expose a compiler hint indicating that all the values are equal in a block of a certain size. --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -554,6 +554,10 @@ class TritonLangProxy:
|
||||
def max_contiguous(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_constancy(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def abs(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
@@ -62,6 +62,7 @@ from .core import (
|
||||
log,
|
||||
make_block_ptr,
|
||||
max,
|
||||
max_constancy,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
min,
|
||||
@@ -162,6 +163,7 @@ __all__ = [
|
||||
"log",
|
||||
"make_block_ptr",
|
||||
"max",
|
||||
"max_constancy",
|
||||
"max_contiguous",
|
||||
"maximum",
|
||||
"min",
|
||||
|
||||
@@ -1665,6 +1665,24 @@ def max_contiguous(input, values, _builder=None):
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_contiguous(input, values)
|
||||
|
||||
|
||||
@builtin
|
||||
def max_constancy(input, values, _builder=None):
|
||||
"""
|
||||
Let the compiler knows that the `value` first values in :code:`input` are constant.
|
||||
|
||||
e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal,
|
||||
for example [0, 0, 0, 0, 1, 1, 1, 1].
|
||||
"""
|
||||
if isinstance(values, constexpr):
|
||||
values = [values]
|
||||
for i, d in enumerate(values):
|
||||
if not isinstance(d, constexpr):
|
||||
raise TypeError(f"values element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_constancy(input, values)
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
|
||||
@@ -1457,6 +1457,13 @@ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
return x
|
||||
|
||||
|
||||
def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to max_constancy does not match the length of values")
|
||||
x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
|
||||
|
||||
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_barrier(), tl.void)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user