23#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/DebugLog.h"
28#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
29#include "mlir/Conversion/Passes.h.inc"
34#define DEBUG_TYPE "math-to-funcs"
54 IPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
60 LogicalResult matchAndRewrite(math::IPowIOp op,
64 GetFuncCallbackTy getFuncOpCallback;
71 FPowIOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
77 LogicalResult matchAndRewrite(math::FPowIOp op,
81 GetFuncCallbackTy getFuncOpCallback;
88 CtlzOpLowering(
MLIRContext *context, GetFuncCallbackTy cb)
90 getFuncOpCallback(cb) {}
94 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
98 GetFuncCallbackTy getFuncOpCallback;
102template <
typename Op>
105 Type opType = op.getType();
107 auto vecType = dyn_cast<VectorType>(opType);
111 if (!vecType.hasRank())
114 int64_t numElements = vecType.getNumElements();
116 Type resultElementType = vecType.getElementType();
118 if (isa<FloatType>(resultElementType))
119 initValueAttr = FloatAttr::get(resultElementType, 0.0);
121 initValueAttr = IntegerAttr::get(resultElementType, 0);
125 for (
int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128 for (
Value input : op->getOperands())
130 vector::ExtractOp::create(rewriter, loc, input, positions));
132 Op::create(rewriter, loc, vecType.getElementType(), operands);
134 vector::InsertOp::create(rewriter, loc, scalarOp,
result, positions);
145 [](
Type ty) { return getElementTypeOrSelf(ty); });
148 [](
Type ty) { return getElementTypeOrSelf(ty); });
149 return FunctionType::get(op->
getContext(), inputTys, resultTys);
184 assert(isa<IntegerType>(elementType) &&
185 "non-integer element type for IPowIOp");
190 std::string funcName(
"__mlir_math_ipowi");
191 llvm::raw_string_ostream nameOS(funcName);
192 nameOS <<
'_' << elementType;
194 FunctionType funcType = FunctionType::get(
195 builder.
getContext(), {elementType, elementType}, elementType);
196 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
197 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
199 LLVM::LinkageAttr::get(builder.
getContext(), inlineLinkage);
200 funcOp->setAttr(
"llvm.linkage", linkage);
203 Block *entryBlock = funcOp.addEntryBlock();
206 Value bArg = funcOp.getArgument(0);
207 Value pArg = funcOp.getArgument(1);
209 Value zeroValue = arith::ConstantOp::create(
211 Value oneValue = arith::ConstantOp::create(
213 Value minusOneValue = arith::ConstantOp::create(
214 builder, elementType,
222 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue);
224 func::ReturnOp::create(builder, oneValue);
228 cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
232 auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
237 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue);
240 func::ReturnOp::create(
242 arith::DivSIOp::create(builder, oneValue, zeroValue).getResult());
246 cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock);
251 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue);
254 func::ReturnOp::create(builder, oneValue);
258 cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock);
262 auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
263 bArg, minusOneValue);
266 auto pIsOdd = arith::CmpIOp::create(
267 builder, arith::CmpIPredicate::ne,
268 arith::AndIOp::create(builder, pArg, oneValue), zeroValue);
271 func::ReturnOp::create(builder, minusOneValue);
275 cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock);
280 func::ReturnOp::create(builder, oneValue);
284 cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->
getBlock(),
290 func::ReturnOp::create(builder, zeroValue);
292 funcBody, funcBody->
end(), {elementType, elementType, elementType},
293 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
297 cf::CondBranchOp::create(builder, pIsNeg, bIsZero->
getBlock(), loopHeader,
315 auto powerTmpIsOdd = arith::CmpIOp::create(
316 builder, arith::CmpIPredicate::ne,
317 arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue);
320 Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp);
321 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), elementType,
324 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
327 cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
334 Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue);
337 auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
338 newPowerTmp, zeroValue);
341 func::ReturnOp::create(builder, newResultTmp);
345 cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock,
351 Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp);
353 cf::BranchOp::create(
354 builder,
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
362IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
419 std::string funcName(
"__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS <<
'_' << baseType;
422 nameOS <<
'_' << powType;
423 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
426 LLVM::LinkageAttr::get(builder.
getContext(), inlineLinkage);
427 funcOp->setAttr(
"llvm.linkage", linkage);
430 Block *entryBlock = funcOp.addEntryBlock();
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
436 Value oneBValue = arith::ConstantOp::create(
437 builder, baseType, builder.
getFloatAttr(baseType, 1.0));
438 Value zeroPValue = arith::ConstantOp::create(
440 Value onePValue = arith::ConstantOp::create(
442 Value minPValue = arith::ConstantOp::create(
445 powType, llvm::APInt::getSignedMinValue(powType.getWidth())));
446 Value maxPValue = arith::ConstantOp::create(
449 powType, llvm::APInt::getSignedMaxValue(powType.getWidth())));
453 auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg,
456 func::ReturnOp::create(builder, oneBValue);
460 cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
464 auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
468 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue);
475 Value negP = arith::SubIOp::create(builder, zeroPValue, pArg);
476 auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg);
477 pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit);
490 funcBody, funcBody->
end(), {baseType, baseType, powType},
491 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
494 cf::BranchOp::create(builder, loopHeader,
ValueRange{oneBValue, bArg, pInit});
503 auto powerTmpIsOdd = arith::CmpIOp::create(
504 builder, arith::CmpIPredicate::ne,
505 arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue);
508 Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp);
509 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
512 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
515 cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
522 Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue);
525 auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
526 newPowerTmp, zeroPValue);
536 Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp);
538 cf::BranchOp::create(
539 builder,
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
547 cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp,
553 newResultTmp = loopExit->getArgument(0);
555 fallthroughBlock = builder.
createBlock(funcBody, funcBody->
end(), baseType,
558 cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock,
561 newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg);
562 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
572 cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock,
575 newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp);
576 cf::BranchOp::create(builder, newResultTmp, returnBlock);
580 func::ReturnOp::create(builder, returnBlock->
getArgument(0));
589FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
590 PatternRewriter &rewriter)
const {
591 if (isa<VectorType>(op.getType()))
598 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
654 if (!isa<IntegerType>(elementType)) {
655 LDBG() <<
"non-integer element type for CtlzFunc; type was: "
657 llvm_unreachable(
"non-integer element type");
665 std::string funcName(
"__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS <<
'_' << elementType;
668 FunctionType funcType =
669 FunctionType::get(builder.
getContext(), {elementType}, elementType);
670 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
676 LLVM::LinkageAttr::get(builder.
getContext(), inlineLinkage);
677 funcOp->setAttr(
"llvm.linkage", linkage);
681 Block *funcBody = funcOp.addEntryBlock();
684 Value arg = funcOp.getArgument(0);
686 Value bitWidthValue = arith::ConstantOp::create(
687 builder, elementType, builder.
getIntegerAttr(elementType, bitWidth));
688 Value zeroValue = arith::ConstantOp::create(
692 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue);
696 scf::IfOp::create(builder, elementType, inputEqZero,
698 auto thenBuilder = ifOp.getThenBodyBuilder();
699 scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
704 Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType,
705 elseBuilder.getIndexAttr(1));
706 Value oneValue = arith::ConstantOp::create(
707 elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1));
708 Value bitWidthIndex = arith::ConstantOp::create(
709 elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth));
710 Value nValue = arith::ConstantOp::create(
711 elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0));
713 auto loop = scf::ForOp::create(
714 elseBuilder, oneIndex, bitWidthIndex, oneIndex,
727 Value argIter = args[0];
728 Value nIter = args[1];
730 Value argIsNonNegative = arith::CmpIOp::create(
731 b, loc, arith::CmpIPredicate::slt, argIter, zeroValue);
732 scf::IfOp ifOp = scf::IfOp::create(
733 b, loc, argIsNonNegative,
736 scf::YieldOp::create(
b, loc,
ValueRange{argIter, nIter});
740 Value nNext = arith::AddIOp::create(
b, loc, nIter, oneValue);
741 Value argNext = arith::ShLIOp::create(
b, loc, argIter, oneValue);
742 scf::YieldOp::create(
b, loc,
ValueRange{argNext, nNext});
744 scf::YieldOp::create(
b, loc, ifOp.getResults());
746 scf::YieldOp::create(elseBuilder, loop.getResult(1));
748 func::ReturnOp::create(builder, ifOp.getResult(0));
754LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
755 PatternRewriter &rewriter)
const {
756 if (isa<VectorType>(op.getType()))
760 func::FuncOp elementFunc = getFuncOpCallback(op, type);
763 diag <<
"Missing software implementation for op " << op->getName()
764 <<
" and type " << type;
772struct ConvertMathToFuncsPass
773 :
public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
774 ConvertMathToFuncsPass() =
default;
775 ConvertMathToFuncsPass(
const ConvertMathToFuncsOptions &
options)
776 : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(
options) {}
778 void runOnOperation()
override;
784 bool isFPowIConvertible(math::FPowIOp op);
787 bool isConvertible(Operation *op);
791 void generateOpImplementations();
800bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
803 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
806bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
810void ConvertMathToFuncsPass::generateOpImplementations() {
811 ModuleOp module = getOperation();
813 module.walk([&](Operation *op) {
814 TypeSwitch<Operation *>(op)
815 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
816 if (!convertCtlz || !isConvertible(op))
822 auto key = std::pair(op->
getName(), resultType);
823 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
827 .Case<math::IPowIOp>([&](math::IPowIOp op) {
828 if (!isConvertible(op))
835 auto key = std::pair(op->getName(), resultType);
836 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
840 .Case<math::FPowIOp>([&](math::FPowIOp op) {
841 if (!isFPowIConvertible(op))
850 auto key = std::pair(op->getName(), funcType);
851 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
858void ConvertMathToFuncsPass::runOnOperation() {
859 ModuleOp module = getOperation();
862 generateOpImplementations();
865 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
866 VecOpToScalarOp<math::CountLeadingZerosOp>>(
870 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
871 auto it = funcImpls.find(std::pair(op->
getName(), type));
872 if (it == funcImpls.end())
884 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
885 func::FuncDialect, scf::SCFDialect,
886 vector::VectorDialect>();
888 target.addDynamicallyLegalOp<math::IPowIOp>(
889 [
this](math::IPowIOp op) {
return !isConvertible(op); });
891 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
892 [
this](math::CountLeadingZerosOp op) {
return !isConvertible(op); });
894 target.addDynamicallyLegalOp<math::FPowIOp>(
895 [
this](math::FPowIOp op) {
return !isFPowIConvertible(op); });
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
unsigned getNumOperands()
result_type_iterator result_type_end()
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...