24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
40 template <
typename Op>
56 IPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
62 LogicalResult matchAndRewrite(math::IPowIOp op,
66 GetFuncCallbackTy getFuncOpCallback;
73 FPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
79 LogicalResult matchAndRewrite(math::FPowIOp op,
83 GetFuncCallbackTy getFuncOpCallback;
90 CtlzOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
92 getFuncOpCallback(cb) {}
96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
100 GetFuncCallbackTy getFuncOpCallback;
104 template <
typename Op>
107 Type opType = op.getType();
109 auto vecType = dyn_cast<VectorType>(opType);
113 if (!vecType.hasRank())
116 int64_t numElements = vecType.getNumElements();
118 Type resultElementType = vecType.getElementType();
120 if (isa<FloatType>(resultElementType))
127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
130 for (
Value input : op->getOperands())
132 rewriter.
create<vector::ExtractOp>(loc, input, positions));
134 rewriter.
create<
Op>(loc, vecType.getElementType(), operands);
136 rewriter.
create<vector::InsertOp>(loc, scalarOp, result, positions);
147 [](
Type ty) { return getElementTypeOrSelf(ty); });
150 [](
Type ty) { return getElementTypeOrSelf(ty); });
186 assert(isa<IntegerType>(elementType) &&
187 "non-integer element type for IPowIOp");
192 std::string funcName(
"__mlir_math_ipowi");
193 llvm::raw_string_ostream nameOS(funcName);
194 nameOS <<
'_' << elementType;
197 builder.
getContext(), {elementType, elementType}, elementType);
198 auto funcOp = builder.
create<func::FuncOp>(funcName, funcType);
199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
202 funcOp->setAttr(
"llvm.linkage", linkage);
205 Block *entryBlock = funcOp.addEntryBlock();
208 Value bArg = funcOp.getArgument(0);
209 Value pArg = funcOp.getArgument(1);
211 Value zeroValue = builder.
create<arith::ConstantOp>(
213 Value oneValue = builder.
create<arith::ConstantOp>(
215 Value minusOneValue = builder.
create<arith::ConstantOp>(
224 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
226 builder.
create<func::ReturnOp>(oneValue);
230 builder.
create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
235 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
239 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
242 builder.
create<func::ReturnOp>(
243 builder.
create<arith::DivSIOp>(oneValue, zeroValue).getResult());
247 builder.
create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
252 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
255 builder.
create<func::ReturnOp>(oneValue);
259 builder.
create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
263 auto bIsMinusOne = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264 bArg, minusOneValue);
267 auto pIsOdd = builder.
create<arith::CmpIOp>(
268 arith::CmpIPredicate::ne, builder.
create<arith::AndIOp>(pArg, oneValue),
272 builder.
create<func::ReturnOp>(minusOneValue);
276 builder.
create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
281 builder.
create<func::ReturnOp>(oneValue);
285 builder.
create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
291 builder.
create<func::ReturnOp>(zeroValue);
293 funcBody, funcBody->
end(), {elementType, elementType, elementType},
298 builder.
create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
310 Value resultTmp = loopHeader->getArgument(0);
311 Value baseTmp = loopHeader->getArgument(1);
312 Value powerTmp = loopHeader->getArgument(2);
316 auto powerTmpIsOdd = builder.
create<arith::CmpIOp>(
317 arith::CmpIPredicate::ne,
318 builder.
create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
321 Value newResultTmp = builder.
create<arith::MulIOp>(resultTmp, baseTmp);
322 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), elementType,
325 builder.
create<cf::BranchOp>(newResultTmp, fallthroughBlock);
328 builder.
create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
331 newResultTmp = fallthroughBlock->getArgument(0);
335 Value newPowerTmp = builder.
create<arith::ShRUIOp>(powerTmp, oneValue);
338 auto newPowerIsZero = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339 newPowerTmp, zeroValue);
342 builder.
create<func::ReturnOp>(newResultTmp);
346 builder.
create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
351 Value newBaseTmp = builder.
create<arith::MulIOp>(baseTmp, baseTmp);
353 builder.
create<cf::BranchOp>(
354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
419 std::string funcName(
"__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS <<
'_' << baseType;
422 nameOS <<
'_' << powType;
423 auto funcOp = builder.
create<func::FuncOp>(funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
427 funcOp->setAttr(
"llvm.linkage", linkage);
430 Block *entryBlock = funcOp.addEntryBlock();
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
436 Value oneBValue = builder.
create<arith::ConstantOp>(
438 Value zeroPValue = builder.
create<arith::ConstantOp>(
440 Value onePValue = builder.
create<arith::ConstantOp>(
442 Value minPValue = builder.
create<arith::ConstantOp>(
443 powType, builder.
getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444 powType.getWidth())));
445 Value maxPValue = builder.
create<arith::ConstantOp>(
446 powType, builder.
getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447 powType.getWidth())));
452 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
454 builder.
create<func::ReturnOp>(oneBValue);
458 builder.
create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
462 auto pIsNeg = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
466 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
473 Value negP = builder.
create<arith::SubIOp>(zeroPValue, pArg);
474 auto pInit = builder.
create<arith::SelectOp>(pIsNeg, negP, pArg);
475 pInit = builder.
create<arith::SelectOp>(pIsMin, maxPValue, pInit);
488 funcBody, funcBody->
end(), {baseType, baseType, powType},
492 builder.
create<cf::BranchOp>(loopHeader,
ValueRange{oneBValue, bArg, pInit});
495 Value resultTmp = loopHeader->getArgument(0);
496 Value baseTmp = loopHeader->getArgument(1);
497 Value powerTmp = loopHeader->getArgument(2);
501 auto powerTmpIsOdd = builder.
create<arith::CmpIOp>(
502 arith::CmpIPredicate::ne,
503 builder.
create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
506 Value newResultTmp = builder.
create<arith::MulFOp>(resultTmp, baseTmp);
507 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
510 builder.
create<cf::BranchOp>(newResultTmp, fallthroughBlock);
513 builder.
create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
516 newResultTmp = fallthroughBlock->getArgument(0);
520 Value newPowerTmp = builder.
create<arith::ShRUIOp>(powerTmp, onePValue);
523 auto newPowerIsZero = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq,
524 newPowerTmp, zeroPValue);
534 Value newBaseTmp = builder.
create<arith::MulFOp>(baseTmp, baseTmp);
536 builder.
create<cf::BranchOp>(
537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
545 builder.
create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
551 newResultTmp = loopExit->getArgument(0);
553 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
556 builder.
create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
559 newResultTmp = builder.
create<arith::MulFOp>(newResultTmp, bArg);
560 builder.
create<cf::BranchOp>(newResultTmp, fallthroughBlock);
565 newResultTmp = fallthroughBlock->getArgument(0);
570 builder.
create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
573 newResultTmp = builder.
create<arith::DivFOp>(oneBValue, newResultTmp);
574 builder.
create<cf::BranchOp>(newResultTmp, returnBlock);
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
589 if (dyn_cast<VectorType>(op.getType()))
596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
652 if (!isa<IntegerType>(elementType)) {
654 DBGS() <<
"non-integer element type for CtlzFunc; type was: ";
655 elementType.
print(llvm::dbgs());
657 llvm_unreachable(
"non-integer element type");
665 std::string funcName(
"__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS <<
'_' << elementType;
668 FunctionType funcType =
670 auto funcOp = builder.
create<func::FuncOp>(funcName, funcType);
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
677 funcOp->setAttr(
"llvm.linkage", linkage);
681 Block *funcBody = funcOp.addEntryBlock();
684 Value arg = funcOp.getArgument(0);
686 Value bitWidthValue = builder.
create<arith::ConstantOp>(
688 Value zeroValue = builder.
create<arith::ConstantOp>(
692 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
695 scf::IfOp ifOp = builder.
create<scf::IfOp>(
696 elementType, inputEqZero,
true,
true);
697 ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
702 Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703 indexType, elseBuilder.getIndexAttr(1));
704 Value oneValue = elseBuilder.create<arith::ConstantOp>(
705 elementType, elseBuilder.getIntegerAttr(elementType, 1));
706 Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707 indexType, elseBuilder.getIndexAttr(bitWidth));
708 Value nValue = elseBuilder.create<arith::ConstantOp>(
709 elementType, elseBuilder.getIntegerAttr(elementType, 0));
711 auto loop = elseBuilder.create<scf::ForOp>(
712 oneIndex, bitWidthIndex, oneIndex,
725 Value argIter = args[0];
726 Value nIter = args[1];
729 loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730 scf::IfOp ifOp = b.
create<scf::IfOp>(
731 loc, argIsNonNegative,
734 b.create<scf::YieldOp>(loc,
ValueRange{argIter, nIter});
738 Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739 Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740 b.create<scf::YieldOp>(loc,
ValueRange{argNext, nNext});
742 b.
create<scf::YieldOp>(loc, ifOp.getResults());
744 elseBuilder.create<scf::YieldOp>(loop.getResult(1));
746 builder.
create<func::ReturnOp>(ifOp.getResult(0));
752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
754 if (dyn_cast<VectorType>(op.getType()))
758 func::FuncOp elementFunc = getFuncOpCallback(op, type);
761 diag <<
"Missing software implementation for op " << op->getName()
762 <<
" and type " << type;
770 struct ConvertMathToFuncsPass
771 :
public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772 ConvertMathToFuncsPass() =
default;
773 ConvertMathToFuncsPass(
const ConvertMathToFuncsOptions &
options)
774 :
impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(
options) {}
776 void runOnOperation()
override;
782 bool isFPowIConvertible(math::FPowIOp op);
789 void generateOpImplementations();
798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
801 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
804 bool ConvertMathToFuncsPass::isConvertible(
Operation *op) {
808 void ConvertMathToFuncsPass::generateOpImplementations() {
809 ModuleOp module = getOperation();
813 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
814 if (!convertCtlz || !isConvertible(op))
820 auto key = std::pair(op->
getName(), resultType);
821 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
825 .Case<math::IPowIOp>([&](math::IPowIOp op) {
826 if (!isConvertible(op))
833 auto key = std::pair(op->getName(), resultType);
834 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
838 .Case<math::FPowIOp>([&](math::FPowIOp op) {
839 if (!isFPowIConvertible(op))
848 auto key = std::pair(op->getName(), funcType);
849 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
856 void ConvertMathToFuncsPass::runOnOperation() {
857 ModuleOp module = getOperation();
860 generateOpImplementations();
863 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
864 VecOpToScalarOp<math::CountLeadingZerosOp>>(
865 patterns.getContext());
868 auto getFuncOpByType = [&](
Operation *op,
Type type) -> func::FuncOp {
869 auto it = funcImpls.find(std::pair(op->
getName(), type));
870 if (it == funcImpls.end())
875 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
879 patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
882 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
883 func::FuncDialect, scf::SCFDialect,
884 vector::VectorDialect>();
886 target.addDynamicallyLegalOp<math::IPowIOp>(
887 [
this](math::IPowIOp op) {
return !isConvertible(op); });
889 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
890 [
this](math::CountLeadingZerosOp op) {
return !isConvertible(op); });
892 target.addDynamicallyLegalOp<math::FPowIOp>(
893 [
this](math::FPowIOp op) {
return !isFPowIConvertible(op); });
static MLIRContext * getContext(OpFoldResult val)
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
result_type_iterator result_type_end()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...