23 #define GEN_PASS_DEF_CONVERTMATHTOLIBMPASS
24 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename Op>
40 template <
typename Op>
49 template <
typename Op>
54 StringRef floatFunc, StringRef doubleFunc)
56 doubleFunc(doubleFunc) {};
61 std::string floatFunc, doubleFunc;
64 template <
typename OpTy>
67 StringRef doubleFunc) {
68 patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69 patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
74 template <
typename Op>
77 auto opType = op.getType();
79 auto vecType = dyn_cast<VectorType>(opType);
83 if (!vecType.hasRank())
85 auto shape = vecType.getShape();
86 int64_t numElements = vecType.getNumElements();
92 for (
auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
95 for (
auto input : op->getOperands())
97 rewriter.
create<vector::ExtractOp>(loc, input, positions));
99 rewriter.
create<
Op>(loc, vecType.getElementType(), operands);
101 rewriter.
create<vector::InsertOp>(loc, scalarOp, result, positions);
107 template <
typename Op>
110 auto opType = op.getType();
111 if (!isa<Float16Type, BFloat16Type>(opType))
116 auto extendedOperands = llvm::to_vector(
117 llvm::map_range(op->getOperands(), [&](
Value operand) ->
Value {
118 return rewriter.create<arith::ExtFOp>(loc, f32, operand);
120 auto newOp = rewriter.
create<
Op>(loc, f32, extendedOperands);
125 template <
typename Op>
127 ScalarOpToLibmCall<Op>::matchAndRewrite(
Op op,
130 auto type = op.getType();
131 if (!isa<Float32Type, Float64Type>(type))
134 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
135 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
142 rewriter.
getContext(), op->getOperandTypes(), op->getResultTypes());
152 opFunc->
setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
167 populatePatternsForOp<math::AbsFOp>(
patterns, benefit, ctx,
"fabsf",
"fabs");
168 populatePatternsForOp<math::AcosOp>(
patterns, benefit, ctx,
"acosf",
"acos");
169 populatePatternsForOp<math::AcoshOp>(
patterns, benefit, ctx,
"acoshf",
171 populatePatternsForOp<math::AsinOp>(
patterns, benefit, ctx,
"asinf",
"asin");
172 populatePatternsForOp<math::AsinhOp>(
patterns, benefit, ctx,
"asinhf",
174 populatePatternsForOp<math::Atan2Op>(
patterns, benefit, ctx,
"atan2f",
176 populatePatternsForOp<math::AtanOp>(
patterns, benefit, ctx,
"atanf",
"atan");
177 populatePatternsForOp<math::AtanhOp>(
patterns, benefit, ctx,
"atanhf",
179 populatePatternsForOp<math::CbrtOp>(
patterns, benefit, ctx,
"cbrtf",
"cbrt");
180 populatePatternsForOp<math::CeilOp>(
patterns, benefit, ctx,
"ceilf",
"ceil");
181 populatePatternsForOp<math::CosOp>(
patterns, benefit, ctx,
"cosf",
"cos");
182 populatePatternsForOp<math::CoshOp>(
patterns, benefit, ctx,
"coshf",
"cosh");
183 populatePatternsForOp<math::ErfOp>(
patterns, benefit, ctx,
"erff",
"erf");
184 populatePatternsForOp<math::ErfcOp>(
patterns, benefit, ctx,
"erfcf",
"erfc");
185 populatePatternsForOp<math::ExpOp>(
patterns, benefit, ctx,
"expf",
"exp");
186 populatePatternsForOp<math::Exp2Op>(
patterns, benefit, ctx,
"exp2f",
"exp2");
187 populatePatternsForOp<math::ExpM1Op>(
patterns, benefit, ctx,
"expm1f",
189 populatePatternsForOp<math::FloorOp>(
patterns, benefit, ctx,
"floorf",
191 populatePatternsForOp<math::FmaOp>(
patterns, benefit, ctx,
"fmaf",
"fma");
192 populatePatternsForOp<math::LogOp>(
patterns, benefit, ctx,
"logf",
"log");
193 populatePatternsForOp<math::Log2Op>(
patterns, benefit, ctx,
"log2f",
"log2");
194 populatePatternsForOp<math::Log10Op>(
patterns, benefit, ctx,
"log10f",
196 populatePatternsForOp<math::Log1pOp>(
patterns, benefit, ctx,
"log1pf",
198 populatePatternsForOp<math::PowFOp>(
patterns, benefit, ctx,
"powf",
"pow");
199 populatePatternsForOp<math::RoundEvenOp>(
patterns, benefit, ctx,
"roundevenf",
201 populatePatternsForOp<math::RoundOp>(
patterns, benefit, ctx,
"roundf",
203 populatePatternsForOp<math::SinOp>(
patterns, benefit, ctx,
"sinf",
"sin");
204 populatePatternsForOp<math::SinhOp>(
patterns, benefit, ctx,
"sinhf",
"sinh");
205 populatePatternsForOp<math::SqrtOp>(
patterns, benefit, ctx,
"sqrtf",
"sqrt");
206 populatePatternsForOp<math::RsqrtOp>(
patterns, benefit, ctx,
"rsqrtf",
208 populatePatternsForOp<math::TanOp>(
patterns, benefit, ctx,
"tanf",
"tan");
209 populatePatternsForOp<math::TanhOp>(
patterns, benefit, ctx,
"tanhf",
"tanh");
210 populatePatternsForOp<math::TruncOp>(
patterns, benefit, ctx,
"truncf",
215 struct ConvertMathToLibmPass
216 :
public impl::ConvertMathToLibmPassBase<ConvertMathToLibmPass> {
217 void runOnOperation()
override;
221 void ConvertMathToLibmPass::runOnOperation() {
222 auto module = getOperation();
228 target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
229 vector::VectorDialect>();
230 target.addIllegalDialect<math::MathDialect>();
static MLIRContext * getContext(OpFoldResult val)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to Libm calls.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...