MLIR 23.0.0git
LoopRangeFolding.cpp
Go to the documentation of this file.
1//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop range folding.
10//
11//===----------------------------------------------------------------------===//
12
14
20#include "mlir/IR/IRMapping.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_SCFFORLOOPRANGEFOLDING
24#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::scf;
29
30namespace {
31struct ForLoopRangeFolding
32 : public impl::SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
33 void runOnOperation() override;
34};
35} // namespace
36
37void ForLoopRangeFolding::runOnOperation() {
38 getOperation()->walk([&](ForOp op) {
39 Value indVar = op.getInductionVar();
40
41 auto canBeFolded = [&](Value value) {
42 return op.isDefinedOutsideOfLoop(value) || value == indVar;
43 };
44
45 // Fold until a fixed point is reached
46 while (true) {
47
48 // If the induction variable is used more than once, we can't fold its
49 // arith ops into the loop range
50 if (!indVar.hasOneUse())
51 break;
52
53 Operation *user = *indVar.getUsers().begin();
54 if (!isa<arith::AddIOp, arith::MulIOp>(user))
55 break;
56
57 if (!llvm::all_of(user->getOperands(), canBeFolded))
58 break;
59
60 OpBuilder b(op);
61 IRMapping lbMap;
62 lbMap.map(indVar, op.getLowerBound());
63 IRMapping ubMap;
64 ubMap.map(indVar, op.getUpperBound());
65 IRMapping stepMap;
66 stepMap.map(indVar, op.getStep());
67
68 if (isa<arith::AddIOp>(user)) {
69 Operation *lbFold = b.clone(*user, lbMap);
70 Operation *ubFold = b.clone(*user, ubMap);
71
72 op.setLowerBound(lbFold->getResult(0));
73 op.setUpperBound(ubFold->getResult(0));
74
75 } else if (auto mulOp = dyn_cast<arith::MulIOp>(user)) {
76 // Only fold if the multiplier is a known strictly positive constant.
77 // Multiplying by zero or a negative value would produce an invalid
78 // step (scf.for requires a strictly positive step).
79 Value multiplier =
80 (mulOp.getLhs() == indVar) ? mulOp.getRhs() : mulOp.getLhs();
81 std::optional<int64_t> multiplierVal = getConstantIntValue(multiplier);
82 if (!multiplierVal || *multiplierVal <= 0)
83 break;
84
85 Operation *lbFold = b.clone(*user, lbMap);
86 Operation *ubFold = b.clone(*user, ubMap);
87 Operation *stepFold = b.clone(*user, stepMap);
88
89 op.setLowerBound(lbFold->getResult(0));
90 op.setUpperBound(ubFold->getResult(0));
91 op.setStep(stepFold->getResult(0));
92 }
93
94 ValueRange wrapIndvar(indVar);
95 user->replaceAllUsesWith(wrapIndvar);
96 user->erase();
97 }
98 });
99}
100
101std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
102 return std::make_unique<ForLoopRangeFolding>();
103}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition Operation.h:298
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
void erase()
Remove this operation from its parent block and delete it.
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::unique_ptr< Pass > createForLoopRangeFoldingPass()
Creates a pass which folds arith ops on induction variable into loop range.