MLIR  22.0.0git
ParallelForToNestedFors.cpp
Go to the documentation of this file.
1 //===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ParallelOp to nested scf.for ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/Support/Debug.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 #define DEBUG_TYPE "parallel-for-to-nested-fors"
25 using namespace mlir;
26 
27 FailureOr<scf::LoopNest>
29  scf::ParallelOp parallelOp) {
30 
31  if (!parallelOp.getResults().empty())
32  return rewriter.notifyMatchFailure(
33  parallelOp, "Currently scf.parallel to scf.for conversion doesn't "
34  "support scf.parallel with results.");
35 
36  rewriter.setInsertionPoint(parallelOp);
37 
38  Location loc = parallelOp.getLoc();
39  SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
40  SmallVector<Value> upperBounds = parallelOp.getUpperBound();
41  SmallVector<Value> steps = parallelOp.getStep();
42 
43  assert(lowerBounds.size() == upperBounds.size() &&
44  lowerBounds.size() == steps.size() &&
45  "Mismatched parallel loop bounds");
46 
48  scf::LoopNest loopNest =
49  scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
50 
51  SmallVector<Value> newInductionVars = llvm::map_to_vector(
52  loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
53  Block *linearizedBody = loopNest.loops.back().getBody();
54  Block *parallelBody = parallelOp.getBody();
55  rewriter.eraseOp(parallelBody->getTerminator());
56  rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
57  newInductionVars);
58  rewriter.eraseOp(parallelOp);
59  return loopNest;
60 }
61 
62 namespace {
63 struct ParallelForToNestedFors final
64  : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
65  void runOnOperation() override {
66  Operation *parentOp = getOperation();
67  IRRewriter rewriter(parentOp->getContext());
68 
69  parentOp->walk(
70  [&](scf::ParallelOp parallelOp) {
71  if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
72  LLVM_DEBUG(
73  llvm::dbgs()
74  << "Failed to convert scf.parallel to nested scf.for ops for:\n"
75  << parallelOp << "\n");
76  return WalkResult::advance();
77  }
78  return WalkResult::advance();
79  });
80  }
81 };
82 } // namespace
83 
84 std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
85  return std::make_unique<ParallelForToNestedFors>();
86 }
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
static WalkResult advance()
Definition: WalkResult.h:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< scf::LoopNest > parallelForToNestedFors(RewriterBase &rewriter, ParallelOp parallelOp)
Try converting scf.forall into an scf.parallel loop.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:707
Include the generated interface declarations.
std::unique_ptr< Pass > createParallelForToNestedForsPass()
Creates a pass that converts SCF forall loops to SCF parallel loops.
LoopVector loops
Definition: SCF.h:67