MLIR  19.0.0git
MathToFuncs.cpp
Go to the documentation of this file.
1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 
38 namespace {
39 // Pattern to convert vector operations to scalar operations.
40 template <typename Op>
41 struct VecOpToScalarOp : public OpRewritePattern<Op> {
42 public:
44 
45  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47 
48 // Callback type for getting pre-generated FuncOp implementing
49 // an operation of the given type.
50 using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
51 
52 // Pattern to convert scalar IPowIOp into a call of outlined
53 // software implementation.
54 class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
55 public:
56  IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
57  : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
58 
59  /// Convert IPowI into a call to a local function implementing
60  /// the power operation. The local function computes a scalar result,
61  /// so vector forms of IPowI are linearized.
62  LogicalResult matchAndRewrite(math::IPowIOp op,
63  PatternRewriter &rewriter) const final;
64 
65 private:
66  GetFuncCallbackTy getFuncOpCallback;
67 };
68 
69 // Pattern to convert scalar FPowIOp into a call of outlined
70 // software implementation.
71 class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
72 public:
73  FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
74  : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
75 
76  /// Convert FPowI into a call to a local function implementing
77  /// the power operation. The local function computes a scalar result,
78  /// so vector forms of FPowI are linearized.
79  LogicalResult matchAndRewrite(math::FPowIOp op,
80  PatternRewriter &rewriter) const final;
81 
82 private:
83  GetFuncCallbackTy getFuncOpCallback;
84 };
85 
86 // Pattern to convert scalar ctlz into a call of outlined software
87 // implementation.
88 class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
89 public:
90  CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
92  getFuncOpCallback(cb) {}
93 
94  /// Convert ctlz into a call to a local function implementing
95  /// the count leading zeros operation.
96  LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
97  PatternRewriter &rewriter) const final;
98 
99 private:
100  GetFuncCallbackTy getFuncOpCallback;
101 };
102 } // namespace
103 
104 template <typename Op>
106 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107  Type opType = op.getType();
108  Location loc = op.getLoc();
109  auto vecType = dyn_cast<VectorType>(opType);
110 
111  if (!vecType)
112  return rewriter.notifyMatchFailure(op, "not a vector operation");
113  if (!vecType.hasRank())
114  return rewriter.notifyMatchFailure(op, "unknown vector rank");
115  ArrayRef<int64_t> shape = vecType.getShape();
116  int64_t numElements = vecType.getNumElements();
117 
118  Type resultElementType = vecType.getElementType();
119  Attribute initValueAttr;
120  if (isa<FloatType>(resultElementType))
121  initValueAttr = FloatAttr::get(resultElementType, 0.0);
122  else
123  initValueAttr = IntegerAttr::get(resultElementType, 0);
124  Value result = rewriter.create<arith::ConstantOp>(
125  loc, DenseElementsAttr::get(vecType, initValueAttr));
126  SmallVector<int64_t> strides = computeStrides(shape);
127  for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128  SmallVector<int64_t> positions = delinearize(linearIndex, strides);
129  SmallVector<Value> operands;
130  for (Value input : op->getOperands())
131  operands.push_back(
132  rewriter.create<vector::ExtractOp>(loc, input, positions));
133  Value scalarOp =
134  rewriter.create<Op>(loc, vecType.getElementType(), operands);
135  result =
136  rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
137  }
138  rewriter.replaceOp(op, result);
139  return success();
140 }
141 
142 static FunctionType getElementalFuncTypeForOp(Operation *op) {
143  SmallVector<Type, 1> resultTys(op->getNumResults());
144  SmallVector<Type, 2> inputTys(op->getNumOperands());
145  std::transform(op->result_type_begin(), op->result_type_end(),
146  resultTys.begin(),
147  [](Type ty) { return getElementTypeOrSelf(ty); });
148  std::transform(op->operand_type_begin(), op->operand_type_end(),
149  inputTys.begin(),
150  [](Type ty) { return getElementTypeOrSelf(ty); });
151  return FunctionType::get(op->getContext(), inputTys, resultTys);
152 }
153 
154 /// Create linkonce_odr function to implement the power function with
155 /// the given \p elementType type inside \p module. The \p elementType
156 /// must be IntegerType, an the created function has
157 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
158 ///
159 /// template <typename T>
160 /// T __mlir_math_ipowi_*(T b, T p) {
161 /// if (p == T(0))
162 /// return T(1);
163 /// if (p < T(0)) {
164 /// if (b == T(0))
165 /// return T(1) / T(0); // trigger div-by-zero
166 /// if (b == T(1))
167 /// return T(1);
168 /// if (b == T(-1)) {
169 /// if (p & T(1))
170 /// return T(-1);
171 /// return T(1);
172 /// }
173 /// return T(0);
174 /// }
175 /// T result = T(1);
176 /// while (true) {
177 /// if (p & T(1))
178 /// result *= b;
179 /// p >>= T(1);
180 /// if (p == T(0))
181 /// return result;
182 /// b *= b;
183 /// }
184 /// }
185 static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
186  assert(isa<IntegerType>(elementType) &&
187  "non-integer element type for IPowIOp");
188 
189  ImplicitLocOpBuilder builder =
190  ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
191 
192  std::string funcName("__mlir_math_ipowi");
193  llvm::raw_string_ostream nameOS(funcName);
194  nameOS << '_' << elementType;
195 
196  FunctionType funcType = FunctionType::get(
197  builder.getContext(), {elementType, elementType}, elementType);
198  auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
199  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200  Attribute linkage =
201  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
202  funcOp->setAttr("llvm.linkage", linkage);
203  funcOp.setPrivate();
204 
205  Block *entryBlock = funcOp.addEntryBlock();
206  Region *funcBody = entryBlock->getParent();
207 
208  Value bArg = funcOp.getArgument(0);
209  Value pArg = funcOp.getArgument(1);
210  builder.setInsertionPointToEnd(entryBlock);
211  Value zeroValue = builder.create<arith::ConstantOp>(
212  elementType, builder.getIntegerAttr(elementType, 0));
213  Value oneValue = builder.create<arith::ConstantOp>(
214  elementType, builder.getIntegerAttr(elementType, 1));
215  Value minusOneValue = builder.create<arith::ConstantOp>(
216  elementType,
217  builder.getIntegerAttr(elementType,
218  APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
219  /*isSigned=*/true)));
220 
221  // if (p == T(0))
222  // return T(1);
223  auto pIsZero =
224  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
225  Block *thenBlock = builder.createBlock(funcBody);
226  builder.create<func::ReturnOp>(oneValue);
227  Block *fallthroughBlock = builder.createBlock(funcBody);
228  // Set up conditional branch for (p == T(0)).
229  builder.setInsertionPointToEnd(pIsZero->getBlock());
230  builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
231 
232  // if (p < T(0)) {
233  builder.setInsertionPointToEnd(fallthroughBlock);
234  auto pIsNeg =
235  builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
236  // if (b == T(0))
237  builder.createBlock(funcBody);
238  auto bIsZero =
239  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
240  // return T(1) / T(0);
241  thenBlock = builder.createBlock(funcBody);
242  builder.create<func::ReturnOp>(
243  builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
244  fallthroughBlock = builder.createBlock(funcBody);
245  // Set up conditional branch for (b == T(0)).
246  builder.setInsertionPointToEnd(bIsZero->getBlock());
247  builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
248 
249  // if (b == T(1))
250  builder.setInsertionPointToEnd(fallthroughBlock);
251  auto bIsOne =
252  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
253  // return T(1);
254  thenBlock = builder.createBlock(funcBody);
255  builder.create<func::ReturnOp>(oneValue);
256  fallthroughBlock = builder.createBlock(funcBody);
257  // Set up conditional branch for (b == T(1)).
258  builder.setInsertionPointToEnd(bIsOne->getBlock());
259  builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
260 
261  // if (b == T(-1)) {
262  builder.setInsertionPointToEnd(fallthroughBlock);
263  auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264  bArg, minusOneValue);
265  // if (p & T(1))
266  builder.createBlock(funcBody);
267  auto pIsOdd = builder.create<arith::CmpIOp>(
268  arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
269  zeroValue);
270  // return T(-1);
271  thenBlock = builder.createBlock(funcBody);
272  builder.create<func::ReturnOp>(minusOneValue);
273  fallthroughBlock = builder.createBlock(funcBody);
274  // Set up conditional branch for (p & T(1)).
275  builder.setInsertionPointToEnd(pIsOdd->getBlock());
276  builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
277 
278  // return T(1);
279  // } // b == T(-1)
280  builder.setInsertionPointToEnd(fallthroughBlock);
281  builder.create<func::ReturnOp>(oneValue);
282  fallthroughBlock = builder.createBlock(funcBody);
283  // Set up conditional branch for (b == T(-1)).
284  builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
285  builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
286  fallthroughBlock);
287 
288  // return T(0);
289  // } // (p < T(0))
290  builder.setInsertionPointToEnd(fallthroughBlock);
291  builder.create<func::ReturnOp>(zeroValue);
292  Block *loopHeader = builder.createBlock(
293  funcBody, funcBody->end(), {elementType, elementType, elementType},
294  {builder.getLoc(), builder.getLoc(), builder.getLoc()});
295  // Set up conditional branch for (p < T(0)).
296  builder.setInsertionPointToEnd(pIsNeg->getBlock());
297  // Set initial values of 'result', 'b' and 'p' for the loop.
298  builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
299  ValueRange{oneValue, bArg, pArg});
300 
301  // T result = T(1);
302  // while (true) {
303  // if (p & T(1))
304  // result *= b;
305  // p >>= T(1);
306  // if (p == T(0))
307  // return result;
308  // b *= b;
309  // }
310  Value resultTmp = loopHeader->getArgument(0);
311  Value baseTmp = loopHeader->getArgument(1);
312  Value powerTmp = loopHeader->getArgument(2);
313  builder.setInsertionPointToEnd(loopHeader);
314 
315  // if (p & T(1))
316  auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
317  arith::CmpIPredicate::ne,
318  builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
319  thenBlock = builder.createBlock(funcBody);
320  // result *= b;
321  Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
322  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
323  builder.getLoc());
324  builder.setInsertionPointToEnd(thenBlock);
325  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
326  // Set up conditional branch for (p & T(1)).
327  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
328  builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
329  resultTmp);
330  // Merged 'result'.
331  newResultTmp = fallthroughBlock->getArgument(0);
332 
333  // p >>= T(1);
334  builder.setInsertionPointToEnd(fallthroughBlock);
335  Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
336 
337  // if (p == T(0))
338  auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339  newPowerTmp, zeroValue);
340  // return result;
341  thenBlock = builder.createBlock(funcBody);
342  builder.create<func::ReturnOp>(newResultTmp);
343  fallthroughBlock = builder.createBlock(funcBody);
344  // Set up conditional branch for (p == T(0)).
345  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
346  builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
347 
348  // b *= b;
349  // }
350  builder.setInsertionPointToEnd(fallthroughBlock);
351  Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
352  // Pass new values for 'result', 'b' and 'p' to the loop header.
353  builder.create<cf::BranchOp>(
354  ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355  return funcOp;
356 }
357 
358 /// Convert IPowI into a call to a local function implementing
359 /// the power operation. The local function computes a scalar result,
360 /// so vector forms of IPowI are linearized.
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363  PatternRewriter &rewriter) const {
364  auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
365 
366  if (!baseType)
367  return rewriter.notifyMatchFailure(op, "non-integer base operand");
368 
369  // The outlined software implementation must have been already
370  // generated.
371  func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372  if (!elementFunc)
373  return rewriter.notifyMatchFailure(op, "missing software implementation");
374 
375  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376  return success();
377 }
378 
379 /// Create linkonce_odr function to implement the power function with
380 /// the given \p funcType type inside \p module. The \p funcType must be
381 /// 'FloatType (*)(FloatType, IntegerType)' function type.
382 ///
383 /// template <typename T>
384 /// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385 /// if (p == Tp{0})
386 /// return Tb{1};
387 /// bool isNegativePower{p < Tp{0}}
388 /// bool isMin{p == std::numeric_limits<Tp>::min()};
389 /// if (isMin) {
390 /// p = std::numeric_limits<Tp>::max();
391 /// } else if (isNegativePower) {
392 /// p = -p;
393 /// }
394 /// Tb result = Tb{1};
395 /// Tb origBase = Tb{b};
396 /// while (true) {
397 /// if (p & Tp{1})
398 /// result *= b;
399 /// p >>= Tp{1};
400 /// if (p == Tp{0})
401 /// break;
402 /// b *= b;
403 /// }
404 /// if (isMin) {
405 /// result *= origBase;
406 /// }
407 /// if (isNegativePower) {
408 /// result = Tb{1} / result;
409 /// }
410 /// return result;
411 /// }
412 static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413  FunctionType funcType) {
414  auto baseType = cast<FloatType>(funcType.getInput(0));
415  auto powType = cast<IntegerType>(funcType.getInput(1));
416  ImplicitLocOpBuilder builder =
417  ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
418 
419  std::string funcName("__mlir_math_fpowi");
420  llvm::raw_string_ostream nameOS(funcName);
421  nameOS << '_' << baseType;
422  nameOS << '_' << powType;
423  auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
424  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425  Attribute linkage =
426  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
427  funcOp->setAttr("llvm.linkage", linkage);
428  funcOp.setPrivate();
429 
430  Block *entryBlock = funcOp.addEntryBlock();
431  Region *funcBody = entryBlock->getParent();
432 
433  Value bArg = funcOp.getArgument(0);
434  Value pArg = funcOp.getArgument(1);
435  builder.setInsertionPointToEnd(entryBlock);
436  Value oneBValue = builder.create<arith::ConstantOp>(
437  baseType, builder.getFloatAttr(baseType, 1.0));
438  Value zeroPValue = builder.create<arith::ConstantOp>(
439  powType, builder.getIntegerAttr(powType, 0));
440  Value onePValue = builder.create<arith::ConstantOp>(
441  powType, builder.getIntegerAttr(powType, 1));
442  Value minPValue = builder.create<arith::ConstantOp>(
443  powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444  powType.getWidth())));
445  Value maxPValue = builder.create<arith::ConstantOp>(
446  powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447  powType.getWidth())));
448 
449  // if (p == Tp{0})
450  // return Tb{1};
451  auto pIsZero =
452  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
453  Block *thenBlock = builder.createBlock(funcBody);
454  builder.create<func::ReturnOp>(oneBValue);
455  Block *fallthroughBlock = builder.createBlock(funcBody);
456  // Set up conditional branch for (p == Tp{0}).
457  builder.setInsertionPointToEnd(pIsZero->getBlock());
458  builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
459 
460  builder.setInsertionPointToEnd(fallthroughBlock);
461  // bool isNegativePower{p < Tp{0}}
462  auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
463  zeroPValue);
464  // bool isMin{p == std::numeric_limits<Tp>::min()};
465  auto pIsMin =
466  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
467 
468  // if (isMin) {
469  // p = std::numeric_limits<Tp>::max();
470  // } else if (isNegativePower) {
471  // p = -p;
472  // }
473  Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
474  auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
475  pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
476 
477  // Tb result = Tb{1};
478  // Tb origBase = Tb{b};
479  // while (true) {
480  // if (p & Tp{1})
481  // result *= b;
482  // p >>= Tp{1};
483  // if (p == Tp{0})
484  // break;
485  // b *= b;
486  // }
487  Block *loopHeader = builder.createBlock(
488  funcBody, funcBody->end(), {baseType, baseType, powType},
489  {builder.getLoc(), builder.getLoc(), builder.getLoc()});
490  // Set initial values of 'result', 'b' and 'p' for the loop.
491  builder.setInsertionPointToEnd(pInit->getBlock());
492  builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
493 
494  // Create loop body.
495  Value resultTmp = loopHeader->getArgument(0);
496  Value baseTmp = loopHeader->getArgument(1);
497  Value powerTmp = loopHeader->getArgument(2);
498  builder.setInsertionPointToEnd(loopHeader);
499 
500  // if (p & Tp{1})
501  auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
502  arith::CmpIPredicate::ne,
503  builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
504  thenBlock = builder.createBlock(funcBody);
505  // result *= b;
506  Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
507  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
508  builder.getLoc());
509  builder.setInsertionPointToEnd(thenBlock);
510  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
511  // Set up conditional branch for (p & Tp{1}).
512  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
513  builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
514  resultTmp);
515  // Merged 'result'.
516  newResultTmp = fallthroughBlock->getArgument(0);
517 
518  // p >>= Tp{1};
519  builder.setInsertionPointToEnd(fallthroughBlock);
520  Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
521 
522  // if (p == Tp{0})
523  auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
524  newPowerTmp, zeroPValue);
525  // break;
526  //
527  // The conditional branch is finalized below with a jump to
528  // the loop exit block.
529  fallthroughBlock = builder.createBlock(funcBody);
530 
531  // b *= b;
532  // }
533  builder.setInsertionPointToEnd(fallthroughBlock);
534  Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
535  // Pass new values for 'result', 'b' and 'p' to the loop header.
536  builder.create<cf::BranchOp>(
537  ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
538 
539  // Set up conditional branch for early loop exit:
540  // if (p == Tp{0})
541  // break;
542  Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
543  builder.getLoc());
544  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
545  builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
546  fallthroughBlock, ValueRange{});
547 
548  // if (isMin) {
549  // result *= origBase;
550  // }
551  newResultTmp = loopExit->getArgument(0);
552  thenBlock = builder.createBlock(funcBody);
553  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
554  builder.getLoc());
555  builder.setInsertionPointToEnd(loopExit);
556  builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
557  newResultTmp);
558  builder.setInsertionPointToEnd(thenBlock);
559  newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
560  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
561 
562  /// if (isNegativePower) {
563  /// result = Tb{1} / result;
564  /// }
565  newResultTmp = fallthroughBlock->getArgument(0);
566  thenBlock = builder.createBlock(funcBody);
567  Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
568  builder.getLoc());
569  builder.setInsertionPointToEnd(fallthroughBlock);
570  builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
571  newResultTmp);
572  builder.setInsertionPointToEnd(thenBlock);
573  newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
574  builder.create<cf::BranchOp>(newResultTmp, returnBlock);
575 
576  // return result;
577  builder.setInsertionPointToEnd(returnBlock);
578  builder.create<func::ReturnOp>(returnBlock->getArgument(0));
579 
580  return funcOp;
581 }
582 
583 /// Convert FPowI into a call to a local function implementing
584 /// the power operation. The local function computes a scalar result,
585 /// so vector forms of FPowI are linearized.
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
588  PatternRewriter &rewriter) const {
589  if (dyn_cast<VectorType>(op.getType()))
590  return rewriter.notifyMatchFailure(op, "non-scalar operation");
591 
592  FunctionType funcType = getElementalFuncTypeForOp(op);
593 
594  // The outlined software implementation must have been already
595  // generated.
596  func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
597  if (!elementFunc)
598  return rewriter.notifyMatchFailure(op, "missing software implementation");
599 
600  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
601  return success();
602 }
603 
604 /// Create function to implement the ctlz function the given \p elementType type
605 /// inside \p module. The \p elementType must be IntegerType, an the created
606 /// function has 'IntegerType (*)(IntegerType)' function type.
607 ///
608 /// template <typename T>
609 /// T __mlir_math_ctlz_*(T x) {
610 /// bits = sizeof(x) * 8;
611 /// if (x == 0)
612 /// return bits;
613 ///
614 /// uint32_t n = 0;
615 /// for (int i = 1; i < bits; ++i) {
616 /// if (x < 0) continue;
617 /// n++;
618 /// x <<= 1;
619 /// }
620 /// return n;
621 /// }
622 ///
623 /// Converts to (for i32):
624 ///
625 /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626 /// %c_32 = arith.constant 32 : index
627 /// %c_0 = arith.constant 0 : i32
628 /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629 /// %out = scf.if %arg_eq_zero {
630 /// scf.yield %c_32 : i32
631 /// } else {
632 /// %c_1index = arith.constant 1 : index
633 /// %c_1i32 = arith.constant 1 : i32
634 /// %n = arith.constant 0 : i32
635 /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636 /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637 /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638 /// %yield_val = scf.if %cond {
639 /// scf.yield %arg_iter, %n_iter : i32, i32
640 /// } else {
641 /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642 /// %n_next = arith.addi %n_iter, %c_1i32 : i32
643 /// scf.yield %arg_next, %n_next : i32, i32
644 /// }
645 /// scf.yield %yield_val: i32, i32
646 /// }
647 /// scf.yield %n_out : i32
648 /// }
649 /// return %out: i32
650 /// }
651 static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
652  if (!isa<IntegerType>(elementType)) {
653  LLVM_DEBUG({
654  DBGS() << "non-integer element type for CtlzFunc; type was: ";
655  elementType.print(llvm::dbgs());
656  });
657  llvm_unreachable("non-integer element type");
658  }
659  int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660 
661  Location loc = module->getLoc();
662  ImplicitLocOpBuilder builder =
663  ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
664 
665  std::string funcName("__mlir_math_ctlz");
666  llvm::raw_string_ostream nameOS(funcName);
667  nameOS << '_' << elementType;
668  FunctionType funcType =
669  FunctionType::get(builder.getContext(), {elementType}, elementType);
670  auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
671 
672  // LinkonceODR ensures that there is only one implementation of this function
673  // across all math.ctlz functions that are lowered in this way.
674  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675  Attribute linkage =
676  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677  funcOp->setAttr("llvm.linkage", linkage);
678  funcOp.setPrivate();
679 
680  // set the insertion point to the start of the function
681  Block *funcBody = funcOp.addEntryBlock();
682  builder.setInsertionPointToStart(funcBody);
683 
684  Value arg = funcOp.getArgument(0);
685  Type indexType = builder.getIndexType();
686  Value bitWidthValue = builder.create<arith::ConstantOp>(
687  elementType, builder.getIntegerAttr(elementType, bitWidth));
688  Value zeroValue = builder.create<arith::ConstantOp>(
689  elementType, builder.getIntegerAttr(elementType, 0));
690 
691  Value inputEqZero =
692  builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
693 
694  // if input == 0, return bit width, else enter loop.
695  scf::IfOp ifOp = builder.create<scf::IfOp>(
696  elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
697  ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
698 
699  auto elseBuilder =
700  ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
701 
702  Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703  indexType, elseBuilder.getIndexAttr(1));
704  Value oneValue = elseBuilder.create<arith::ConstantOp>(
705  elementType, elseBuilder.getIntegerAttr(elementType, 1));
706  Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707  indexType, elseBuilder.getIndexAttr(bitWidth));
708  Value nValue = elseBuilder.create<arith::ConstantOp>(
709  elementType, elseBuilder.getIntegerAttr(elementType, 0));
710 
711  auto loop = elseBuilder.create<scf::ForOp>(
712  oneIndex, bitWidthIndex, oneIndex,
713  // Initial values for two loop induction variables, the arg which is being
714  // shifted left in each iteration, and the n value which tracks the count
715  // of leading zeros.
716  ValueRange{arg, nValue},
717  // Callback to build the body of the for loop
718  // if (arg < 0) {
719  // continue;
720  // } else {
721  // n++;
722  // arg <<= 1;
723  // }
724  [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
725  Value argIter = args[0];
726  Value nIter = args[1];
727 
728  Value argIsNonNegative = b.create<arith::CmpIOp>(
729  loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730  scf::IfOp ifOp = b.create<scf::IfOp>(
731  loc, argIsNonNegative,
732  [&](OpBuilder &b, Location loc) {
733  // If arg is negative, continue (effectively, break)
734  b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
735  },
736  [&](OpBuilder &b, Location loc) {
737  // Otherwise, increment n and shift arg left.
738  Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739  Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740  b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
741  });
742  b.create<scf::YieldOp>(loc, ifOp.getResults());
743  });
744  elseBuilder.create<scf::YieldOp>(loop.getResult(1));
745 
746  builder.create<func::ReturnOp>(ifOp.getResult(0));
747  return funcOp;
748 }
749 
750 /// Convert ctlz into a call to a local function implementing the ctlz
751 /// operation.
752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
753  PatternRewriter &rewriter) const {
754  if (dyn_cast<VectorType>(op.getType()))
755  return rewriter.notifyMatchFailure(op, "non-scalar operation");
756 
758  func::FuncOp elementFunc = getFuncOpCallback(op, type);
759  if (!elementFunc)
760  return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
761  diag << "Missing software implementation for op " << op->getName()
762  << " and type " << type;
763  });
764 
765  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
766  return success();
767 }
768 
769 namespace {
770 struct ConvertMathToFuncsPass
771  : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772  ConvertMathToFuncsPass() = default;
773  ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
774  : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
775 
776  void runOnOperation() override;
777 
778 private:
779  // Return true, if this FPowI operation must be converted
780  // because the width of its exponent's type is greater than
781  // or equal to minWidthOfFPowIExponent option value.
782  bool isFPowIConvertible(math::FPowIOp op);
783 
784  // Generate outlined implementations for power operations
785  // and store them in funcImpls map.
786  void generateOpImplementations();
787 
788  // A map between pairs of (operation, type) deduced from operations that this
789  // pass will convert, and the corresponding outlined software implementations
790  // of these operations for the given type.
791  DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
792 };
793 } // namespace
794 
795 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
796  auto expTy =
797  dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
798  return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
799 }
800 
801 void ConvertMathToFuncsPass::generateOpImplementations() {
802  ModuleOp module = getOperation();
803 
804  module.walk([&](Operation *op) {
806  .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
807  if (!convertCtlz)
808  return;
809  Type resultType = getElementTypeOrSelf(op.getResult().getType());
810 
811  // Generate the software implementation of this operation,
812  // if it has not been generated yet.
813  auto key = std::pair(op->getName(), resultType);
814  auto entry = funcImpls.try_emplace(key, func::FuncOp{});
815  if (entry.second)
816  entry.first->second = createCtlzFunc(&module, resultType);
817  })
818  .Case<math::IPowIOp>([&](math::IPowIOp op) {
819  Type resultType = getElementTypeOrSelf(op.getResult().getType());
820 
821  // Generate the software implementation of this operation,
822  // if it has not been generated yet.
823  auto key = std::pair(op->getName(), resultType);
824  auto entry = funcImpls.try_emplace(key, func::FuncOp{});
825  if (entry.second)
826  entry.first->second = createElementIPowIFunc(&module, resultType);
827  })
828  .Case<math::FPowIOp>([&](math::FPowIOp op) {
829  if (!isFPowIConvertible(op))
830  return;
831 
832  FunctionType funcType = getElementalFuncTypeForOp(op);
833 
834  // Generate the software implementation of this operation,
835  // if it has not been generated yet.
836  // FPowI implementations are mapped via the FunctionType
837  // created from the operation's result and operands.
838  auto key = std::pair(op->getName(), funcType);
839  auto entry = funcImpls.try_emplace(key, func::FuncOp{});
840  if (entry.second)
841  entry.first->second = createElementFPowIFunc(&module, funcType);
842  });
843  });
844 }
845 
846 void ConvertMathToFuncsPass::runOnOperation() {
847  ModuleOp module = getOperation();
848 
849  // Create outlined implementations for power operations.
850  generateOpImplementations();
851 
852  RewritePatternSet patterns(&getContext());
853  patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
854  VecOpToScalarOp<math::CountLeadingZerosOp>>(
855  patterns.getContext());
856 
857  // For the given Type Returns FuncOp stored in funcImpls map.
858  auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
859  auto it = funcImpls.find(std::pair(op->getName(), type));
860  if (it == funcImpls.end())
861  return {};
862 
863  return it->second;
864  };
865  patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
866  getFuncOpByType);
867 
868  if (convertCtlz)
869  patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
870 
871  ConversionTarget target(getContext());
872  target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
873  func::FuncDialect, scf::SCFDialect,
874  vector::VectorDialect>();
875 
876  target.addIllegalOp<math::IPowIOp>();
877  if (convertCtlz)
878  target.addIllegalOp<math::CountLeadingZerosOp>();
879  target.addDynamicallyLegalOp<math::FPowIOp>(
880  [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
881  if (failed(applyPartialConversion(module, target, std::move(patterns))))
882  signalPassFailure();
883 }
static MLIRContext * getContext(OpFoldResult val)
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
#define DBGS()
Definition: MathToFuncs.cpp:36
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_type_iterator operand_type_end()
Definition: Operation.h:391
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
result_type_iterator result_type_end()
Definition: Operation.h:422
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
result_type_iterator result_type_begin()
Definition: Operation.h:421
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
operand_type_iterator operand_type_begin()
Definition: Operation.h:390
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void print(raw_ostream &os) const
Print the current type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358