MLIR  22.0.0git
MathToFuncs.cpp
Go to the documentation of this file.
1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/DebugLog.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 #define DEBUG_TYPE "math-to-funcs"
35 
36 namespace {
37 // Pattern to convert vector operations to scalar operations.
38 template <typename Op>
39 struct VecOpToScalarOp : public OpRewritePattern<Op> {
40 public:
42 
43  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
44 };
45 
46 // Callback type for getting pre-generated FuncOp implementing
47 // an operation of the given type.
48 using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
49 
50 // Pattern to convert scalar IPowIOp into a call of outlined
51 // software implementation.
52 class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
53 public:
54  IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
55  : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
56 
57  /// Convert IPowI into a call to a local function implementing
58  /// the power operation. The local function computes a scalar result,
59  /// so vector forms of IPowI are linearized.
60  LogicalResult matchAndRewrite(math::IPowIOp op,
61  PatternRewriter &rewriter) const final;
62 
63 private:
64  GetFuncCallbackTy getFuncOpCallback;
65 };
66 
67 // Pattern to convert scalar FPowIOp into a call of outlined
68 // software implementation.
69 class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
70 public:
71  FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
72  : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
73 
74  /// Convert FPowI into a call to a local function implementing
75  /// the power operation. The local function computes a scalar result,
76  /// so vector forms of FPowI are linearized.
77  LogicalResult matchAndRewrite(math::FPowIOp op,
78  PatternRewriter &rewriter) const final;
79 
80 private:
81  GetFuncCallbackTy getFuncOpCallback;
82 };
83 
84 // Pattern to convert scalar ctlz into a call of outlined software
85 // implementation.
86 class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
87 public:
88  CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
90  getFuncOpCallback(cb) {}
91 
92  /// Convert ctlz into a call to a local function implementing
93  /// the count leading zeros operation.
94  LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
95  PatternRewriter &rewriter) const final;
96 
97 private:
98  GetFuncCallbackTy getFuncOpCallback;
99 };
100 } // namespace
101 
102 template <typename Op>
103 LogicalResult
104 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
105  Type opType = op.getType();
106  Location loc = op.getLoc();
107  auto vecType = dyn_cast<VectorType>(opType);
108 
109  if (!vecType)
110  return rewriter.notifyMatchFailure(op, "not a vector operation");
111  if (!vecType.hasRank())
112  return rewriter.notifyMatchFailure(op, "unknown vector rank");
113  ArrayRef<int64_t> shape = vecType.getShape();
114  int64_t numElements = vecType.getNumElements();
115 
116  Type resultElementType = vecType.getElementType();
117  Attribute initValueAttr;
118  if (isa<FloatType>(resultElementType))
119  initValueAttr = FloatAttr::get(resultElementType, 0.0);
120  else
121  initValueAttr = IntegerAttr::get(resultElementType, 0);
122  Value result = arith::ConstantOp::create(
123  rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr));
124  SmallVector<int64_t> strides = computeStrides(shape);
125  for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
126  SmallVector<int64_t> positions = delinearize(linearIndex, strides);
127  SmallVector<Value> operands;
128  for (Value input : op->getOperands())
129  operands.push_back(
130  vector::ExtractOp::create(rewriter, loc, input, positions));
131  Value scalarOp =
132  Op::create(rewriter, loc, vecType.getElementType(), operands);
133  result =
134  vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
135  }
136  rewriter.replaceOp(op, result);
137  return success();
138 }
139 
140 static FunctionType getElementalFuncTypeForOp(Operation *op) {
141  SmallVector<Type, 1> resultTys(op->getNumResults());
142  SmallVector<Type, 2> inputTys(op->getNumOperands());
143  std::transform(op->result_type_begin(), op->result_type_end(),
144  resultTys.begin(),
145  [](Type ty) { return getElementTypeOrSelf(ty); });
146  std::transform(op->operand_type_begin(), op->operand_type_end(),
147  inputTys.begin(),
148  [](Type ty) { return getElementTypeOrSelf(ty); });
149  return FunctionType::get(op->getContext(), inputTys, resultTys);
150 }
151 
152 /// Create linkonce_odr function to implement the power function with
153 /// the given \p elementType type inside \p module. The \p elementType
154 /// must be IntegerType, an the created function has
155 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
156 ///
157 /// template <typename T>
158 /// T __mlir_math_ipowi_*(T b, T p) {
159 /// if (p == T(0))
160 /// return T(1);
161 /// if (p < T(0)) {
162 /// if (b == T(0))
163 /// return T(1) / T(0); // trigger div-by-zero
164 /// if (b == T(1))
165 /// return T(1);
166 /// if (b == T(-1)) {
167 /// if (p & T(1))
168 /// return T(-1);
169 /// return T(1);
170 /// }
171 /// return T(0);
172 /// }
173 /// T result = T(1);
174 /// while (true) {
175 /// if (p & T(1))
176 /// result *= b;
177 /// p >>= T(1);
178 /// if (p == T(0))
179 /// return result;
180 /// b *= b;
181 /// }
182 /// }
183 static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
184  assert(isa<IntegerType>(elementType) &&
185  "non-integer element type for IPowIOp");
186 
187  ImplicitLocOpBuilder builder =
188  ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
189 
190  std::string funcName("__mlir_math_ipowi");
191  llvm::raw_string_ostream nameOS(funcName);
192  nameOS << '_' << elementType;
193 
194  FunctionType funcType = FunctionType::get(
195  builder.getContext(), {elementType, elementType}, elementType);
196  auto funcOp = func::FuncOp::create(builder, funcName, funcType);
197  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
198  Attribute linkage =
199  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
200  funcOp->setAttr("llvm.linkage", linkage);
201  funcOp.setPrivate();
202 
203  Block *entryBlock = funcOp.addEntryBlock();
204  Region *funcBody = entryBlock->getParent();
205 
206  Value bArg = funcOp.getArgument(0);
207  Value pArg = funcOp.getArgument(1);
208  builder.setInsertionPointToEnd(entryBlock);
209  Value zeroValue = arith::ConstantOp::create(
210  builder, elementType, builder.getIntegerAttr(elementType, 0));
211  Value oneValue = arith::ConstantOp::create(
212  builder, elementType, builder.getIntegerAttr(elementType, 1));
213  Value minusOneValue = arith::ConstantOp::create(
214  builder, elementType,
215  builder.getIntegerAttr(elementType,
216  APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
217  /*isSigned=*/true)));
218 
219  // if (p == T(0))
220  // return T(1);
221  auto pIsZero =
222  arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue);
223  Block *thenBlock = builder.createBlock(funcBody);
224  func::ReturnOp::create(builder, oneValue);
225  Block *fallthroughBlock = builder.createBlock(funcBody);
226  // Set up conditional branch for (p == T(0)).
227  builder.setInsertionPointToEnd(pIsZero->getBlock());
228  cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
229 
230  // if (p < T(0)) {
231  builder.setInsertionPointToEnd(fallthroughBlock);
232  auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
233  zeroValue);
234  // if (b == T(0))
235  builder.createBlock(funcBody);
236  auto bIsZero =
237  arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue);
238  // return T(1) / T(0);
239  thenBlock = builder.createBlock(funcBody);
240  func::ReturnOp::create(
241  builder,
242  arith::DivSIOp::create(builder, oneValue, zeroValue).getResult());
243  fallthroughBlock = builder.createBlock(funcBody);
244  // Set up conditional branch for (b == T(0)).
245  builder.setInsertionPointToEnd(bIsZero->getBlock());
246  cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock);
247 
248  // if (b == T(1))
249  builder.setInsertionPointToEnd(fallthroughBlock);
250  auto bIsOne =
251  arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue);
252  // return T(1);
253  thenBlock = builder.createBlock(funcBody);
254  func::ReturnOp::create(builder, oneValue);
255  fallthroughBlock = builder.createBlock(funcBody);
256  // Set up conditional branch for (b == T(1)).
257  builder.setInsertionPointToEnd(bIsOne->getBlock());
258  cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock);
259 
260  // if (b == T(-1)) {
261  builder.setInsertionPointToEnd(fallthroughBlock);
262  auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
263  bArg, minusOneValue);
264  // if (p & T(1))
265  builder.createBlock(funcBody);
266  auto pIsOdd = arith::CmpIOp::create(
267  builder, arith::CmpIPredicate::ne,
268  arith::AndIOp::create(builder, pArg, oneValue), zeroValue);
269  // return T(-1);
270  thenBlock = builder.createBlock(funcBody);
271  func::ReturnOp::create(builder, minusOneValue);
272  fallthroughBlock = builder.createBlock(funcBody);
273  // Set up conditional branch for (p & T(1)).
274  builder.setInsertionPointToEnd(pIsOdd->getBlock());
275  cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock);
276 
277  // return T(1);
278  // } // b == T(-1)
279  builder.setInsertionPointToEnd(fallthroughBlock);
280  func::ReturnOp::create(builder, oneValue);
281  fallthroughBlock = builder.createBlock(funcBody);
282  // Set up conditional branch for (b == T(-1)).
283  builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
284  cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(),
285  fallthroughBlock);
286 
287  // return T(0);
288  // } // (p < T(0))
289  builder.setInsertionPointToEnd(fallthroughBlock);
290  func::ReturnOp::create(builder, zeroValue);
291  Block *loopHeader = builder.createBlock(
292  funcBody, funcBody->end(), {elementType, elementType, elementType},
293  {builder.getLoc(), builder.getLoc(), builder.getLoc()});
294  // Set up conditional branch for (p < T(0)).
295  builder.setInsertionPointToEnd(pIsNeg->getBlock());
296  // Set initial values of 'result', 'b' and 'p' for the loop.
297  cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader,
298  ValueRange{oneValue, bArg, pArg});
299 
300  // T result = T(1);
301  // while (true) {
302  // if (p & T(1))
303  // result *= b;
304  // p >>= T(1);
305  // if (p == T(0))
306  // return result;
307  // b *= b;
308  // }
309  Value resultTmp = loopHeader->getArgument(0);
310  Value baseTmp = loopHeader->getArgument(1);
311  Value powerTmp = loopHeader->getArgument(2);
312  builder.setInsertionPointToEnd(loopHeader);
313 
314  // if (p & T(1))
315  auto powerTmpIsOdd = arith::CmpIOp::create(
316  builder, arith::CmpIPredicate::ne,
317  arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue);
318  thenBlock = builder.createBlock(funcBody);
319  // result *= b;
320  Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp);
321  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
322  builder.getLoc());
323  builder.setInsertionPointToEnd(thenBlock);
324  cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
325  // Set up conditional branch for (p & T(1)).
326  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
327  cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
328  resultTmp);
329  // Merged 'result'.
330  newResultTmp = fallthroughBlock->getArgument(0);
331 
332  // p >>= T(1);
333  builder.setInsertionPointToEnd(fallthroughBlock);
334  Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue);
335 
336  // if (p == T(0))
337  auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
338  newPowerTmp, zeroValue);
339  // return result;
340  thenBlock = builder.createBlock(funcBody);
341  func::ReturnOp::create(builder, newResultTmp);
342  fallthroughBlock = builder.createBlock(funcBody);
343  // Set up conditional branch for (p == T(0)).
344  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
345  cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock,
346  fallthroughBlock);
347 
348  // b *= b;
349  // }
350  builder.setInsertionPointToEnd(fallthroughBlock);
351  Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp);
352  // Pass new values for 'result', 'b' and 'p' to the loop header.
353  cf::BranchOp::create(
354  builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355  return funcOp;
356 }
357 
358 /// Convert IPowI into a call to a local function implementing
359 /// the power operation. The local function computes a scalar result,
360 /// so vector forms of IPowI are linearized.
361 LogicalResult
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363  PatternRewriter &rewriter) const {
364  auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
365 
366  if (!baseType)
367  return rewriter.notifyMatchFailure(op, "non-integer base operand");
368 
369  // The outlined software implementation must have been already
370  // generated.
371  func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372  if (!elementFunc)
373  return rewriter.notifyMatchFailure(op, "missing software implementation");
374 
375  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376  return success();
377 }
378 
379 /// Create linkonce_odr function to implement the power function with
380 /// the given \p funcType type inside \p module. The \p funcType must be
381 /// 'FloatType (*)(FloatType, IntegerType)' function type.
382 ///
383 /// template <typename T>
384 /// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385 /// if (p == Tp{0})
386 /// return Tb{1};
387 /// bool isNegativePower{p < Tp{0}}
388 /// bool isMin{p == std::numeric_limits<Tp>::min()};
389 /// if (isMin) {
390 /// p = std::numeric_limits<Tp>::max();
391 /// } else if (isNegativePower) {
392 /// p = -p;
393 /// }
394 /// Tb result = Tb{1};
395 /// Tb origBase = Tb{b};
396 /// while (true) {
397 /// if (p & Tp{1})
398 /// result *= b;
399 /// p >>= Tp{1};
400 /// if (p == Tp{0})
401 /// break;
402 /// b *= b;
403 /// }
404 /// if (isMin) {
405 /// result *= origBase;
406 /// }
407 /// if (isNegativePower) {
408 /// result = Tb{1} / result;
409 /// }
410 /// return result;
411 /// }
412 static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413  FunctionType funcType) {
414  auto baseType = cast<FloatType>(funcType.getInput(0));
415  auto powType = cast<IntegerType>(funcType.getInput(1));
416  ImplicitLocOpBuilder builder =
417  ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
418 
419  std::string funcName("__mlir_math_fpowi");
420  llvm::raw_string_ostream nameOS(funcName);
421  nameOS << '_' << baseType;
422  nameOS << '_' << powType;
423  auto funcOp = func::FuncOp::create(builder, funcName, funcType);
424  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425  Attribute linkage =
426  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
427  funcOp->setAttr("llvm.linkage", linkage);
428  funcOp.setPrivate();
429 
430  Block *entryBlock = funcOp.addEntryBlock();
431  Region *funcBody = entryBlock->getParent();
432 
433  Value bArg = funcOp.getArgument(0);
434  Value pArg = funcOp.getArgument(1);
435  builder.setInsertionPointToEnd(entryBlock);
436  Value oneBValue = arith::ConstantOp::create(
437  builder, baseType, builder.getFloatAttr(baseType, 1.0));
438  Value zeroPValue = arith::ConstantOp::create(
439  builder, powType, builder.getIntegerAttr(powType, 0));
440  Value onePValue = arith::ConstantOp::create(
441  builder, powType, builder.getIntegerAttr(powType, 1));
442  Value minPValue = arith::ConstantOp::create(
443  builder, powType,
444  builder.getIntegerAttr(
445  powType, llvm::APInt::getSignedMinValue(powType.getWidth())));
446  Value maxPValue = arith::ConstantOp::create(
447  builder, powType,
448  builder.getIntegerAttr(
449  powType, llvm::APInt::getSignedMaxValue(powType.getWidth())));
450 
451  // if (p == Tp{0})
452  // return Tb{1};
453  auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg,
454  zeroPValue);
455  Block *thenBlock = builder.createBlock(funcBody);
456  func::ReturnOp::create(builder, oneBValue);
457  Block *fallthroughBlock = builder.createBlock(funcBody);
458  // Set up conditional branch for (p == Tp{0}).
459  builder.setInsertionPointToEnd(pIsZero->getBlock());
460  cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
461 
462  builder.setInsertionPointToEnd(fallthroughBlock);
463  // bool isNegativePower{p < Tp{0}}
464  auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
465  zeroPValue);
466  // bool isMin{p == std::numeric_limits<Tp>::min()};
467  auto pIsMin =
468  arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue);
469 
470  // if (isMin) {
471  // p = std::numeric_limits<Tp>::max();
472  // } else if (isNegativePower) {
473  // p = -p;
474  // }
475  Value negP = arith::SubIOp::create(builder, zeroPValue, pArg);
476  auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg);
477  pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit);
478 
479  // Tb result = Tb{1};
480  // Tb origBase = Tb{b};
481  // while (true) {
482  // if (p & Tp{1})
483  // result *= b;
484  // p >>= Tp{1};
485  // if (p == Tp{0})
486  // break;
487  // b *= b;
488  // }
489  Block *loopHeader = builder.createBlock(
490  funcBody, funcBody->end(), {baseType, baseType, powType},
491  {builder.getLoc(), builder.getLoc(), builder.getLoc()});
492  // Set initial values of 'result', 'b' and 'p' for the loop.
493  builder.setInsertionPointToEnd(pInit->getBlock());
494  cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit});
495 
496  // Create loop body.
497  Value resultTmp = loopHeader->getArgument(0);
498  Value baseTmp = loopHeader->getArgument(1);
499  Value powerTmp = loopHeader->getArgument(2);
500  builder.setInsertionPointToEnd(loopHeader);
501 
502  // if (p & Tp{1})
503  auto powerTmpIsOdd = arith::CmpIOp::create(
504  builder, arith::CmpIPredicate::ne,
505  arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue);
506  thenBlock = builder.createBlock(funcBody);
507  // result *= b;
508  Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp);
509  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
510  builder.getLoc());
511  builder.setInsertionPointToEnd(thenBlock);
512  cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
513  // Set up conditional branch for (p & Tp{1}).
514  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
515  cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
516  resultTmp);
517  // Merged 'result'.
518  newResultTmp = fallthroughBlock->getArgument(0);
519 
520  // p >>= Tp{1};
521  builder.setInsertionPointToEnd(fallthroughBlock);
522  Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue);
523 
524  // if (p == Tp{0})
525  auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
526  newPowerTmp, zeroPValue);
527  // break;
528  //
529  // The conditional branch is finalized below with a jump to
530  // the loop exit block.
531  fallthroughBlock = builder.createBlock(funcBody);
532 
533  // b *= b;
534  // }
535  builder.setInsertionPointToEnd(fallthroughBlock);
536  Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp);
537  // Pass new values for 'result', 'b' and 'p' to the loop header.
538  cf::BranchOp::create(
539  builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
540 
541  // Set up conditional branch for early loop exit:
542  // if (p == Tp{0})
543  // break;
544  Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
545  builder.getLoc());
546  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
547  cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp,
548  fallthroughBlock, ValueRange{});
549 
550  // if (isMin) {
551  // result *= origBase;
552  // }
553  newResultTmp = loopExit->getArgument(0);
554  thenBlock = builder.createBlock(funcBody);
555  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
556  builder.getLoc());
557  builder.setInsertionPointToEnd(loopExit);
558  cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock,
559  newResultTmp);
560  builder.setInsertionPointToEnd(thenBlock);
561  newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg);
562  cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
563 
564  /// if (isNegativePower) {
565  /// result = Tb{1} / result;
566  /// }
567  newResultTmp = fallthroughBlock->getArgument(0);
568  thenBlock = builder.createBlock(funcBody);
569  Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
570  builder.getLoc());
571  builder.setInsertionPointToEnd(fallthroughBlock);
572  cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock,
573  newResultTmp);
574  builder.setInsertionPointToEnd(thenBlock);
575  newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp);
576  cf::BranchOp::create(builder, newResultTmp, returnBlock);
577 
578  // return result;
579  builder.setInsertionPointToEnd(returnBlock);
580  func::ReturnOp::create(builder, returnBlock->getArgument(0));
581 
582  return funcOp;
583 }
584 
585 /// Convert FPowI into a call to a local function implementing
586 /// the power operation. The local function computes a scalar result,
587 /// so vector forms of FPowI are linearized.
588 LogicalResult
589 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
590  PatternRewriter &rewriter) const {
591  if (isa<VectorType>(op.getType()))
592  return rewriter.notifyMatchFailure(op, "non-scalar operation");
593 
594  FunctionType funcType = getElementalFuncTypeForOp(op);
595 
596  // The outlined software implementation must have been already
597  // generated.
598  func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
599  if (!elementFunc)
600  return rewriter.notifyMatchFailure(op, "missing software implementation");
601 
602  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
603  return success();
604 }
605 
606 /// Create function to implement the ctlz function the given \p elementType type
607 /// inside \p module. The \p elementType must be IntegerType, an the created
608 /// function has 'IntegerType (*)(IntegerType)' function type.
609 ///
610 /// template <typename T>
611 /// T __mlir_math_ctlz_*(T x) {
612 /// bits = sizeof(x) * 8;
613 /// if (x == 0)
614 /// return bits;
615 ///
616 /// uint32_t n = 0;
617 /// for (int i = 1; i < bits; ++i) {
618 /// if (x < 0) continue;
619 /// n++;
620 /// x <<= 1;
621 /// }
622 /// return n;
623 /// }
624 ///
625 /// Converts to (for i32):
626 ///
627 /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
628 /// %c_32 = arith.constant 32 : index
629 /// %c_0 = arith.constant 0 : i32
630 /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
631 /// %out = scf.if %arg_eq_zero {
632 /// scf.yield %c_32 : i32
633 /// } else {
634 /// %c_1index = arith.constant 1 : index
635 /// %c_1i32 = arith.constant 1 : i32
636 /// %n = arith.constant 0 : i32
637 /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
638 /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
639 /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
640 /// %yield_val = scf.if %cond {
641 /// scf.yield %arg_iter, %n_iter : i32, i32
642 /// } else {
643 /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
644 /// %n_next = arith.addi %n_iter, %c_1i32 : i32
645 /// scf.yield %arg_next, %n_next : i32, i32
646 /// }
647 /// scf.yield %yield_val: i32, i32
648 /// }
649 /// scf.yield %n_out : i32
650 /// }
651 /// return %out: i32
652 /// }
653 static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
654  if (!isa<IntegerType>(elementType)) {
655  LDBG() << "non-integer element type for CtlzFunc; type was: "
656  << elementType;
657  llvm_unreachable("non-integer element type");
658  }
659  int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660 
661  Location loc = module->getLoc();
662  ImplicitLocOpBuilder builder =
663  ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
664 
665  std::string funcName("__mlir_math_ctlz");
666  llvm::raw_string_ostream nameOS(funcName);
667  nameOS << '_' << elementType;
668  FunctionType funcType =
669  FunctionType::get(builder.getContext(), {elementType}, elementType);
670  auto funcOp = func::FuncOp::create(builder, funcName, funcType);
671 
672  // LinkonceODR ensures that there is only one implementation of this function
673  // across all math.ctlz functions that are lowered in this way.
674  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675  Attribute linkage =
676  LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677  funcOp->setAttr("llvm.linkage", linkage);
678  funcOp.setPrivate();
679 
680  // set the insertion point to the start of the function
681  Block *funcBody = funcOp.addEntryBlock();
682  builder.setInsertionPointToStart(funcBody);
683 
684  Value arg = funcOp.getArgument(0);
685  Type indexType = builder.getIndexType();
686  Value bitWidthValue = arith::ConstantOp::create(
687  builder, elementType, builder.getIntegerAttr(elementType, bitWidth));
688  Value zeroValue = arith::ConstantOp::create(
689  builder, elementType, builder.getIntegerAttr(elementType, 0));
690 
691  Value inputEqZero =
692  arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue);
693 
694  // if input == 0, return bit width, else enter loop.
695  scf::IfOp ifOp =
696  scf::IfOp::create(builder, elementType, inputEqZero,
697  /*addThenBlock=*/true, /*addElseBlock=*/true);
698  auto thenBuilder = ifOp.getThenBodyBuilder();
699  scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
700 
701  auto elseBuilder =
702  ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
703 
704  Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType,
705  elseBuilder.getIndexAttr(1));
706  Value oneValue = arith::ConstantOp::create(
707  elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1));
708  Value bitWidthIndex = arith::ConstantOp::create(
709  elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth));
710  Value nValue = arith::ConstantOp::create(
711  elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0));
712 
713  auto loop = scf::ForOp::create(
714  elseBuilder, oneIndex, bitWidthIndex, oneIndex,
715  // Initial values for two loop induction variables, the arg which is being
716  // shifted left in each iteration, and the n value which tracks the count
717  // of leading zeros.
718  ValueRange{arg, nValue},
719  // Callback to build the body of the for loop
720  // if (arg < 0) {
721  // continue;
722  // } else {
723  // n++;
724  // arg <<= 1;
725  // }
726  [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
727  Value argIter = args[0];
728  Value nIter = args[1];
729 
730  Value argIsNonNegative = arith::CmpIOp::create(
731  b, loc, arith::CmpIPredicate::slt, argIter, zeroValue);
732  scf::IfOp ifOp = scf::IfOp::create(
733  b, loc, argIsNonNegative,
734  [&](OpBuilder &b, Location loc) {
735  // If arg is negative, continue (effectively, break)
736  scf::YieldOp::create(b, loc, ValueRange{argIter, nIter});
737  },
738  [&](OpBuilder &b, Location loc) {
739  // Otherwise, increment n and shift arg left.
740  Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue);
741  Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue);
742  scf::YieldOp::create(b, loc, ValueRange{argNext, nNext});
743  });
744  scf::YieldOp::create(b, loc, ifOp.getResults());
745  });
746  scf::YieldOp::create(elseBuilder, loop.getResult(1));
747 
748  func::ReturnOp::create(builder, ifOp.getResult(0));
749  return funcOp;
750 }
751 
752 /// Convert ctlz into a call to a local function implementing the ctlz
753 /// operation.
754 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
755  PatternRewriter &rewriter) const {
756  if (isa<VectorType>(op.getType()))
757  return rewriter.notifyMatchFailure(op, "non-scalar operation");
758 
759  Type type = getElementTypeOrSelf(op.getResult().getType());
760  func::FuncOp elementFunc = getFuncOpCallback(op, type);
761  if (!elementFunc)
762  return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
763  diag << "Missing software implementation for op " << op->getName()
764  << " and type " << type;
765  });
766 
767  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
768  return success();
769 }
770 
771 namespace {
772 struct ConvertMathToFuncsPass
773  : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
774  ConvertMathToFuncsPass() = default;
775  ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
776  : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
777 
778  void runOnOperation() override;
779 
780 private:
781  // Return true, if this FPowI operation must be converted
782  // because the width of its exponent's type is greater than
783  // or equal to minWidthOfFPowIExponent option value.
784  bool isFPowIConvertible(math::FPowIOp op);
785 
786  // Reture true, if operation is integer type.
787  bool isConvertible(Operation *op);
788 
789  // Generate outlined implementations for power operations
790  // and store them in funcImpls map.
791  void generateOpImplementations();
792 
793  // A map between pairs of (operation, type) deduced from operations that this
794  // pass will convert, and the corresponding outlined software implementations
795  // of these operations for the given type.
796  DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
797 };
798 } // namespace
799 
800 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
801  auto expTy =
802  dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
803  return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
804 }
805 
806 bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
807  return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType()));
808 }
809 
810 void ConvertMathToFuncsPass::generateOpImplementations() {
811  ModuleOp module = getOperation();
812 
813  module.walk([&](Operation *op) {
815  .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
816  if (!convertCtlz || !isConvertible(op))
817  return;
818  Type resultType = getElementTypeOrSelf(op.getResult().getType());
819 
820  // Generate the software implementation of this operation,
821  // if it has not been generated yet.
822  auto key = std::pair(op->getName(), resultType);
823  auto entry = funcImpls.try_emplace(key, func::FuncOp{});
824  if (entry.second)
825  entry.first->second = createCtlzFunc(&module, resultType);
826  })
827  .Case<math::IPowIOp>([&](math::IPowIOp op) {
828  if (!isConvertible(op))
829  return;
830 
831  Type resultType = getElementTypeOrSelf(op.getResult().getType());
832 
833  // Generate the software implementation of this operation,
834  // if it has not been generated yet.
835  auto key = std::pair(op->getName(), resultType);
836  auto entry = funcImpls.try_emplace(key, func::FuncOp{});
837  if (entry.second)
838  entry.first->second = createElementIPowIFunc(&module, resultType);
839  })
840  .Case<math::FPowIOp>([&](math::FPowIOp op) {
841  if (!isFPowIConvertible(op))
842  return;
843 
844  FunctionType funcType = getElementalFuncTypeForOp(op);
845 
846  // Generate the software implementation of this operation,
847  // if it has not been generated yet.
848  // FPowI implementations are mapped via the FunctionType
849  // created from the operation's result and operands.
850  auto key = std::pair(op->getName(), funcType);
851  auto entry = funcImpls.try_emplace(key, func::FuncOp{});
852  if (entry.second)
853  entry.first->second = createElementFPowIFunc(&module, funcType);
854  });
855  });
856 }
857 
858 void ConvertMathToFuncsPass::runOnOperation() {
859  ModuleOp module = getOperation();
860 
861  // Create outlined implementations for power operations.
862  generateOpImplementations();
863 
865  patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
866  VecOpToScalarOp<math::CountLeadingZerosOp>>(
867  patterns.getContext());
868 
869  // For the given Type Returns FuncOp stored in funcImpls map.
870  auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
871  auto it = funcImpls.find(std::pair(op->getName(), type));
872  if (it == funcImpls.end())
873  return {};
874 
875  return it->second;
876  };
877  patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
878  getFuncOpByType);
879 
880  if (convertCtlz)
881  patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
882 
883  ConversionTarget target(getContext());
884  target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
885  func::FuncDialect, scf::SCFDialect,
886  vector::VectorDialect>();
887 
888  target.addDynamicallyLegalOp<math::IPowIOp>(
889  [this](math::IPowIOp op) { return !isConvertible(op); });
890  if (convertCtlz) {
891  target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
892  [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
893  }
894  target.addDynamicallyLegalOp<math::FPowIOp>(
895  [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
896  if (failed(applyPartialConversion(module, target, std::move(patterns))))
897  signalPassFailure();
898 }
static MLIRContext * getContext(OpFoldResult val)
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:654
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:638
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
operand_type_iterator operand_type_end()
Definition: Operation.h:396
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
result_type_iterator result_type_end()
Definition: Operation.h:427
result_type_iterator result_type_begin()
Definition: Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314