MLIR 22.0.0git
MathToFuncs.cpp
Go to the documentation of this file.
1//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "mlir/Pass/Pass.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/DebugLog.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34#define DEBUG_TYPE "math-to-funcs"
35
36namespace {
37// Pattern to convert vector operations to scalar operations.
38template <typename Op>
39struct VecOpToScalarOp : public OpRewritePattern<Op> {
40public:
42
43 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
44};
45
46// Callback type for getting pre-generated FuncOp implementing
47// an operation of the given type.
48using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
49
50// Pattern to convert scalar IPowIOp into a call of outlined
51// software implementation.
52class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
53public:
54 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
55 : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
56
57 /// Convert IPowI into a call to a local function implementing
58 /// the power operation. The local function computes a scalar result,
59 /// so vector forms of IPowI are linearized.
60 LogicalResult matchAndRewrite(math::IPowIOp op,
61 PatternRewriter &rewriter) const final;
62
63private:
64 GetFuncCallbackTy getFuncOpCallback;
65};
66
67// Pattern to convert scalar FPowIOp into a call of outlined
68// software implementation.
69class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
70public:
71 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
72 : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
73
74 /// Convert FPowI into a call to a local function implementing
75 /// the power operation. The local function computes a scalar result,
76 /// so vector forms of FPowI are linearized.
77 LogicalResult matchAndRewrite(math::FPowIOp op,
78 PatternRewriter &rewriter) const final;
79
80private:
81 GetFuncCallbackTy getFuncOpCallback;
82};
83
84// Pattern to convert scalar ctlz into a call of outlined software
85// implementation.
86class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
87public:
88 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
90 getFuncOpCallback(cb) {}
91
92 /// Convert ctlz into a call to a local function implementing
93 /// the count leading zeros operation.
94 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
95 PatternRewriter &rewriter) const final;
96
97private:
98 GetFuncCallbackTy getFuncOpCallback;
99};
100} // namespace
101
102template <typename Op>
103LogicalResult
104VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
105 Type opType = op.getType();
106 Location loc = op.getLoc();
107 auto vecType = dyn_cast<VectorType>(opType);
108
109 if (!vecType)
110 return rewriter.notifyMatchFailure(op, "not a vector operation");
111 if (!vecType.hasRank())
112 return rewriter.notifyMatchFailure(op, "unknown vector rank");
113 ArrayRef<int64_t> shape = vecType.getShape();
114 int64_t numElements = vecType.getNumElements();
115
116 Type resultElementType = vecType.getElementType();
117 Attribute initValueAttr;
118 if (isa<FloatType>(resultElementType))
119 initValueAttr = FloatAttr::get(resultElementType, 0.0);
120 else
121 initValueAttr = IntegerAttr::get(resultElementType, 0);
122 Value result = arith::ConstantOp::create(
123 rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr));
125 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
126 SmallVector<int64_t> positions = delinearize(linearIndex, strides);
127 SmallVector<Value> operands;
128 for (Value input : op->getOperands())
129 operands.push_back(
130 vector::ExtractOp::create(rewriter, loc, input, positions));
131 Value scalarOp =
132 Op::create(rewriter, loc, vecType.getElementType(), operands);
133 result =
134 vector::InsertOp::create(rewriter, loc, scalarOp, result, positions);
135 }
136 rewriter.replaceOp(op, result);
137 return success();
138}
139
140static FunctionType getElementalFuncTypeForOp(Operation *op) {
141 SmallVector<Type, 1> resultTys(op->getNumResults());
142 SmallVector<Type, 2> inputTys(op->getNumOperands());
143 std::transform(op->result_type_begin(), op->result_type_end(),
144 resultTys.begin(),
145 [](Type ty) { return getElementTypeOrSelf(ty); });
146 std::transform(op->operand_type_begin(), op->operand_type_end(),
147 inputTys.begin(),
148 [](Type ty) { return getElementTypeOrSelf(ty); });
149 return FunctionType::get(op->getContext(), inputTys, resultTys);
150}
151
152/// Create linkonce_odr function to implement the power function with
153/// the given \p elementType type inside \p module. The \p elementType
154/// must be IntegerType, an the created function has
155/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
156///
157/// template <typename T>
158/// T __mlir_math_ipowi_*(T b, T p) {
159/// if (p == T(0))
160/// return T(1);
161/// if (p < T(0)) {
162/// if (b == T(0))
163/// return T(1) / T(0); // trigger div-by-zero
164/// if (b == T(1))
165/// return T(1);
166/// if (b == T(-1)) {
167/// if (p & T(1))
168/// return T(-1);
169/// return T(1);
170/// }
171/// return T(0);
172/// }
173/// T result = T(1);
174/// while (true) {
175/// if (p & T(1))
176/// result *= b;
177/// p >>= T(1);
178/// if (p == T(0))
179/// return result;
180/// b *= b;
181/// }
182/// }
183static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
184 assert(isa<IntegerType>(elementType) &&
185 "non-integer element type for IPowIOp");
186
187 ImplicitLocOpBuilder builder =
188 ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
189
190 std::string funcName("__mlir_math_ipowi");
191 llvm::raw_string_ostream nameOS(funcName);
192 nameOS << '_' << elementType;
193
194 FunctionType funcType = FunctionType::get(
195 builder.getContext(), {elementType, elementType}, elementType);
196 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
197 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
198 Attribute linkage =
199 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
200 funcOp->setAttr("llvm.linkage", linkage);
201 funcOp.setPrivate();
202
203 Block *entryBlock = funcOp.addEntryBlock();
204 Region *funcBody = entryBlock->getParent();
205
206 Value bArg = funcOp.getArgument(0);
207 Value pArg = funcOp.getArgument(1);
208 builder.setInsertionPointToEnd(entryBlock);
209 Value zeroValue = arith::ConstantOp::create(
210 builder, elementType, builder.getIntegerAttr(elementType, 0));
211 Value oneValue = arith::ConstantOp::create(
212 builder, elementType, builder.getIntegerAttr(elementType, 1));
213 Value minusOneValue = arith::ConstantOp::create(
214 builder, elementType,
215 builder.getIntegerAttr(elementType,
216 APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
217 /*isSigned=*/true)));
218
219 // if (p == T(0))
220 // return T(1);
221 auto pIsZero =
222 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue);
223 Block *thenBlock = builder.createBlock(funcBody);
224 func::ReturnOp::create(builder, oneValue);
225 Block *fallthroughBlock = builder.createBlock(funcBody);
226 // Set up conditional branch for (p == T(0)).
227 builder.setInsertionPointToEnd(pIsZero->getBlock());
228 cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
229
230 // if (p < T(0)) {
231 builder.setInsertionPointToEnd(fallthroughBlock);
232 auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
233 zeroValue);
234 // if (b == T(0))
235 builder.createBlock(funcBody);
236 auto bIsZero =
237 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue);
238 // return T(1) / T(0);
239 thenBlock = builder.createBlock(funcBody);
240 func::ReturnOp::create(
241 builder,
242 arith::DivSIOp::create(builder, oneValue, zeroValue).getResult());
243 fallthroughBlock = builder.createBlock(funcBody);
244 // Set up conditional branch for (b == T(0)).
245 builder.setInsertionPointToEnd(bIsZero->getBlock());
246 cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock);
247
248 // if (b == T(1))
249 builder.setInsertionPointToEnd(fallthroughBlock);
250 auto bIsOne =
251 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue);
252 // return T(1);
253 thenBlock = builder.createBlock(funcBody);
254 func::ReturnOp::create(builder, oneValue);
255 fallthroughBlock = builder.createBlock(funcBody);
256 // Set up conditional branch for (b == T(1)).
257 builder.setInsertionPointToEnd(bIsOne->getBlock());
258 cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock);
259
260 // if (b == T(-1)) {
261 builder.setInsertionPointToEnd(fallthroughBlock);
262 auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
263 bArg, minusOneValue);
264 // if (p & T(1))
265 builder.createBlock(funcBody);
266 auto pIsOdd = arith::CmpIOp::create(
267 builder, arith::CmpIPredicate::ne,
268 arith::AndIOp::create(builder, pArg, oneValue), zeroValue);
269 // return T(-1);
270 thenBlock = builder.createBlock(funcBody);
271 func::ReturnOp::create(builder, minusOneValue);
272 fallthroughBlock = builder.createBlock(funcBody);
273 // Set up conditional branch for (p & T(1)).
274 builder.setInsertionPointToEnd(pIsOdd->getBlock());
275 cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock);
276
277 // return T(1);
278 // } // b == T(-1)
279 builder.setInsertionPointToEnd(fallthroughBlock);
280 func::ReturnOp::create(builder, oneValue);
281 fallthroughBlock = builder.createBlock(funcBody);
282 // Set up conditional branch for (b == T(-1)).
283 builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
284 cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(),
285 fallthroughBlock);
286
287 // return T(0);
288 // } // (p < T(0))
289 builder.setInsertionPointToEnd(fallthroughBlock);
290 func::ReturnOp::create(builder, zeroValue);
291 Block *loopHeader = builder.createBlock(
292 funcBody, funcBody->end(), {elementType, elementType, elementType},
293 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
294 // Set up conditional branch for (p < T(0)).
295 builder.setInsertionPointToEnd(pIsNeg->getBlock());
296 // Set initial values of 'result', 'b' and 'p' for the loop.
297 cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader,
298 ValueRange{oneValue, bArg, pArg});
299
300 // T result = T(1);
301 // while (true) {
302 // if (p & T(1))
303 // result *= b;
304 // p >>= T(1);
305 // if (p == T(0))
306 // return result;
307 // b *= b;
308 // }
309 Value resultTmp = loopHeader->getArgument(0);
310 Value baseTmp = loopHeader->getArgument(1);
311 Value powerTmp = loopHeader->getArgument(2);
312 builder.setInsertionPointToEnd(loopHeader);
313
314 // if (p & T(1))
315 auto powerTmpIsOdd = arith::CmpIOp::create(
316 builder, arith::CmpIPredicate::ne,
317 arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue);
318 thenBlock = builder.createBlock(funcBody);
319 // result *= b;
320 Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp);
321 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
322 builder.getLoc());
323 builder.setInsertionPointToEnd(thenBlock);
324 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
325 // Set up conditional branch for (p & T(1)).
326 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
327 cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
328 resultTmp);
329 // Merged 'result'.
330 newResultTmp = fallthroughBlock->getArgument(0);
331
332 // p >>= T(1);
333 builder.setInsertionPointToEnd(fallthroughBlock);
334 Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue);
335
336 // if (p == T(0))
337 auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
338 newPowerTmp, zeroValue);
339 // return result;
340 thenBlock = builder.createBlock(funcBody);
341 func::ReturnOp::create(builder, newResultTmp);
342 fallthroughBlock = builder.createBlock(funcBody);
343 // Set up conditional branch for (p == T(0)).
344 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
345 cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock,
346 fallthroughBlock);
347
348 // b *= b;
349 // }
350 builder.setInsertionPointToEnd(fallthroughBlock);
351 Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp);
352 // Pass new values for 'result', 'b' and 'p' to the loop header.
353 cf::BranchOp::create(
354 builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355 return funcOp;
356}
357
358/// Convert IPowI into a call to a local function implementing
359/// the power operation. The local function computes a scalar result,
360/// so vector forms of IPowI are linearized.
361LogicalResult
362IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363 PatternRewriter &rewriter) const {
364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
365
366 if (!baseType)
367 return rewriter.notifyMatchFailure(op, "non-integer base operand");
368
369 // The outlined software implementation must have been already
370 // generated.
371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372 if (!elementFunc)
373 return rewriter.notifyMatchFailure(op, "missing software implementation");
374
375 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376 return success();
377}
378
379/// Create linkonce_odr function to implement the power function with
380/// the given \p funcType type inside \p module. The \p funcType must be
381/// 'FloatType (*)(FloatType, IntegerType)' function type.
382///
383/// template <typename T>
384/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385/// if (p == Tp{0})
386/// return Tb{1};
387/// bool isNegativePower{p < Tp{0}}
388/// bool isMin{p == std::numeric_limits<Tp>::min()};
389/// if (isMin) {
390/// p = std::numeric_limits<Tp>::max();
391/// } else if (isNegativePower) {
392/// p = -p;
393/// }
394/// Tb result = Tb{1};
395/// Tb origBase = Tb{b};
396/// while (true) {
397/// if (p & Tp{1})
398/// result *= b;
399/// p >>= Tp{1};
400/// if (p == Tp{0})
401/// break;
402/// b *= b;
403/// }
404/// if (isMin) {
405/// result *= origBase;
406/// }
407/// if (isNegativePower) {
408/// result = Tb{1} / result;
409/// }
410/// return result;
411/// }
412static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413 FunctionType funcType) {
414 auto baseType = cast<FloatType>(funcType.getInput(0));
415 auto powType = cast<IntegerType>(funcType.getInput(1));
416 ImplicitLocOpBuilder builder =
417 ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
418
419 std::string funcName("__mlir_math_fpowi");
420 llvm::raw_string_ostream nameOS(funcName);
421 nameOS << '_' << baseType;
422 nameOS << '_' << powType;
423 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425 Attribute linkage =
426 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
427 funcOp->setAttr("llvm.linkage", linkage);
428 funcOp.setPrivate();
429
430 Block *entryBlock = funcOp.addEntryBlock();
431 Region *funcBody = entryBlock->getParent();
432
433 Value bArg = funcOp.getArgument(0);
434 Value pArg = funcOp.getArgument(1);
435 builder.setInsertionPointToEnd(entryBlock);
436 Value oneBValue = arith::ConstantOp::create(
437 builder, baseType, builder.getFloatAttr(baseType, 1.0));
438 Value zeroPValue = arith::ConstantOp::create(
439 builder, powType, builder.getIntegerAttr(powType, 0));
440 Value onePValue = arith::ConstantOp::create(
441 builder, powType, builder.getIntegerAttr(powType, 1));
442 Value minPValue = arith::ConstantOp::create(
443 builder, powType,
444 builder.getIntegerAttr(
445 powType, llvm::APInt::getSignedMinValue(powType.getWidth())));
446 Value maxPValue = arith::ConstantOp::create(
447 builder, powType,
448 builder.getIntegerAttr(
449 powType, llvm::APInt::getSignedMaxValue(powType.getWidth())));
450
451 // if (p == Tp{0})
452 // return Tb{1};
453 auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg,
454 zeroPValue);
455 Block *thenBlock = builder.createBlock(funcBody);
456 func::ReturnOp::create(builder, oneBValue);
457 Block *fallthroughBlock = builder.createBlock(funcBody);
458 // Set up conditional branch for (p == Tp{0}).
459 builder.setInsertionPointToEnd(pIsZero->getBlock());
460 cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock);
461
462 builder.setInsertionPointToEnd(fallthroughBlock);
463 // bool isNegativePower{p < Tp{0}}
464 auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg,
465 zeroPValue);
466 // bool isMin{p == std::numeric_limits<Tp>::min()};
467 auto pIsMin =
468 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue);
469
470 // if (isMin) {
471 // p = std::numeric_limits<Tp>::max();
472 // } else if (isNegativePower) {
473 // p = -p;
474 // }
475 Value negP = arith::SubIOp::create(builder, zeroPValue, pArg);
476 auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg);
477 pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit);
478
479 // Tb result = Tb{1};
480 // Tb origBase = Tb{b};
481 // while (true) {
482 // if (p & Tp{1})
483 // result *= b;
484 // p >>= Tp{1};
485 // if (p == Tp{0})
486 // break;
487 // b *= b;
488 // }
489 Block *loopHeader = builder.createBlock(
490 funcBody, funcBody->end(), {baseType, baseType, powType},
491 {builder.getLoc(), builder.getLoc(), builder.getLoc()});
492 // Set initial values of 'result', 'b' and 'p' for the loop.
493 builder.setInsertionPointToEnd(pInit->getBlock());
494 cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit});
495
496 // Create loop body.
497 Value resultTmp = loopHeader->getArgument(0);
498 Value baseTmp = loopHeader->getArgument(1);
499 Value powerTmp = loopHeader->getArgument(2);
500 builder.setInsertionPointToEnd(loopHeader);
501
502 // if (p & Tp{1})
503 auto powerTmpIsOdd = arith::CmpIOp::create(
504 builder, arith::CmpIPredicate::ne,
505 arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue);
506 thenBlock = builder.createBlock(funcBody);
507 // result *= b;
508 Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp);
509 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
510 builder.getLoc());
511 builder.setInsertionPointToEnd(thenBlock);
512 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
513 // Set up conditional branch for (p & Tp{1}).
514 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
515 cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock,
516 resultTmp);
517 // Merged 'result'.
518 newResultTmp = fallthroughBlock->getArgument(0);
519
520 // p >>= Tp{1};
521 builder.setInsertionPointToEnd(fallthroughBlock);
522 Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue);
523
524 // if (p == Tp{0})
525 auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq,
526 newPowerTmp, zeroPValue);
527 // break;
528 //
529 // The conditional branch is finalized below with a jump to
530 // the loop exit block.
531 fallthroughBlock = builder.createBlock(funcBody);
532
533 // b *= b;
534 // }
535 builder.setInsertionPointToEnd(fallthroughBlock);
536 Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp);
537 // Pass new values for 'result', 'b' and 'p' to the loop header.
538 cf::BranchOp::create(
539 builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
540
541 // Set up conditional branch for early loop exit:
542 // if (p == Tp{0})
543 // break;
544 Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
545 builder.getLoc());
546 builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
547 cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp,
548 fallthroughBlock, ValueRange{});
549
550 // if (isMin) {
551 // result *= origBase;
552 // }
553 newResultTmp = loopExit->getArgument(0);
554 thenBlock = builder.createBlock(funcBody);
555 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
556 builder.getLoc());
557 builder.setInsertionPointToEnd(loopExit);
558 cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock,
559 newResultTmp);
560 builder.setInsertionPointToEnd(thenBlock);
561 newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg);
562 cf::BranchOp::create(builder, newResultTmp, fallthroughBlock);
563
564 /// if (isNegativePower) {
565 /// result = Tb{1} / result;
566 /// }
567 newResultTmp = fallthroughBlock->getArgument(0);
568 thenBlock = builder.createBlock(funcBody);
569 Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
570 builder.getLoc());
571 builder.setInsertionPointToEnd(fallthroughBlock);
572 cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock,
573 newResultTmp);
574 builder.setInsertionPointToEnd(thenBlock);
575 newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp);
576 cf::BranchOp::create(builder, newResultTmp, returnBlock);
577
578 // return result;
579 builder.setInsertionPointToEnd(returnBlock);
580 func::ReturnOp::create(builder, returnBlock->getArgument(0));
581
582 return funcOp;
583}
584
585/// Convert FPowI into a call to a local function implementing
586/// the power operation. The local function computes a scalar result,
587/// so vector forms of FPowI are linearized.
588LogicalResult
589FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
590 PatternRewriter &rewriter) const {
591 if (isa<VectorType>(op.getType()))
592 return rewriter.notifyMatchFailure(op, "non-scalar operation");
593
594 FunctionType funcType = getElementalFuncTypeForOp(op);
595
596 // The outlined software implementation must have been already
597 // generated.
598 func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
599 if (!elementFunc)
600 return rewriter.notifyMatchFailure(op, "missing software implementation");
601
602 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
603 return success();
604}
605
606/// Create function to implement the ctlz function the given \p elementType type
607/// inside \p module. The \p elementType must be IntegerType, an the created
608/// function has 'IntegerType (*)(IntegerType)' function type.
609///
610/// template <typename T>
611/// T __mlir_math_ctlz_*(T x) {
612/// bits = sizeof(x) * 8;
613/// if (x == 0)
614/// return bits;
615///
616/// uint32_t n = 0;
617/// for (int i = 1; i < bits; ++i) {
618/// if (x < 0) continue;
619/// n++;
620/// x <<= 1;
621/// }
622/// return n;
623/// }
624///
625/// Converts to (for i32):
626///
627/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
628/// %c_32 = arith.constant 32 : index
629/// %c_0 = arith.constant 0 : i32
630/// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
631/// %out = scf.if %arg_eq_zero {
632/// scf.yield %c_32 : i32
633/// } else {
634/// %c_1index = arith.constant 1 : index
635/// %c_1i32 = arith.constant 1 : i32
636/// %n = arith.constant 0 : i32
637/// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
638/// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
639/// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
640/// %yield_val = scf.if %cond {
641/// scf.yield %arg_iter, %n_iter : i32, i32
642/// } else {
643/// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
644/// %n_next = arith.addi %n_iter, %c_1i32 : i32
645/// scf.yield %arg_next, %n_next : i32, i32
646/// }
647/// scf.yield %yield_val: i32, i32
648/// }
649/// scf.yield %n_out : i32
650/// }
651/// return %out: i32
652/// }
653static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
654 if (!isa<IntegerType>(elementType)) {
655 LDBG() << "non-integer element type for CtlzFunc; type was: "
656 << elementType;
657 llvm_unreachable("non-integer element type");
658 }
659 int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660
661 Location loc = module->getLoc();
662 ImplicitLocOpBuilder builder =
663 ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
664
665 std::string funcName("__mlir_math_ctlz");
666 llvm::raw_string_ostream nameOS(funcName);
667 nameOS << '_' << elementType;
668 FunctionType funcType =
669 FunctionType::get(builder.getContext(), {elementType}, elementType);
670 auto funcOp = func::FuncOp::create(builder, funcName, funcType);
671
672 // LinkonceODR ensures that there is only one implementation of this function
673 // across all math.ctlz functions that are lowered in this way.
674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675 Attribute linkage =
676 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677 funcOp->setAttr("llvm.linkage", linkage);
678 funcOp.setPrivate();
679
680 // set the insertion point to the start of the function
681 Block *funcBody = funcOp.addEntryBlock();
682 builder.setInsertionPointToStart(funcBody);
683
684 Value arg = funcOp.getArgument(0);
685 Type indexType = builder.getIndexType();
686 Value bitWidthValue = arith::ConstantOp::create(
687 builder, elementType, builder.getIntegerAttr(elementType, bitWidth));
688 Value zeroValue = arith::ConstantOp::create(
689 builder, elementType, builder.getIntegerAttr(elementType, 0));
690
691 Value inputEqZero =
692 arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue);
693
694 // if input == 0, return bit width, else enter loop.
695 scf::IfOp ifOp =
696 scf::IfOp::create(builder, elementType, inputEqZero,
697 /*addThenBlock=*/true, /*addElseBlock=*/true);
698 auto thenBuilder = ifOp.getThenBodyBuilder();
699 scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
700
701 auto elseBuilder =
702 ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
703
704 Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType,
705 elseBuilder.getIndexAttr(1));
706 Value oneValue = arith::ConstantOp::create(
707 elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1));
708 Value bitWidthIndex = arith::ConstantOp::create(
709 elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth));
710 Value nValue = arith::ConstantOp::create(
711 elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0));
712
713 auto loop = scf::ForOp::create(
714 elseBuilder, oneIndex, bitWidthIndex, oneIndex,
715 // Initial values for two loop induction variables, the arg which is being
716 // shifted left in each iteration, and the n value which tracks the count
717 // of leading zeros.
718 ValueRange{arg, nValue},
719 // Callback to build the body of the for loop
720 // if (arg < 0) {
721 // continue;
722 // } else {
723 // n++;
724 // arg <<= 1;
725 // }
726 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
727 Value argIter = args[0];
728 Value nIter = args[1];
729
730 Value argIsNonNegative = arith::CmpIOp::create(
731 b, loc, arith::CmpIPredicate::slt, argIter, zeroValue);
732 scf::IfOp ifOp = scf::IfOp::create(
733 b, loc, argIsNonNegative,
734 [&](OpBuilder &b, Location loc) {
735 // If arg is negative, continue (effectively, break)
736 scf::YieldOp::create(b, loc, ValueRange{argIter, nIter});
737 },
738 [&](OpBuilder &b, Location loc) {
739 // Otherwise, increment n and shift arg left.
740 Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue);
741 Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue);
742 scf::YieldOp::create(b, loc, ValueRange{argNext, nNext});
743 });
744 scf::YieldOp::create(b, loc, ifOp.getResults());
745 });
746 scf::YieldOp::create(elseBuilder, loop.getResult(1));
747
748 func::ReturnOp::create(builder, ifOp.getResult(0));
749 return funcOp;
750}
751
752/// Convert ctlz into a call to a local function implementing the ctlz
753/// operation.
754LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
755 PatternRewriter &rewriter) const {
756 if (isa<VectorType>(op.getType()))
757 return rewriter.notifyMatchFailure(op, "non-scalar operation");
758
759 Type type = getElementTypeOrSelf(op.getResult().getType());
760 func::FuncOp elementFunc = getFuncOpCallback(op, type);
761 if (!elementFunc)
762 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
763 diag << "Missing software implementation for op " << op->getName()
764 << " and type " << type;
765 });
766
767 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
768 return success();
769}
770
771namespace {
772struct ConvertMathToFuncsPass
773 : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
774 ConvertMathToFuncsPass() = default;
775 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
776 : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
777
778 void runOnOperation() override;
779
780private:
781 // Return true, if this FPowI operation must be converted
782 // because the width of its exponent's type is greater than
783 // or equal to minWidthOfFPowIExponent option value.
784 bool isFPowIConvertible(math::FPowIOp op);
785
786 // Reture true, if operation is integer type.
787 bool isConvertible(Operation *op);
788
789 // Generate outlined implementations for power operations
790 // and store them in funcImpls map.
791 void generateOpImplementations();
792
793 // A map between pairs of (operation, type) deduced from operations that this
794 // pass will convert, and the corresponding outlined software implementations
795 // of these operations for the given type.
796 DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
797};
798} // namespace
799
800bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
801 auto expTy =
802 dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
803 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
804}
805
806bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
807 return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType()));
808}
809
810void ConvertMathToFuncsPass::generateOpImplementations() {
811 ModuleOp module = getOperation();
812
813 module.walk([&](Operation *op) {
814 TypeSwitch<Operation *>(op)
815 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
816 if (!convertCtlz || !isConvertible(op))
817 return;
818 Type resultType = getElementTypeOrSelf(op.getResult().getType());
819
820 // Generate the software implementation of this operation,
821 // if it has not been generated yet.
822 auto key = std::pair(op->getName(), resultType);
823 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
824 if (entry.second)
825 entry.first->second = createCtlzFunc(&module, resultType);
826 })
827 .Case<math::IPowIOp>([&](math::IPowIOp op) {
828 if (!isConvertible(op))
829 return;
830
831 Type resultType = getElementTypeOrSelf(op.getResult().getType());
832
833 // Generate the software implementation of this operation,
834 // if it has not been generated yet.
835 auto key = std::pair(op->getName(), resultType);
836 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
837 if (entry.second)
838 entry.first->second = createElementIPowIFunc(&module, resultType);
839 })
840 .Case<math::FPowIOp>([&](math::FPowIOp op) {
841 if (!isFPowIConvertible(op))
842 return;
843
844 FunctionType funcType = getElementalFuncTypeForOp(op);
845
846 // Generate the software implementation of this operation,
847 // if it has not been generated yet.
848 // FPowI implementations are mapped via the FunctionType
849 // created from the operation's result and operands.
850 auto key = std::pair(op->getName(), funcType);
851 auto entry = funcImpls.try_emplace(key, func::FuncOp{});
852 if (entry.second)
853 entry.first->second = createElementFPowIFunc(&module, funcType);
854 });
855 });
856}
857
858void ConvertMathToFuncsPass::runOnOperation() {
859 ModuleOp module = getOperation();
860
861 // Create outlined implementations for power operations.
862 generateOpImplementations();
863
864 RewritePatternSet patterns(&getContext());
865 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
866 VecOpToScalarOp<math::CountLeadingZerosOp>>(
867 patterns.getContext());
868
869 // For the given Type Returns FuncOp stored in funcImpls map.
870 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
871 auto it = funcImpls.find(std::pair(op->getName(), type));
872 if (it == funcImpls.end())
873 return {};
874
875 return it->second;
876 };
877 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
878 getFuncOpByType);
879
880 if (convertCtlz)
881 patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
882
883 ConversionTarget target(getContext());
884 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
885 func::FuncDialect, scf::SCFDialect,
886 vector::VectorDialect>();
887
888 target.addDynamicallyLegalOp<math::IPowIOp>(
889 [this](math::IPowIOp op) { return !isConvertible(op); });
890 if (convertCtlz) {
891 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
892 [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
893 }
894 target.addDynamicallyLegalOp<math::FPowIOp>(
895 [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
896 if (failed(applyPartialConversion(module, target, std::move(patterns))))
897 signalPassFailure();
898}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType)
Create linkonce_odr function to implement the power function with the given elementType type inside m...
static FunctionType getElementalFuncTypeForOp(Operation *op)
static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType)
Create linkonce_odr function to implement the power function with the given funcType type inside modu...
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType)
Create function to implement the ctlz function the given elementType type inside module.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:647
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Block * getBlock() const
Returns the current block of the builder.
Definition Builders.h:448
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
operand_type_iterator operand_type_end()
Definition Operation.h:396
unsigned getNumOperands()
Definition Operation.h:346
result_type_iterator result_type_end()
Definition Operation.h:427
result_type_iterator result_type_begin()
Definition Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
operand_type_iterator operand_type_begin()
Definition Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator end()
Definition Region.h:56
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...