MLIR  16.0.0git
MathToFuncs.cpp
Go to the documentation of this file.
1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 
33 namespace {
34 // Pattern to convert vector operations to scalar operations.
35 template <typename Op>
36 struct VecOpToScalarOp : public OpRewritePattern<Op> {
37 public:
39 
40  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
41 };
42 
43 // Callback type for getting pre-generated FuncOp implementing
44 // a power operation of the given type.
45 using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
46 
47 // Pattern to convert scalar IPowIOp into a call of outlined
48 // software implementation.
49 struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
50 
51 private:
52  GetPowerFuncCallbackTy getFuncOpCallback;
53 
54 public:
55  IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
56  : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
57 
58  /// Convert IPowI into a call to a local function implementing
59  /// the power operation. The local function computes a scalar result,
60  /// so vector forms of IPowI are linearized.
61  LogicalResult matchAndRewrite(math::IPowIOp op,
62  PatternRewriter &rewriter) const final;
63 };
64 } // namespace
65 
66 template <typename Op>
68 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
69  Type opType = op.getType();
70  Location loc = op.getLoc();
71  auto vecType = opType.template dyn_cast<VectorType>();
72 
73  if (!vecType)
74  return rewriter.notifyMatchFailure(op, "not a vector operation");
75  if (!vecType.hasRank())
76  return rewriter.notifyMatchFailure(op, "unknown vector rank");
77  ArrayRef<int64_t> shape = vecType.getShape();
78  int64_t numElements = vecType.getNumElements();
79 
80  Value result = rewriter.create<arith::ConstantOp>(
82  vecType, IntegerAttr::get(vecType.getElementType(), 0)));
83  SmallVector<int64_t> strides = computeStrides(shape);
84  for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
85  SmallVector<int64_t> positions = delinearize(strides, linearIndex);
86  SmallVector<Value> operands;
87  for (Value input : op->getOperands())
88  operands.push_back(
89  rewriter.create<vector::ExtractOp>(loc, input, positions));
90  Value scalarOp =
91  rewriter.create<Op>(loc, vecType.getElementType(), operands);
92  result =
93  rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
94  }
95  rewriter.replaceOp(op, result);
96  return success();
97 }
98 
99 /// Create linkonce_odr function to implement the power function with
100 /// the given \p funcType type inside \p module. \p funcType must be
101 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
102 ///
103 /// template <typename T>
104 /// T __mlir_math_ipowi_*(T b, T p) {
105 /// if (p == T(0))
106 /// return T(1);
107 /// if (p < T(0)) {
108 /// if (b == T(0))
109 /// return T(1) / T(0); // trigger div-by-zero
110 /// if (b == T(1))
111 /// return T(1);
112 /// if (b == T(-1)) {
113 /// if (p & T(1))
114 /// return T(-1);
115 /// return T(1);
116 /// }
117 /// return T(0);
118 /// }
119 /// T result = T(1);
120 /// while (true) {
121 /// if (p & T(1))
122 /// result *= b;
123 /// p >>= T(1);
124 /// if (p == T(0))
125 /// return result;
126 /// b *= b;
127 /// }
128 /// }
129 static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
130  assert(elementType.isa<IntegerType>() &&
131  "non-integer element type for IPowIOp");
132 
133  // IntegerType elementType = funcType.getInput(0).cast<IntegerType>();
134  ImplicitLocOpBuilder builder =
135  ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
136 
137  std::string funcName("__mlir_math_ipowi");
138  llvm::raw_string_ostream nameOS(funcName);
139  nameOS << '_' << elementType;
140 
141  FunctionType funcType = FunctionType::get(
142  builder.getContext(), {elementType, elementType}, elementType);
143  auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
144  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
145  Attribute linkage =
146  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
147  funcOp->setAttr("llvm.linkage", linkage);
148  funcOp.setPrivate();
149 
150  Block *entryBlock = funcOp.addEntryBlock();
151  Region *funcBody = entryBlock->getParent();
152 
153  Value bArg = funcOp.getArgument(0);
154  Value pArg = funcOp.getArgument(1);
155  builder.setInsertionPointToEnd(entryBlock);
156  Value zeroValue = builder.create<arith::ConstantOp>(
157  elementType, builder.getIntegerAttr(elementType, 0));
158  Value oneValue = builder.create<arith::ConstantOp>(
159  elementType, builder.getIntegerAttr(elementType, 1));
160  Value minusOneValue = builder.create<arith::ConstantOp>(
161  elementType,
162  builder.getIntegerAttr(elementType,
163  APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
164  /*isSigned=*/true)));
165 
166  // if (p == T(0))
167  // return T(1);
168  auto pIsZero =
169  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
170  Block *thenBlock = builder.createBlock(funcBody);
171  builder.create<func::ReturnOp>(oneValue);
172  Block *fallthroughBlock = builder.createBlock(funcBody);
173  // Set up conditional branch for (p == T(0)).
174  builder.setInsertionPointToEnd(pIsZero->getBlock());
175  builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
176 
177  // if (p < T(0)) {
178  builder.setInsertionPointToEnd(fallthroughBlock);
179  auto pIsNeg =
180  builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
181  // if (b == T(0))
182  builder.createBlock(funcBody);
183  auto bIsZero =
184  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
185  // return T(1) / T(0);
186  thenBlock = builder.createBlock(funcBody);
187  builder.create<func::ReturnOp>(
188  builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
189  fallthroughBlock = builder.createBlock(funcBody);
190  // Set up conditional branch for (b == T(0)).
191  builder.setInsertionPointToEnd(bIsZero->getBlock());
192  builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
193 
194  // if (b == T(1))
195  builder.setInsertionPointToEnd(fallthroughBlock);
196  auto bIsOne =
197  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
198  // return T(1);
199  thenBlock = builder.createBlock(funcBody);
200  builder.create<func::ReturnOp>(oneValue);
201  fallthroughBlock = builder.createBlock(funcBody);
202  // Set up conditional branch for (b == T(1)).
203  builder.setInsertionPointToEnd(bIsOne->getBlock());
204  builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
205 
206  // if (b == T(-1)) {
207  builder.setInsertionPointToEnd(fallthroughBlock);
208  auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
209  bArg, minusOneValue);
210  // if (p & T(1))
211  builder.createBlock(funcBody);
212  auto pIsOdd = builder.create<arith::CmpIOp>(
213  arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
214  zeroValue);
215  // return T(-1);
216  thenBlock = builder.createBlock(funcBody);
217  builder.create<func::ReturnOp>(minusOneValue);
218  fallthroughBlock = builder.createBlock(funcBody);
219  // Set up conditional branch for (p & T(1)).
220  builder.setInsertionPointToEnd(pIsOdd->getBlock());
221  builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
222 
223  // return T(1);
224  // } // b == T(-1)
225  builder.setInsertionPointToEnd(fallthroughBlock);
226  builder.create<func::ReturnOp>(oneValue);
227  fallthroughBlock = builder.createBlock(funcBody);
228  // Set up conditional branch for (b == T(-1)).
229  builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
230  builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
231  fallthroughBlock);
232 
233  // return T(0);
234  // } // (p < T(0))
235  builder.setInsertionPointToEnd(fallthroughBlock);
236  builder.create<func::ReturnOp>(zeroValue);
237  Block *loopHeader = builder.createBlock(
238  funcBody, funcBody->end(), {elementType, elementType, elementType},
239  {builder.getLoc(), builder.getLoc(), builder.getLoc()});
240  // Set up conditional branch for (p < T(0)).
241  builder.setInsertionPointToEnd(pIsNeg->getBlock());
242  // Set initial values of 'result', 'b' and 'p' for the loop.
243  builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
244  ValueRange{oneValue, bArg, pArg});
245 
246  // T result = T(1);
247  // while (true) {
248  // if (p & T(1))
249  // result *= b;
250  // p >>= T(1);
251  // if (p == T(0))
252  // return result;
253  // b *= b;
254  // }
255  Value resultTmp = loopHeader->getArgument(0);
256  Value baseTmp = loopHeader->getArgument(1);
257  Value powerTmp = loopHeader->getArgument(2);
258  builder.setInsertionPointToEnd(loopHeader);
259 
260  // if (p & T(1))
261  auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
262  arith::CmpIPredicate::ne,
263  builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
264  thenBlock = builder.createBlock(funcBody);
265  // result *= b;
266  Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
267  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
268  builder.getLoc());
269  builder.setInsertionPointToEnd(thenBlock);
270  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
271  // Set up conditional branch for (p & T(1)).
272  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
273  builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
274  resultTmp);
275  // Merged 'result'.
276  newResultTmp = fallthroughBlock->getArgument(0);
277 
278  // p >>= T(1);
279  builder.setInsertionPointToEnd(fallthroughBlock);
280  Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
281 
282  // if (p == T(0))
283  auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
284  newPowerTmp, zeroValue);
285  // return result;
286  thenBlock = builder.createBlock(funcBody);
287  builder.create<func::ReturnOp>(newResultTmp);
288  fallthroughBlock = builder.createBlock(funcBody);
289  // Set up conditional branch for (p == T(0)).
290  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
291  builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
292 
293  // b *= b;
294  // }
295  builder.setInsertionPointToEnd(fallthroughBlock);
296  Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
297  // Pass new values for 'result', 'b' and 'p' to the loop header.
298  builder.create<cf::BranchOp>(
299  ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
300  return funcOp;
301 }
302 
303 /// Convert IPowI into a call to a local function implementing
304 /// the power operation. The local function computes a scalar result,
305 /// so vector forms of IPowI are linearized.
307 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
308  PatternRewriter &rewriter) const {
309  auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>();
310 
311  if (!baseType)
312  return rewriter.notifyMatchFailure(op, "non-integer base operand");
313 
314  // The outlined software implementation must have been already
315  // generated.
316  func::FuncOp elementFunc = getFuncOpCallback(baseType);
317  if (!elementFunc)
318  return rewriter.notifyMatchFailure(op, "missing software implementation");
319 
320  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
321  return success();
322 }
323 
324 namespace {
325 struct ConvertMathToFuncsPass
326  : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
327  ConvertMathToFuncsPass() = default;
328 
329  void runOnOperation() override;
330 
331 private:
332  // Generate outlined implementations for power operations
333  // and store them in powerFuncs map.
334  void preprocessPowOperations();
335 
336  // A map between function types deduced from power operations
337  // and the corresponding outlined software implementations
338  // of these operations.
339  DenseMap<Type, func::FuncOp> powerFuncs;
340 };
341 } // namespace
342 
343 void ConvertMathToFuncsPass::preprocessPowOperations() {
344  ModuleOp module = getOperation();
345 
346  module.walk([&](Operation *op) {
347  TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) {
348  Type resultType = getElementTypeOrSelf(op.getResult().getType());
349 
350  // Generate the software implementation of this operation,
351  // if it has not been generated yet.
352  auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
353  if (entry.second)
354  entry.first->second = createElementIPowIFunc(&module, resultType);
355  });
356  });
357 }
358 
359 void ConvertMathToFuncsPass::runOnOperation() {
360  ModuleOp module = getOperation();
361 
362  // Create outlined implementations for power operations.
363  preprocessPowOperations();
364 
365  RewritePatternSet patterns(&getContext());
366  patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext());
367 
368  // For the given Type Returns FuncOp stored in powerFuncs map.
369  auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
370  auto it = powerFuncs.find(type);
371  if (it == powerFuncs.end())
372  return {};
373 
374  return it->second;
375  };
376  patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType);
377 
378  ConversionTarget target(getContext());
379  target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
380  func::FuncDialect, vector::VectorDialect>();
381  target.addIllegalOp<math::IPowIOp>();
382  if (failed(applyPartialConversion(module, target, std::move(patterns))))
383  signalPassFailure();
384 }
385 
386 std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() {
387  return std::make_unique<ConvertMathToFuncsPass>();
388 }
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
MLIRContext * getContext() const
Definition: Builders.h:54
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:108
This provides public APIs that all operations should have.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:517
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Given a set of sizes, compute and return the strides (i.e.
std::unique_ptr< Pass > createConvertMathToFuncsPass()
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356