23 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
24 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename Op>
40 template <
typename Op>
49 template <
typename Op>
53 ScalarOpToLibmCall(
MLIRContext *context, StringRef floatFunc,
56 doubleFunc(doubleFunc){};
61 std::string floatFunc, doubleFunc;
64 template <
typename OpTy>
66 StringRef floatFunc, StringRef doubleFunc) {
67 patterns.
add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68 patterns.
add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
73 template <
typename Op>
76 auto opType = op.getType();
78 auto vecType = dyn_cast<VectorType>(opType);
82 if (!vecType.hasRank())
84 auto shape = vecType.getShape();
85 int64_t numElements = vecType.getNumElements();
91 for (
auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
96 rewriter.
create<vector::ExtractOp>(loc, input, positions));
98 rewriter.
create<
Op>(loc, vecType.getElementType(), operands);
100 rewriter.
create<vector::InsertOp>(loc, scalarOp, result, positions);
106 template <
typename Op>
109 auto opType = op.getType();
110 if (!isa<Float16Type, BFloat16Type>(opType))
115 auto extendedOperands = llvm::to_vector(
117 return rewriter.create<arith::ExtFOp>(loc, f32, operand);
119 auto newOp = rewriter.
create<
Op>(loc, f32, extendedOperands);
124 template <
typename Op>
126 ScalarOpToLibmCall<Op>::matchAndRewrite(
Op op,
129 auto type = op.getType();
130 if (!isa<Float32Type, Float64Type>(type))
133 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
134 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
151 opFunc->
setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
165 populatePatternsForOp<math::AbsFOp>(patterns, ctx,
"fabsf",
"fabs");
166 populatePatternsForOp<math::AcosOp>(patterns, ctx,
"acosf",
"acos");
167 populatePatternsForOp<math::AcoshOp>(patterns, ctx,
"acoshf",
"acosh");
168 populatePatternsForOp<math::AsinOp>(patterns, ctx,
"asinf",
"asin");
169 populatePatternsForOp<math::AsinhOp>(patterns, ctx,
"asinhf",
"asinh");
170 populatePatternsForOp<math::Atan2Op>(patterns, ctx,
"atan2f",
"atan2");
171 populatePatternsForOp<math::AtanOp>(patterns, ctx,
"atanf",
"atan");
172 populatePatternsForOp<math::AtanhOp>(patterns, ctx,
"atanhf",
"atanh");
173 populatePatternsForOp<math::CbrtOp>(patterns, ctx,
"cbrtf",
"cbrt");
174 populatePatternsForOp<math::CeilOp>(patterns, ctx,
"ceilf",
"ceil");
175 populatePatternsForOp<math::CosOp>(patterns, ctx,
"cosf",
"cos");
176 populatePatternsForOp<math::CoshOp>(patterns, ctx,
"coshf",
"cosh");
177 populatePatternsForOp<math::ErfOp>(patterns, ctx,
"erff",
"erf");
178 populatePatternsForOp<math::ExpOp>(patterns, ctx,
"expf",
"exp");
179 populatePatternsForOp<math::Exp2Op>(patterns, ctx,
"exp2f",
"exp2");
180 populatePatternsForOp<math::ExpM1Op>(patterns, ctx,
"expm1f",
"expm1");
181 populatePatternsForOp<math::FloorOp>(patterns, ctx,
"floorf",
"floor");
182 populatePatternsForOp<math::FmaOp>(patterns, ctx,
"fmaf",
"fma");
183 populatePatternsForOp<math::LogOp>(patterns, ctx,
"logf",
"log");
184 populatePatternsForOp<math::Log2Op>(patterns, ctx,
"log2f",
"log2");
185 populatePatternsForOp<math::Log10Op>(patterns, ctx,
"log10f",
"log10");
186 populatePatternsForOp<math::Log1pOp>(patterns, ctx,
"log1pf",
"log1p");
187 populatePatternsForOp<math::PowFOp>(patterns, ctx,
"powf",
"pow");
188 populatePatternsForOp<math::RoundEvenOp>(patterns, ctx,
"roundevenf",
190 populatePatternsForOp<math::RoundOp>(patterns, ctx,
"roundf",
"round");
191 populatePatternsForOp<math::SinOp>(patterns, ctx,
"sinf",
"sin");
192 populatePatternsForOp<math::SinhOp>(patterns, ctx,
"sinhf",
"sinh");
193 populatePatternsForOp<math::SqrtOp>(patterns, ctx,
"sqrtf",
"sqrt");
194 populatePatternsForOp<math::TanOp>(patterns, ctx,
"tanf",
"tan");
195 populatePatternsForOp<math::TanhOp>(patterns, ctx,
"tanhf",
"tanh");
196 populatePatternsForOp<math::TruncOp>(patterns, ctx,
"truncf",
"trunc");
200 struct ConvertMathToLibmPass
201 :
public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> {
202 void runOnOperation()
override;
206 void ConvertMathToLibmPass::runOnOperation() {
207 auto module = getOperation();
213 target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
214 vector::VectorDialect>();
215 target.addIllegalDialect<math::MathDialect>();
221 return std::make_unique<ConvertMathToLibmPass>();
static MLIRContext * getContext(OpFoldResult val)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to Libm calls.
std::unique_ptr< OperationPass< ModuleOp > > createConvertMathToLibmPass()
Create a pass to convert Math operations to libm calls.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...