Keren Zhou
2023-03-29 16:34:23 -07:00
committed by GitHub
parent f53bb6a1bc
commit 43eed392df
2 changed files with 6 additions and 8 deletions

View File

@@ -995,8 +995,8 @@ struct ExpOpConversionApprox
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
// For non-FP32 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};
const double log2e = 1.4426950408889634;
@@ -1117,7 +1117,7 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For FP64 input type, ExpOpConversionApprox will return failure and
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);

View File

@@ -514,11 +514,9 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# ----------------
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in float_dtypes for expr in ['exp', 'log', 'cos', 'sin']])
def test_math_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# ----------------
# test abs