23 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
24 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename Op>
40 template <
typename Op>
49 template <
typename Op>
53 ScalarOpToLibmCall(
MLIRContext *context, StringRef floatFunc,
56 doubleFunc(doubleFunc){};
61 std::string floatFunc, doubleFunc;
64 template <
typename OpTy>
66 StringRef floatFunc, StringRef doubleFunc) {
67 patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68 patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
73 template <
typename Op>
76 auto opType = op.getType();
78 auto vecType = dyn_cast<VectorType>(opType);
82 if (!vecType.hasRank())
84 auto shape = vecType.getShape();
85 int64_t numElements = vecType.getNumElements();
91 for (
auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
94 for (
auto input : op->getOperands())
96 rewriter.
create<vector::ExtractOp>(loc, input, positions));
98 rewriter.
create<
Op>(loc, vecType.getElementType(), operands);
100 rewriter.
create<vector::InsertOp>(loc, scalarOp, result, positions);
106 template <
typename Op>
109 auto opType = op.getType();
110 if (!isa<Float16Type, BFloat16Type>(opType))
115 auto extendedOperands = llvm::to_vector(
116 llvm::map_range(op->getOperands(), [&](
Value operand) ->
Value {
117 return rewriter.create<arith::ExtFOp>(loc, f32, operand);
119 auto newOp = rewriter.
create<
Op>(loc, f32, extendedOperands);
124 template <
typename Op>
126 ScalarOpToLibmCall<Op>::matchAndRewrite(
Op op,
129 auto type = op.getType();
130 if (!isa<Float32Type, Float64Type>(type))
133 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
134 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
141 rewriter.
getContext(), op->getOperandTypes(), op->getResultTypes());
151 opFunc->
setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
165 populatePatternsForOp<math::AbsFOp>(
patterns, ctx,
"fabsf",
"fabs");
166 populatePatternsForOp<math::AcosOp>(
patterns, ctx,
"acosf",
"acos");
167 populatePatternsForOp<math::AcoshOp>(
patterns, ctx,
"acoshf",
"acosh");
168 populatePatternsForOp<math::AsinOp>(
patterns, ctx,
"asinf",
"asin");
169 populatePatternsForOp<math::AsinhOp>(
patterns, ctx,
"asinhf",
"asinh");
170 populatePatternsForOp<math::Atan2Op>(
patterns, ctx,
"atan2f",
"atan2");
171 populatePatternsForOp<math::AtanOp>(
patterns, ctx,
"atanf",
"atan");
172 populatePatternsForOp<math::AtanhOp>(
patterns, ctx,
"atanhf",
"atanh");
173 populatePatternsForOp<math::CbrtOp>(
patterns, ctx,
"cbrtf",
"cbrt");
174 populatePatternsForOp<math::CeilOp>(
patterns, ctx,
"ceilf",
"ceil");
175 populatePatternsForOp<math::CosOp>(
patterns, ctx,
"cosf",
"cos");
176 populatePatternsForOp<math::CoshOp>(
patterns, ctx,
"coshf",
"cosh");
177 populatePatternsForOp<math::ErfOp>(
patterns, ctx,
"erff",
"erf");
178 populatePatternsForOp<math::ExpOp>(
patterns, ctx,
"expf",
"exp");
179 populatePatternsForOp<math::Exp2Op>(
patterns, ctx,
"exp2f",
"exp2");
180 populatePatternsForOp<math::ExpM1Op>(
patterns, ctx,
"expm1f",
"expm1");
181 populatePatternsForOp<math::FloorOp>(
patterns, ctx,
"floorf",
"floor");
182 populatePatternsForOp<math::FmaOp>(
patterns, ctx,
"fmaf",
"fma");
183 populatePatternsForOp<math::LogOp>(
patterns, ctx,
"logf",
"log");
184 populatePatternsForOp<math::Log2Op>(
patterns, ctx,
"log2f",
"log2");
185 populatePatternsForOp<math::Log10Op>(
patterns, ctx,
"log10f",
"log10");
186 populatePatternsForOp<math::Log1pOp>(
patterns, ctx,
"log1pf",
"log1p");
187 populatePatternsForOp<math::PowFOp>(
patterns, ctx,
"powf",
"pow");
188 populatePatternsForOp<math::RoundEvenOp>(
patterns, ctx,
"roundevenf",
190 populatePatternsForOp<math::RoundOp>(
patterns, ctx,
"roundf",
"round");
191 populatePatternsForOp<math::SinOp>(
patterns, ctx,
"sinf",
"sin");
192 populatePatternsForOp<math::SinhOp>(
patterns, ctx,
"sinhf",
"sinh");
193 populatePatternsForOp<math::SqrtOp>(
patterns, ctx,
"sqrtf",
"sqrt");
194 populatePatternsForOp<math::RsqrtOp>(
patterns, ctx,
"rsqrtf",
"rsqrt");
195 populatePatternsForOp<math::TanOp>(
patterns, ctx,
"tanf",
"tan");
196 populatePatternsForOp<math::TanhOp>(
patterns, ctx,
"tanhf",
"tanh");
197 populatePatternsForOp<math::TruncOp>(
patterns, ctx,
"truncf",
"trunc");
201 struct ConvertMathToLibmPass
202 :
public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> {
203 void runOnOperation()
override;
207 void ConvertMathToLibmPass::runOnOperation() {
208 auto module = getOperation();
214 target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
215 vector::VectorDialect>();
216 target.addIllegalDialect<math::MathDialect>();
222 return std::make_unique<ConvertMathToLibmPass>();
static MLIRContext * getContext(OpFoldResult val)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to Libm calls.
const FrozenRewritePatternSet & patterns
std::unique_ptr< OperationPass< ModuleOp > > createConvertMathToLibmPass()
Create a pass to convert Math operations to libm calls.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...