20#include "llvm/ADT/SmallVectorExtras.h"
23#define GEN_PASS_DEF_CONVERTMATHTOLIBMPASS
24#include "mlir/Conversion/Passes.h.inc"
35 using OpRewritePattern<
Op>::OpRewritePattern;
37 LogicalResult matchAndRewrite(
Op op, PatternRewriter &rewriter)
const final;
43 using OpRewritePattern<
Op>::OpRewritePattern;
45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter)
const final;
52 using OpRewritePattern<
Op>::OpRewritePattern;
53 ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
54 StringRef floatFunc, StringRef doubleFunc)
55 : OpRewritePattern<
Op>(context, benefit), floatFunc(floatFunc),
56 doubleFunc(doubleFunc) {};
58 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter)
const final;
61 std::string floatFunc, doubleFunc;
64template <
typename OpTy>
67 StringRef doubleFunc) {
68 patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69 patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
77 auto opType = op.getType();
79 auto vecType = dyn_cast<VectorType>(opType);
83 if (!vecType.hasRank())
85 auto shape = vecType.getShape();
86 int64_t numElements = vecType.getNumElements();
91 FloatAttr::get(vecType.getElementType(), 0.0)));
93 for (
auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
96 for (
auto input : op->getOperands())
98 vector::ExtractOp::create(rewriter, loc, input, positions));
100 Op::create(rewriter, loc, vecType.getElementType(), operands);
102 vector::InsertOp::create(rewriter, loc, scalarOp,
result, positions);
108template <
typename Op>
111 auto opType = op.getType();
112 if (!isa<Float16Type, BFloat16Type>(opType))
117 auto extendedOperands =
118 llvm::map_to_vector(op->getOperands(), [&](
Value operand) ->
Value {
119 return arith::ExtFOp::create(rewriter, loc, f32, operand);
121 auto newOp = Op::create(rewriter, loc, f32, extendedOperands);
126template <
typename Op>
128ScalarOpToLibmCall<Op>::matchAndRewrite(
Op op,
130 auto module = SymbolTable::getNearestSymbolTable(op);
131 auto type = op.getType();
132 if (!isa<Float32Type, Float64Type>(type))
135 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
136 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
142 auto opFunctionTy = FunctionType::get(
143 rewriter.
getContext(), op->getOperandTypes(), op->getResultTypes());
144 opFunc = func::FuncOp::create(rewriter, rewriter.
getUnknownLoc(), name,
153 opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
168 populatePatternsForOp<math::AbsFOp>(
patterns, benefit, ctx,
"fabsf",
"fabs");
169 populatePatternsForOp<math::AcosOp>(
patterns, benefit, ctx,
"acosf",
"acos");
170 populatePatternsForOp<math::AcoshOp>(
patterns, benefit, ctx,
"acoshf",
172 populatePatternsForOp<math::AsinOp>(
patterns, benefit, ctx,
"asinf",
"asin");
173 populatePatternsForOp<math::AsinhOp>(
patterns, benefit, ctx,
"asinhf",
175 populatePatternsForOp<math::Atan2Op>(
patterns, benefit, ctx,
"atan2f",
177 populatePatternsForOp<math::AtanOp>(
patterns, benefit, ctx,
"atanf",
"atan");
178 populatePatternsForOp<math::AtanhOp>(
patterns, benefit, ctx,
"atanhf",
180 populatePatternsForOp<math::CbrtOp>(
patterns, benefit, ctx,
"cbrtf",
"cbrt");
181 populatePatternsForOp<math::CeilOp>(
patterns, benefit, ctx,
"ceilf",
"ceil");
182 populatePatternsForOp<math::CosOp>(
patterns, benefit, ctx,
"cosf",
"cos");
183 populatePatternsForOp<math::CoshOp>(
patterns, benefit, ctx,
"coshf",
"cosh");
184 populatePatternsForOp<math::ErfOp>(
patterns, benefit, ctx,
"erff",
"erf");
185 populatePatternsForOp<math::ErfcOp>(
patterns, benefit, ctx,
"erfcf",
"erfc");
186 populatePatternsForOp<math::ExpOp>(
patterns, benefit, ctx,
"expf",
"exp");
187 populatePatternsForOp<math::Exp2Op>(
patterns, benefit, ctx,
"exp2f",
"exp2");
188 populatePatternsForOp<math::ExpM1Op>(
patterns, benefit, ctx,
"expm1f",
190 populatePatternsForOp<math::FloorOp>(
patterns, benefit, ctx,
"floorf",
192 populatePatternsForOp<math::FmaOp>(
patterns, benefit, ctx,
"fmaf",
"fma");
193 populatePatternsForOp<math::LogOp>(
patterns, benefit, ctx,
"logf",
"log");
194 populatePatternsForOp<math::Log2Op>(
patterns, benefit, ctx,
"log2f",
"log2");
195 populatePatternsForOp<math::Log10Op>(
patterns, benefit, ctx,
"log10f",
197 populatePatternsForOp<math::Log1pOp>(
patterns, benefit, ctx,
"log1pf",
199 populatePatternsForOp<math::PowFOp>(
patterns, benefit, ctx,
"powf",
"pow");
200 populatePatternsForOp<math::RoundEvenOp>(
patterns, benefit, ctx,
"roundevenf",
202 populatePatternsForOp<math::RoundOp>(
patterns, benefit, ctx,
"roundf",
204 populatePatternsForOp<math::SinOp>(
patterns, benefit, ctx,
"sinf",
"sin");
205 populatePatternsForOp<math::SinhOp>(
patterns, benefit, ctx,
"sinhf",
"sinh");
206 populatePatternsForOp<math::SqrtOp>(
patterns, benefit, ctx,
"sqrtf",
"sqrt");
207 populatePatternsForOp<math::RsqrtOp>(
patterns, benefit, ctx,
"rsqrtf",
209 populatePatternsForOp<math::TanOp>(
patterns, benefit, ctx,
"tanf",
"tan");
210 populatePatternsForOp<math::TanhOp>(
patterns, benefit, ctx,
"tanhf",
"tanh");
211 populatePatternsForOp<math::TruncOp>(
patterns, benefit, ctx,
"truncf",
216struct ConvertMathToLibmPass
218 void runOnOperation()
override;
222void ConvertMathToLibmPass::runOnOperation() {
223 auto module = getOperation();
229 target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
230 vector::VectorDialect>();
231 target.addIllegalDialect<math::MathDialect>();
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Include the generated interface declarations.
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to Libm calls.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...