mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix math for constant values (#2472)
https://github.com/openai/triton/issues/2470
This commit is contained in:
@@ -860,9 +860,9 @@ def test_unary_op(dtype_x, expr, num_ctas, device):
|
||||
# ----------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']])
|
||||
def test_math_op(dtype_x, expr, device):
|
||||
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin'] for x in ['x', '3.0']])
|
||||
def test_math_op(dtype_x, expr, device, x):
|
||||
_test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test abs
|
||||
|
||||
@@ -1286,6 +1286,8 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
:type ieee_rounding: bool
|
||||
"""
|
||||
ieee_rounding = _constexpr_to_value(ieee_rounding)
|
||||
x = _to_tensor(x, _builder)
|
||||
y = _to_tensor(y, _builder)
|
||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
@@ -1307,36 +1309,42 @@ def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("exponential")
|
||||
def exp(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.log(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("cosine")
|
||||
def cos(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.cos(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("sine")
|
||||
def sin(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.sin(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("square root")
|
||||
def sqrt(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.sqrt(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("absolute value")
|
||||
def abs(x, _builder=None):
|
||||
x = _to_tensor(x, _builder)
|
||||
return semantic.abs(x, _builder)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user