MLIR 22.0.0git
LoopRangeFolding.cpp
Go to the documentation of this file.
1//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop range folding.
10//
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/IRMapping.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING
23#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::scf;
28
29namespace {
30struct ForLoopRangeFolding
31 : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
32 void runOnOperation() override;
33};
34} // namespace
35
36void ForLoopRangeFolding::runOnOperation() {
37 getOperation()->walk([&](ForOp op) {
38 Value indVar = op.getInductionVar();
39
40 auto canBeFolded = [&](Value value) {
41 return op.isDefinedOutsideOfLoop(value) || value == indVar;
42 };
43
44 // Fold until a fixed point is reached
45 while (true) {
46
47 // If the induction variable is used more than once, we can't fold its
48 // arith ops into the loop range
49 if (!indVar.hasOneUse())
50 break;
51
52 Operation *user = *indVar.getUsers().begin();
53 if (!isa<arith::AddIOp, arith::MulIOp>(user))
54 break;
55
56 if (!llvm::all_of(user->getOperands(), canBeFolded))
57 break;
58
59 OpBuilder b(op);
60 IRMapping lbMap;
61 lbMap.map(indVar, op.getLowerBound());
62 IRMapping ubMap;
63 ubMap.map(indVar, op.getUpperBound());
64 IRMapping stepMap;
65 stepMap.map(indVar, op.getStep());
66
67 if (isa<arith::AddIOp>(user)) {
68 Operation *lbFold = b.clone(*user, lbMap);
69 Operation *ubFold = b.clone(*user, ubMap);
70
71 op.setLowerBound(lbFold->getResult(0));
72 op.setUpperBound(ubFold->getResult(0));
73
74 } else if (isa<arith::MulIOp>(user)) {
75 Operation *lbFold = b.clone(*user, lbMap);
76 Operation *ubFold = b.clone(*user, ubMap);
77 Operation *stepFold = b.clone(*user, stepMap);
78
79 op.setLowerBound(lbFold->getResult(0));
80 op.setUpperBound(ubFold->getResult(0));
81 op.setStep(stepFold->getResult(0));
82 }
83
84 ValueRange wrapIndvar(indVar);
85 user->replaceAllUsesWith(wrapIndvar);
86 user->erase();
87 }
88 });
89}
90
91std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
92 return std::make_unique<ForLoopRangeFolding>();
93}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition Operation.h:272
void erase()
Remove this operation from its parent block and delete it.
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Include the generated interface declarations.
std::unique_ptr< Pass > createForLoopRangeFoldingPass()
Creates a pass which folds arith ops on induction variable into loop range.