MLIR  22.0.0git
ParallelLoopTiling.cpp
Go to the documentation of this file.
1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop tiling on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFPARALLELLOOPTILING
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 /// Tile a parallel loop of the form
30 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
31 /// step (%arg4, %arg5)
32 ///
33 /// into
34 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
35 /// step (%arg4*tileSize[0],
36 /// %arg5*tileSize[1])
37 /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
38 /// min(%arg5*tileSize[1], %arg3-%i1))
39 /// step (%arg4, %arg5)
40 ///
41 /// or, when no-min-max-bounds is true, into
42 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
43 /// step (%arg4*tileSize[0],
44 /// %arg5*tileSize[1])
45 /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
46 /// %arg5*tileSize[1])
47 /// step (%arg4, %arg5)
48 /// %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
49 /// (%j1 * %arg5 + %i1 < %arg3)
50 /// scf.if (%inbound)
51 /// ....
52 ///
53 /// where the uses of %i0 and %i1 in the loop body are replaced by
54 /// %i0 + j0 and %i1 + %j1.
55 ///
56 /// The old loop is replaced with the new one.
57 std::pair<ParallelOp, ParallelOp>
58 mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
59  bool noMinMaxBounds) {
60  OpBuilder b(op);
61  auto zero = arith::ConstantIndexOp::create(b, op.getLoc(), 0);
62  SmallVector<Value, 2> tileSizeConstants;
63  tileSizeConstants.reserve(op.getUpperBound().size());
64  for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
65  if (i < tileSizes.size())
66  tileSizeConstants.push_back(
67  arith::ConstantIndexOp::create(b, op.getLoc(), tileSizes[i]));
68  else
69  // Just pick 1 for the remaining dimensions.
70  tileSizeConstants.push_back(
71  arith::ConstantIndexOp::create(b, op.getLoc(), 1));
72  }
73 
74  // Create the outer loop with adjusted steps.
75  SmallVector<Value, 2> newSteps;
76  newSteps.reserve(op.getStep().size());
77  for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
78  newSteps.push_back(arith::MulIOp::create(b, op.getLoc(), std::get<0>(step),
79  std::get<1>(step)));
80  }
81  auto outerLoop = ParallelOp::create(b, op.getLoc(), op.getLowerBound(),
82  op.getUpperBound(), newSteps);
83  b.setInsertionPointToStart(outerLoop.getBody());
84 
85  // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
86  auto minMap = AffineMap::get(
87  /*dimCount=*/3, /*symbolCount=*/0,
88  {getAffineDimExpr(/*position=*/0, b.getContext()),
89  getAffineDimExpr(/*position=*/1, b.getContext()) -
90  getAffineDimExpr(/*position=*/2, b.getContext())},
91  b.getContext());
92 
93  // Create the inner loop with adjusted bounds.
94  SmallVector<Value, 2> newBounds;
95  newBounds.reserve(op.getUpperBound().size());
96  bool needInboundCheck = false;
97  for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] :
98  llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(),
99  outerLoop.getStep(), outerLoop.getInductionVars(),
100  op.getStep(), tileSizeConstants)) {
101  // Collect the statically known loop bounds
102  auto lowerBoundConstant =
103  lowerBound.getDefiningOp<arith::ConstantIndexOp>();
104  auto upperBoundConstant =
105  upperBound.getDefiningOp<arith::ConstantIndexOp>();
106  auto stepConstant = step.getDefiningOp<arith::ConstantIndexOp>();
107  auto tileSize =
108  cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
109  // If the loop bounds and the loop step are constant and if the number of
110  // loop iterations is an integer multiple of the tile size, we use a static
111  // bound for the inner loop.
112  if (lowerBoundConstant && upperBoundConstant && stepConstant) {
113  auto numIterations = llvm::divideCeil(upperBoundConstant.value() -
114  lowerBoundConstant.value(),
115  stepConstant.value());
116  if (numIterations % tileSize == 0) {
117  newBounds.push_back(newStep);
118  continue;
119  }
120  }
121 
122  // For InboundCheck mode, just use the variable outer step
123  if (noMinMaxBounds) {
124  newBounds.push_back(newStep);
125  needInboundCheck = true;
126  continue;
127  }
128 
129  // Otherwise, we dynamically compute the bound for
130  // each iteration of the outer loop.
131  newBounds.push_back(
132  affine::AffineMinOp::create(b, op.getLoc(), b.getIndexType(), minMap,
133  ValueRange{newStep, upperBound, iv}));
134  }
135  auto innerLoop = ParallelOp::create(
136  b, op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
137  op.getStep());
138 
139  if (noMinMaxBounds && needInboundCheck) {
140  b.setInsertionPointToStart(innerLoop.getBody());
141  // Insert in-bound check
142  Value inbound =
143  arith::ConstantIntOp::create(b, op.getLoc(), b.getIntegerType(1), 1);
144  for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
145  llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
146  innerLoop.getInductionVars(), innerLoop.getStep())) {
147  // %in_bound = %in_bound &&
148  // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
149  Value index = arith::AddIOp::create(
150  b, op.getLoc(),
151  arith::MulIOp::create(b, op.getLoc(), innerIV, innerStep), outerIV);
152  Value dimInbound = arith::CmpIOp::create(
153  b, op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
154  inbound = arith::AndIOp::create(b, op.getLoc(), inbound, dimInbound);
155  }
156  auto ifInbound = IfOp::create(b, op.getLoc(),
157  /*resultTypes*/ ArrayRef<Type>{}, inbound,
158  /*hasElseRegion*/ false);
159  ifInbound.getThenRegion().takeBody(op.getRegion());
160  Block &thenBlock = ifInbound.getThenRegion().front();
161  // Replace the scf.reduce terminator with an scf.yield terminator.
162  Operation *reduceOp = thenBlock.getTerminator();
163  b.setInsertionPointToEnd(&thenBlock);
164  scf::YieldOp::create(b, reduceOp->getLoc());
165  reduceOp->erase();
166  b.setInsertionPointToStart(innerLoop.getBody());
167  for (const auto &ivs : llvm::enumerate(llvm::zip(
168  innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
169  auto newIndex = arith::AddIOp::create(
170  b, op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
171  thenBlock.getArgument(ivs.index())
172  .replaceAllUsesExcept(newIndex, newIndex);
173  }
174  thenBlock.eraseArguments(0, thenBlock.getNumArguments());
175  } else {
176  innerLoop.getRegion().takeBody(op.getRegion());
177  b.setInsertionPointToStart(innerLoop.getBody());
178  for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
179  outerLoop.getInductionVars())) {
180  Value innerIndex = std::get<0>(ivs);
181  auto newIndex = arith::AddIOp::create(b, op.getLoc(), std::get<0>(ivs),
182  std::get<1>(ivs));
183  innerIndex.replaceAllUsesExcept(newIndex, newIndex);
184  }
185  }
186 
187  op.erase();
188  return std::make_pair(outerLoop, innerLoop);
189 }
190 
191 namespace {
192 struct ParallelLoopTiling
193  : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> {
194  ParallelLoopTiling() = default;
195  explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
196  bool noMinMaxBounds = false) {
197  this->tileSizes = tileSizes;
198  this->noMinMaxBounds = noMinMaxBounds;
199  }
200 
201  void runOnOperation() override {
202  for (auto tileSize : tileSizes)
203  if (tileSize == 0) {
205  "tile size cannot be 0");
206  return signalPassFailure();
207  }
208  auto *parentOp = getOperation();
209  SmallVector<ParallelOp, 2> innermostPloops;
210  getInnermostParallelLoops(parentOp, innermostPloops);
211  for (ParallelOp ploop : innermostPloops) {
212  // FIXME: Add reduction support.
213  if (ploop.getNumReductions() == 0)
214  tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
215  }
216  }
217 };
218 } // namespace
219 
220 std::unique_ptr<Pass>
222  bool noMinMaxBounds) {
223  return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
224 }
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
Operation & front()
Definition: Block.h:153
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
Definition: Pass.h:177
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation * > &exceptions)
Replace all uses of 'this' value with 'newValue', updating anything in the IR that uses 'this' to use...
Definition: Value.cpp:71
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< ParallelOp, ParallelOp > tileParallelLoop(ParallelOp op, llvm::ArrayRef< int64_t > tileSizes, bool noMinMaxBounds)
Tile a parallel loop of the form scf.parallel (i0, i1) = (arg0, arg1) to (arg2, arg3) step (arg4,...
Include the generated interface declarations.
std::unique_ptr< Pass > createParallelLoopTilingPass(llvm::ArrayRef< int64_t > tileSize={}, bool noMinMaxBounds=false)
Creates a pass which tiles innermost parallel loops.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:240
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619