24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
40 template <
typename Op>
56 IPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
62 LogicalResult matchAndRewrite(math::IPowIOp op,
66 GetFuncCallbackTy getFuncOpCallback;
73 FPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
79 LogicalResult matchAndRewrite(math::FPowIOp op,
83 GetFuncCallbackTy getFuncOpCallback;
90 CtlzOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
92 getFuncOpCallback(cb) {}
96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
100 GetFuncCallbackTy getFuncOpCallback;
104 template <
typename Op>
107 Type opType = op.getType();
109 auto vecType = dyn_cast<VectorType>(opType);
113 if (!vecType.hasRank())
116 int64_t numElements = vecType.getNumElements();
118 Type resultElementType = vecType.getElementType();
120 if (isa<FloatType>(resultElementType))
127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
132 rewriter.
create<vector::ExtractOp>(loc, input, positions));
134 rewriter.
create<
Op>(loc, vecType.getElementType(), operands);
136 rewriter.
create<vector::InsertOp>(loc, scalarOp, result, positions);
147 [](
Type ty) { return getElementTypeOrSelf(ty); });
150 [](
Type ty) { return getElementTypeOrSelf(ty); });
186 assert(isa<IntegerType>(elementType) &&
187 "non-integer element type for IPowIOp");
192 std::string funcName(
"__mlir_math_ipowi");
193 llvm::raw_string_ostream nameOS(funcName);
194 nameOS <<
'_' << elementType;
197 builder.
getContext(), {elementType, elementType}, elementType);
198 auto funcOp = builder.
create<func::FuncOp>(funcName, funcType);
199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
202 funcOp->setAttr(
"llvm.linkage", linkage);
205 Block *entryBlock = funcOp.addEntryBlock();
208 Value bArg = funcOp.getArgument(0);
209 Value pArg = funcOp.getArgument(1);
211 Value zeroValue = builder.
create<arith::ConstantOp>(
213 Value oneValue = builder.
create<arith::ConstantOp>(
215 Value minusOneValue = builder.
create<arith::ConstantOp>(
224 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
226 builder.
create<func::ReturnOp>(oneValue);
230 builder.
create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
235 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
239 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
242 builder.
create<func::ReturnOp>(
243 builder.
create<arith::DivSIOp>(oneValue, zeroValue).getResult());
247 builder.
create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
252 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
255 builder.
create<func::ReturnOp>(oneValue);
259 builder.
create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
263 auto bIsMinusOne = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264 bArg, minusOneValue);
267 auto pIsOdd = builder.
create<arith::CmpIOp>(
268 arith::CmpIPredicate::ne, builder.
create<arith::AndIOp>(pArg, oneValue),
272 builder.
create<func::ReturnOp>(minusOneValue);
276 builder.
create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
281 builder.
create<func::ReturnOp>(oneValue);
285 builder.
create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
291 builder.
create<func::ReturnOp>(zeroValue);
293 funcBody, funcBody->
end(), {elementType, elementType, elementType},
298 builder.
create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
310 Value resultTmp = loopHeader->getArgument(0);
311 Value baseTmp = loopHeader->getArgument(1);
312 Value powerTmp = loopHeader->getArgument(2);
316 auto powerTmpIsOdd = builder.
create<arith::CmpIOp>(
317 arith::CmpIPredicate::ne,
318 builder.
create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
321 Value newResultTmp = builder.
create<arith::MulIOp>(resultTmp, baseTmp);
322 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), elementType,
325 builder.
create<cf::BranchOp>(newResultTmp, fallthroughBlock);
328 builder.
create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
331 newResultTmp = fallthroughBlock->getArgument(0);
335 Value newPowerTmp = builder.
create<arith::ShRUIOp>(powerTmp, oneValue);
338 auto newPowerIsZero = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339 newPowerTmp, zeroValue);
342 builder.
create<func::ReturnOp>(newResultTmp);
346 builder.
create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
351 Value newBaseTmp = builder.
create<arith::MulIOp>(baseTmp, baseTmp);
353 builder.
create<cf::BranchOp>(
354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
419 std::string funcName(
"__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS <<
'_' << baseType;
422 nameOS <<
'_' << powType;
423 auto funcOp = builder.
create<func::FuncOp>(funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
427 funcOp->setAttr(
"llvm.linkage", linkage);
430 Block *entryBlock = funcOp.addEntryBlock();
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
436 Value oneBValue = builder.
create<arith::ConstantOp>(
438 Value zeroPValue = builder.
create<arith::ConstantOp>(
440 Value onePValue = builder.
create<arith::ConstantOp>(
442 Value minPValue = builder.
create<arith::ConstantOp>(
443 powType, builder.
getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444 powType.getWidth())));
445 Value maxPValue = builder.
create<arith::ConstantOp>(
446 powType, builder.
getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447 powType.getWidth())));
452 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
454 builder.
create<func::ReturnOp>(oneBValue);
458 builder.
create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
462 auto pIsNeg = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
466 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
473 Value negP = builder.
create<arith::SubIOp>(zeroPValue, pArg);
474 auto pInit = builder.
create<arith::SelectOp>(pIsNeg, negP, pArg);
475 pInit = builder.
create<arith::SelectOp>(pIsMin, maxPValue, pInit);
488 funcBody, funcBody->
end(), {baseType, baseType, powType},
492 builder.
create<cf::BranchOp>(loopHeader,
ValueRange{oneBValue, bArg, pInit});
495 Value resultTmp = loopHeader->getArgument(0);
496 Value baseTmp = loopHeader->getArgument(1);
497 Value powerTmp = loopHeader->getArgument(2);
501 auto powerTmpIsOdd = builder.
create<arith::CmpIOp>(
502 arith::CmpIPredicate::ne,
503 builder.
create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
506 Value newResultTmp = builder.
create<arith::MulFOp>(resultTmp, baseTmp);
507 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
510 builder.
create<cf::BranchOp>(newResultTmp, fallthroughBlock);
513 builder.
create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
516 newResultTmp = fallthroughBlock->getArgument(0);
520 Value newPowerTmp = builder.
create<arith::ShRUIOp>(powerTmp, onePValue);
523 auto newPowerIsZero = builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq,
524 newPowerTmp, zeroPValue);
534 Value newBaseTmp = builder.
create<arith::MulFOp>(baseTmp, baseTmp);
536 builder.
create<cf::BranchOp>(
537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
545 builder.
create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
551 newResultTmp = loopExit->getArgument(0);
553 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
556 builder.
create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
559 newResultTmp = builder.
create<arith::MulFOp>(newResultTmp, bArg);
560 builder.
create<cf::BranchOp>(newResultTmp, fallthroughBlock);
565 newResultTmp = fallthroughBlock->getArgument(0);
570 builder.
create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
573 newResultTmp = builder.
create<arith::DivFOp>(oneBValue, newResultTmp);
574 builder.
create<cf::BranchOp>(newResultTmp, returnBlock);
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
589 if (dyn_cast<VectorType>(op.getType()))
596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
652 if (!isa<IntegerType>(elementType)) {
654 DBGS() <<
"non-integer element type for CtlzFunc; type was: ";
655 elementType.
print(llvm::dbgs());
657 llvm_unreachable(
"non-integer element type");
665 std::string funcName(
"__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS <<
'_' << elementType;
668 FunctionType funcType =
670 auto funcOp = builder.
create<func::FuncOp>(funcName, funcType);
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
677 funcOp->setAttr(
"llvm.linkage", linkage);
681 Block *funcBody = funcOp.addEntryBlock();
684 Value arg = funcOp.getArgument(0);
686 Value bitWidthValue = builder.
create<arith::ConstantOp>(
688 Value zeroValue = builder.
create<arith::ConstantOp>(
692 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
695 scf::IfOp ifOp = builder.
create<scf::IfOp>(
696 elementType, inputEqZero,
true,
true);
697 ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
702 Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703 indexType, elseBuilder.getIndexAttr(1));
704 Value oneValue = elseBuilder.create<arith::ConstantOp>(
705 elementType, elseBuilder.getIntegerAttr(elementType, 1));
706 Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707 indexType, elseBuilder.getIndexAttr(bitWidth));
708 Value nValue = elseBuilder.create<arith::ConstantOp>(
709 elementType, elseBuilder.getIntegerAttr(elementType, 0));
711 auto loop = elseBuilder.create<scf::ForOp>(
712 oneIndex, bitWidthIndex, oneIndex,
725 Value argIter = args[0];
726 Value nIter = args[1];
729 loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730 scf::IfOp ifOp = b.
create<scf::IfOp>(
731 loc, argIsNonNegative,
734 b.create<scf::YieldOp>(loc,
ValueRange{argIter, nIter});
738 Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739 Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740 b.create<scf::YieldOp>(loc,
ValueRange{argNext, nNext});
742 b.
create<scf::YieldOp>(loc, ifOp.getResults());
744 elseBuilder.create<scf::YieldOp>(loop.getResult(1));
746 builder.
create<func::ReturnOp>(ifOp.getResult(0));
752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
754 if (dyn_cast<VectorType>(op.getType()))
758 func::FuncOp elementFunc = getFuncOpCallback(op, type);
761 diag <<
"Missing software implementation for op " << op->
getName()
762 <<
" and type " << type;
770 struct ConvertMathToFuncsPass
771 :
public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772 ConvertMathToFuncsPass() =
default;
773 ConvertMathToFuncsPass(
const ConvertMathToFuncsOptions &
options)
774 :
impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(
options) {}
776 void runOnOperation()
override;
782 bool isFPowIConvertible(math::FPowIOp op);
786 void generateOpImplementations();
795 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
798 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
801 void ConvertMathToFuncsPass::generateOpImplementations() {
802 ModuleOp module = getOperation();
806 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
813 auto key = std::pair(op->
getName(), resultType);
814 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
818 .Case<math::IPowIOp>([&](math::IPowIOp op) {
823 auto key = std::pair(op->
getName(), resultType);
824 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
828 .Case<math::FPowIOp>([&](math::FPowIOp op) {
829 if (!isFPowIConvertible(op))
838 auto key = std::pair(op->
getName(), funcType);
839 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
846 void ConvertMathToFuncsPass::runOnOperation() {
847 ModuleOp module = getOperation();
850 generateOpImplementations();
853 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
854 VecOpToScalarOp<math::CountLeadingZerosOp>>(
855 patterns.getContext());
858 auto getFuncOpByType = [&](
Operation *op,
Type type) -> func::FuncOp {
859 auto it = funcImpls.find(std::pair(op->
getName(), type));
860 if (it == funcImpls.end())
865 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
869 patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
872 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
873 func::FuncDialect, scf::SCFDialect,
874 vector::VectorDialect>();
876 target.addIllegalOp<math::IPowIOp>();
878 target.addIllegalOp<math::CountLeadingZerosOp>();
879 target.addDynamicallyLegalOp<math::FPowIOp>(
880 [
this](math::FPowIOp op) {
return !isFPowIConvertible(op); });
static MLIRContext * getContext(OpFoldResult val)
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_iterator result_type_end()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...