23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/DebugLog.h"
28 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
29 #include "mlir/Conversion/Passes.h.inc"
34 #define DEBUG_TYPE "math-to-funcs"
38 template <
typename Op>
54 IPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
60 LogicalResult matchAndRewrite(math::IPowIOp op,
64 GetFuncCallbackTy getFuncOpCallback;
71 FPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
77 LogicalResult matchAndRewrite(math::FPowIOp op,
81 GetFuncCallbackTy getFuncOpCallback;
88 CtlzOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
90 getFuncOpCallback(cb) {}
94 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
98 GetFuncCallbackTy getFuncOpCallback;
102 template <
typename Op>
105 Type opType = op.getType();
107 auto vecType = dyn_cast<VectorType>(opType);
111 if (!vecType.hasRank())
114 int64_t numElements = vecType.getNumElements();
116 Type resultElementType = vecType.getElementType();
118 if (isa<FloatType>(resultElementType))
122 Value result = arith::ConstantOp::create(
125 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128 for (
Value input : op->getOperands())
130 vector::ExtractOp::create(rewriter, loc, input, positions));
132 Op::create(rewriter, loc, vecType.getElementType(), operands);
134 vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
145 [](
Type ty) { return getElementTypeOrSelf(ty); });
148 [](
Type ty) { return getElementTypeOrSelf(ty); });
184 assert(isa<IntegerType>(elementType) &&
185 "non-integer element type for IPowIOp");
190 std::string funcName(
"__mlir_math_ipowi");
191 llvm::raw_string_ostream nameOS(funcName);
192 nameOS <<
'_' << elementType;
195 builder.
getContext(), {elementType, elementType}, elementType);
196 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
197 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200 funcOp->setAttr(
"llvm.linkage", linkage);
203 Block *entryBlock = funcOp.addEntryBlock();
206 Value bArg = funcOp.getArgument(0);
207 Value pArg = funcOp.getArgument(1);
209 Value zeroValue = arith::ConstantOp::create(
211 Value oneValue = arith::ConstantOp::create(
213 Value minusOneValue = arith::ConstantOp::create(
214 builder, elementType,
222 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue);
224 func::ReturnOp::create(builder, oneValue);
228 cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
232 auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
237 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue);
240 func::ReturnOp::create(
242 arith::DivSIOp::create(builder, oneValue, zeroValue).getResult());
246 cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock);
251 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue);
254 func::ReturnOp::create(builder, oneValue);
258 cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock);
262 auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
263 bArg, minusOneValue);
266 auto pIsOdd = arith::CmpIOp::create(
267 builder, arith::CmpIPredicate::ne,
268 arith::AndIOp::create(builder, pArg, oneValue), zeroValue);
271 func::ReturnOp::create(builder, minusOneValue);
275 cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock);
280 func::ReturnOp::create(builder, oneValue);
284 cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->
getBlock(),
290 func::ReturnOp::create(builder, zeroValue);
292 funcBody, funcBody->
end(), {elementType, elementType, elementType},
293 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
297 cf::CondBranchOp::create(builder, pIsNeg, bIsZero->
getBlock(), loopHeader,
315 auto powerTmpIsOdd = arith::CmpIOp::create(
316 builder, arith::CmpIPredicate::ne,
317 arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue);
320 Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp);
321 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), elementType,
324 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
327 cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
334 Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue);
337 auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
338 newPowerTmp, zeroValue);
341 func::ReturnOp::create(builder, newResultTmp);
345 cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock,
351 Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp);
353 cf::BranchOp::create(
354 builder,
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
419 std::string funcName(
"__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS <<
'_' << baseType;
422 nameOS <<
'_' << powType;
423 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
427 funcOp->setAttr(
"llvm.linkage", linkage);
430 Block *entryBlock = funcOp.addEntryBlock();
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
436 Value oneBValue = arith::ConstantOp::create(
437 builder, baseType, builder.
getFloatAttr(baseType, 1.0));
438 Value zeroPValue = arith::ConstantOp::create(
440 Value onePValue = arith::ConstantOp::create(
442 Value minPValue = arith::ConstantOp::create(
445 powType, llvm::APInt::getSignedMinValue(powType.getWidth())));
446 Value maxPValue = arith::ConstantOp::create(
449 powType, llvm::APInt::getSignedMaxValue(powType.getWidth())));
453 auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg,
456 func::ReturnOp::create(builder, oneBValue);
460 cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
464 auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
468 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue);
475 Value negP = arith::SubIOp::create(builder, zeroPValue, pArg);
476 auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg);
477 pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit);
490 funcBody, funcBody->
end(), {baseType, baseType, powType},
491 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
494 cf::BranchOp::create(builder, loopHeader,
ValueRange{oneBValue, bArg, pInit});
503 auto powerTmpIsOdd = arith::CmpIOp::create(
504 builder, arith::CmpIPredicate::ne,
505 arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue);
508 Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp);
509 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
512 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
515 cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
522 Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue);
525 auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
526 newPowerTmp, zeroPValue);
536 Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp);
538 cf::BranchOp::create(
539 builder,
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
547 cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp,
553 newResultTmp = loopExit->getArgument(0);
555 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
558 cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock,
561 newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg);
562 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
572 cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock,
575 newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp);
576 cf::BranchOp::create(builder, newResultTmp, returnBlock);
580 func::ReturnOp::create(builder, returnBlock->
getArgument(0));
589 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
591 if (isa<VectorType>(op.getType()))
598 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
654 if (!isa<IntegerType>(elementType)) {
655 LDBG() <<
"non-integer element type for CtlzFunc; type was: "
657 llvm_unreachable(
"non-integer element type");
665 std::string funcName(
"__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS <<
'_' << elementType;
668 FunctionType funcType =
670 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
677 funcOp->setAttr(
"llvm.linkage", linkage);
681 Block *funcBody = funcOp.addEntryBlock();
684 Value arg = funcOp.getArgument(0);
686 Value bitWidthValue = arith::ConstantOp::create(
687 builder, elementType, builder.
getIntegerAttr(elementType, bitWidth));
688 Value zeroValue = arith::ConstantOp::create(
692 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue);
696 scf::IfOp::create(builder, elementType, inputEqZero,
698 auto thenBuilder = ifOp.getThenBodyBuilder();
699 scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
704 Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType,
705 elseBuilder.getIndexAttr(1));
706 Value oneValue = arith::ConstantOp::create(
707 elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1));
708 Value bitWidthIndex = arith::ConstantOp::create(
709 elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth));
710 Value nValue = arith::ConstantOp::create(
711 elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0));
713 auto loop = scf::ForOp::create(
714 elseBuilder, oneIndex, bitWidthIndex, oneIndex,
727 Value argIter = args[0];
728 Value nIter = args[1];
730 Value argIsNonNegative = arith::CmpIOp::create(
731 b, loc, arith::CmpIPredicate::slt, argIter, zeroValue);
732 scf::IfOp ifOp = scf::IfOp::create(
733 b, loc, argIsNonNegative,
736 scf::YieldOp::create(b, loc,
ValueRange{argIter, nIter});
740 Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue);
741 Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue);
742 scf::YieldOp::create(b, loc,
ValueRange{argNext, nNext});
744 scf::YieldOp::create(b, loc, ifOp.getResults());
746 scf::YieldOp::create(elseBuilder, loop.getResult(1));
748 func::ReturnOp::create(builder, ifOp.getResult(0));
754 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
756 if (isa<VectorType>(op.getType()))
760 func::FuncOp elementFunc = getFuncOpCallback(op, type);
763 diag <<
"Missing software implementation for op " << op->getName()
764 <<
" and type " << type;
772 struct ConvertMathToFuncsPass
773 :
public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
774 ConvertMathToFuncsPass() =
default;
775 ConvertMathToFuncsPass(
const ConvertMathToFuncsOptions &
options)
776 :
impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(
options) {}
778 void runOnOperation()
override;
784 bool isFPowIConvertible(math::FPowIOp op);
791 void generateOpImplementations();
800 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
803 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
806 bool ConvertMathToFuncsPass::isConvertible(
Operation *op) {
810 void ConvertMathToFuncsPass::generateOpImplementations() {
811 ModuleOp module = getOperation();
815 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
816 if (!convertCtlz || !isConvertible(op))
822 auto key = std::pair(op->
getName(), resultType);
823 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
827 .Case<math::IPowIOp>([&](math::IPowIOp op) {
828 if (!isConvertible(op))
835 auto key = std::pair(op->getName(), resultType);
836 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
840 .Case<math::FPowIOp>([&](math::FPowIOp op) {
841 if (!isFPowIConvertible(op))
850 auto key = std::pair(op->getName(), funcType);
851 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
858 void ConvertMathToFuncsPass::runOnOperation() {
859 ModuleOp module = getOperation();
862 generateOpImplementations();
865 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
866 VecOpToScalarOp<math::CountLeadingZerosOp>>(
870 auto getFuncOpByType = [&](
Operation *op,
Type type) -> func::FuncOp {
871 auto it = funcImpls.find(std::pair(op->
getName(), type));
872 if (it == funcImpls.end())
884 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
885 func::FuncDialect, scf::SCFDialect,
886 vector::VectorDialect>();
888 target.addDynamicallyLegalOp<math::IPowIOp>(
889 [
this](math::IPowIOp op) {
return !isConvertible(op); });
891 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
892 [
this](math::CountLeadingZerosOp op) {
return !isConvertible(op); });
894 target.addDynamicallyLegalOp<math::FPowIOp>(
895 [
this](math::FPowIOp op) {
return !isFPowIConvertible(op); });
static MLIRContext * getContext(OpFoldResult val)
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getBlock() const
Returns the current block of the builder.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
result_type_iterator result_type_end()
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...