MLIR  20.0.0git
LoopRangeFolding.cpp
Go to the documentation of this file.
1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop range folding.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/IRMapping.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 namespace {
30 struct ForLoopRangeFolding
31  : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
32  void runOnOperation() override;
33 };
34 } // namespace
35 
36 void ForLoopRangeFolding::runOnOperation() {
37  getOperation()->walk([&](ForOp op) {
38  Value indVar = op.getInductionVar();
39 
40  auto canBeFolded = [&](Value value) {
41  return op.isDefinedOutsideOfLoop(value) || value == indVar;
42  };
43 
44  // Fold until a fixed point is reached
45  while (true) {
46 
47  // If the induction variable is used more than once, we can't fold its
48  // arith ops into the loop range
49  if (!indVar.hasOneUse())
50  break;
51 
52  Operation *user = *indVar.getUsers().begin();
53  if (!isa<arith::AddIOp, arith::MulIOp>(user))
54  break;
55 
56  if (!llvm::all_of(user->getOperands(), canBeFolded))
57  break;
58 
59  OpBuilder b(op);
60  IRMapping lbMap;
61  lbMap.map(indVar, op.getLowerBound());
62  IRMapping ubMap;
63  ubMap.map(indVar, op.getUpperBound());
64  IRMapping stepMap;
65  stepMap.map(indVar, op.getStep());
66 
67  if (isa<arith::AddIOp>(user)) {
68  Operation *lbFold = b.clone(*user, lbMap);
69  Operation *ubFold = b.clone(*user, ubMap);
70 
71  op.setLowerBound(lbFold->getResult(0));
72  op.setUpperBound(ubFold->getResult(0));
73 
74  } else if (isa<arith::MulIOp>(user)) {
75  Operation *ubFold = b.clone(*user, ubMap);
76  Operation *stepFold = b.clone(*user, stepMap);
77 
78  op.setUpperBound(ubFold->getResult(0));
79  op.setStep(stepFold->getResult(0));
80  }
81 
82  ValueRange wrapIndvar(indVar);
83  user->replaceAllUsesWith(wrapIndvar);
84  user->erase();
85  }
86  });
87 }
88 
89 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
90  return std::make_unique<ForLoopRangeFolding>();
91 }
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class helps build Operations.
Definition: Builders.h:215
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Include the generated interface declarations.
std::unique_ptr< Pass > createForLoopRangeFoldingPass()
Creates a pass which folds arith ops on induction variable into loop range.