MLIR  16.0.0git
ParallelLoopTiling.cpp
Go to the documentation of this file.
1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop tiling on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
20 
21 using namespace mlir;
22 using namespace mlir::scf;
23 
24 /// Tile a parallel loop of the form
25 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
26 /// step (%arg4, %arg5)
27 ///
28 /// into
29 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
30 /// step (%arg4*tileSize[0],
31 /// %arg5*tileSize[1])
32 /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
33 /// min(%arg5*tileSize[1], %arg3-%i1))
34 /// step (%arg4, %arg5)
35 ///
36 /// or, when no-min-max-bounds is true, into
37 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
38 /// step (%arg4*tileSize[0],
39 /// %arg5*tileSize[1])
40 /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
41 /// %arg5*tileSize[1])
42 /// step (%arg4, %arg5)
43 /// %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
44 /// (%j1 * %arg5 + %i1 < %arg3)
45 /// scf.if (%inbound)
46 /// ....
47 ///
48 /// where the uses of %i0 and %i1 in the loop body are replaced by
49 /// %i0 + j0 and %i1 + %j1.
50 ///
51 /// The old loop is replaced with the new one.
52 std::pair<ParallelOp, ParallelOp>
54  bool noMinMaxBounds) {
55  OpBuilder b(op);
56  auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
57  SmallVector<Value, 2> tileSizeConstants;
58  tileSizeConstants.reserve(op.getUpperBound().size());
59  for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) {
60  if (i < tileSizes.size())
61  tileSizeConstants.push_back(
62  b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i]));
63  else
64  // Just pick 1 for the remaining dimensions.
65  tileSizeConstants.push_back(
66  b.create<arith::ConstantIndexOp>(op.getLoc(), 1));
67  }
68 
69  // Create the outer loop with adjusted steps.
70  SmallVector<Value, 2> newSteps;
71  newSteps.reserve(op.getStep().size());
72  for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) {
73  newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step),
74  std::get<1>(step)));
75  }
76  auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(),
77  op.getUpperBound(), newSteps);
78  b.setInsertionPointToStart(outerLoop.getBody());
79 
80  // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
81  auto minMap = AffineMap::get(
82  /*dimCount=*/3, /*symbolCount=*/0,
83  {getAffineDimExpr(/*position=*/0, b.getContext()),
84  getAffineDimExpr(/*position=*/1, b.getContext()) -
85  getAffineDimExpr(/*position=*/2, b.getContext())},
86  b.getContext());
87 
88  // Create the inner loop with adjusted bounds.
89  SmallVector<Value, 2> newBounds;
90  newBounds.reserve(op.getUpperBound().size());
91  bool needInboundCheck = false;
92  for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] :
93  llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(),
94  outerLoop.getStep(), outerLoop.getInductionVars(),
95  op.getStep(), tileSizeConstants)) {
96  // Collect the statically known loop bounds
97  auto lowerBoundConstant =
98  dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
99  auto upperBoundConstant =
100  dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
101  auto stepConstant =
102  dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
103  auto tileSize =
104  cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
105  // If the loop bounds and the loop step are constant and if the number of
106  // loop iterations is an integer multiple of the tile size, we use a static
107  // bound for the inner loop.
108  if (lowerBoundConstant && upperBoundConstant && stepConstant) {
109  auto numIterations = llvm::divideCeil(upperBoundConstant.value() -
110  lowerBoundConstant.value(),
111  stepConstant.value());
112  if (numIterations % tileSize == 0) {
113  newBounds.push_back(newStep);
114  continue;
115  }
116  }
117 
118  // For InboundCheck mode, just use the variable outer step
119  if (noMinMaxBounds) {
120  newBounds.push_back(newStep);
121  needInboundCheck = true;
122  continue;
123  }
124 
125  // Otherwise, we dynamically compute the bound for
126  // each iteration of the outer loop.
127  newBounds.push_back(
128  b.create<AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
129  ValueRange{newStep, upperBound, iv}));
130  }
131  auto innerLoop = b.create<ParallelOp>(
132  op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
133  op.getStep());
134 
135  if (noMinMaxBounds && needInboundCheck) {
136  b.setInsertionPointToStart(innerLoop.getBody());
137  // Insert in-bound check
138  Value inbound =
139  b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
140  for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
141  llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
142  innerLoop.getInductionVars(), innerLoop.getStep())) {
143  // %in_bound = %in_bound &&
144  // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
145  Value index = b.create<arith::AddIOp>(
146  op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep),
147  outerIV);
148  Value dimInbound = b.create<arith::CmpIOp>(
149  op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound);
150  inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound);
151  }
152  auto ifInbound = b.create<IfOp>(op.getLoc(),
153  /*resultTypes*/ ArrayRef<Type>{}, inbound,
154  /*hasElseRegion*/ false);
155  ifInbound.getThenRegion().takeBody(op.getRegion());
156  Block &thenBlock = ifInbound.getThenRegion().front();
157  b.setInsertionPointToStart(innerLoop.getBody());
158  for (const auto &ivs : llvm::enumerate(llvm::zip(
159  innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
160  auto newIndex = b.create<arith::AddIOp>(
161  op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value()));
162  thenBlock.getArgument(ivs.index())
163  .replaceAllUsesExcept(newIndex, newIndex);
164  }
165  thenBlock.eraseArguments(llvm::to_vector<4>(
166  llvm::seq((unsigned)0, thenBlock.getNumArguments())));
167  } else {
168  innerLoop.getRegion().takeBody(op.getRegion());
169  b.setInsertionPointToStart(innerLoop.getBody());
170  for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
171  outerLoop.getInductionVars())) {
172  Value innerIndex = std::get<0>(ivs);
173  auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs),
174  std::get<1>(ivs));
175  innerIndex.replaceAllUsesExcept(newIndex, newIndex);
176  }
177  }
178 
179  op.erase();
180  return std::make_pair(outerLoop, innerLoop);
181 }
182 
183 namespace {
184 struct ParallelLoopTiling
185  : public SCFParallelLoopTilingBase<ParallelLoopTiling> {
186  ParallelLoopTiling() = default;
187  explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
188  bool noMinMaxBounds = false) {
189  this->tileSizes = tileSizes;
190  this->noMinMaxBounds = noMinMaxBounds;
191  }
192 
193  void runOnOperation() override {
194  auto *parentOp = getOperation();
195  SmallVector<ParallelOp, 2> innermostPloops;
196  getInnermostParallelLoops(parentOp, innermostPloops);
197  for (ParallelOp ploop : innermostPloops) {
198  // FIXME: Add reduction support.
199  if (ploop.getNumReductions() == 0)
200  tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
201  }
202  }
203 };
204 } // namespace
205 
206 std::unique_ptr<Pass>
208  bool noMinMaxBounds) {
209  return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
210 }
Include the generated interface declarations.
MLIRContext * getContext() const
Definition: Builders.h:54
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:266
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:43
Block represents an ordered list of Operations.
Definition: Block.h:29
Operation & front()
Definition: Block.h:144
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
void eraseArguments(ArrayRef< unsigned > argIndices)
Erases the arguments listed in argIndices and removes them from the argument list.
Definition: Block.cpp:189
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumArguments()
Definition: Block.h:119
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
std::pair< ParallelOp, ParallelOp > tileParallelLoop(ParallelOp op, llvm::ArrayRef< int64_t > tileSizes, bool noMinMaxBounds)
Tile a parallel loop of the form scf.parallel (i0, i1) = (arg0, arg1) to (arg2, arg3) step (arg4...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
IndexType getIndexType()
Definition: Builders.cpp:48
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
std::unique_ptr< Pass > createParallelLoopTilingPass(llvm::ArrayRef< int64_t > tileSize={}, bool noMinMaxBounds=false)
Creates a pass which tiles innermost parallel loops.
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation *> &exceptions) const
Replace all uses of &#39;this&#39; value with &#39;newValue&#39;, updating anything in the IR that uses &#39;this&#39; to use...
Definition: Value.cpp:61