mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix tl.exp for fp16 (#1440)
https://github.com/openai/triton/issues/1438 https://github.com/openai/triton/issues/1360
This commit is contained in:
@@ -995,8 +995,8 @@ struct ExpOpConversionApprox
|
||||
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||
// For non-FP32 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() != 32)
|
||||
return {};
|
||||
|
||||
const double log2e = 1.4426950408889634;
|
||||
@@ -1117,7 +1117,7 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
|
||||
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is
|
||||
// FP32. For FP64 input type, ExpOpConversionApprox will return failure and
|
||||
// FP32. For other input types, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
|
||||
@@ -514,11 +514,9 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# ----------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
'exp', 'log', 'cos', 'sin'
|
||||
])
|
||||
def test_math_op(expr, device='cuda'):
|
||||
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in float_dtypes for expr in ['exp', 'log', 'cos', 'sin']])
|
||||
def test_math_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
||||
|
||||
# ----------------
|
||||
# test abs
|
||||
|
||||
Reference in New Issue
Block a user