MLIR  20.0.0git
ForallToParallel.cpp
Go to the documentation of this file.
1 //===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForallOp's into SCF.ParallelOps's.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
20 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
26  scf::ForallOp forallOp,
27  scf::ParallelOp *result) {
28  OpBuilder::InsertionGuard guard(rewriter);
29  rewriter.setInsertionPoint(forallOp);
30 
31  Location loc = forallOp.getLoc();
32  if (!forallOp.getOutputs().empty())
33  return rewriter.notifyMatchFailure(
34  forallOp,
35  "only fully bufferized scf.forall ops can be lowered to scf.parallel");
36 
37  // Convert mixed bounds and steps to SSA values.
38  SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
39  SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
40  SmallVector<Value> steps = forallOp.getStep(rewriter);
41 
42  // Create empty scf.parallel op.
43  auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
44  rewriter.eraseBlock(&parallelOp.getRegion().front());
45  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
46  parallelOp.getRegion().begin());
47  // Replace the terminator.
48  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
49  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
50  parallelOp.getRegion().front().getTerminator());
51 
52  // If the mapping attribute is present, propagate to the new parallelOp.
53  if (forallOp.getMapping())
54  parallelOp->setAttr("mapping", *forallOp.getMapping());
55 
56  // Erase the scf.forall op.
57  rewriter.replaceOp(forallOp, parallelOp);
58 
59  if (result)
60  *result = parallelOp;
61 
62  return success();
63 }
64 
65 namespace {
66 struct ForallToParallelLoop final
67  : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
68  void runOnOperation() override {
69  Operation *parentOp = getOperation();
70  IRRewriter rewriter(parentOp->getContext());
71 
72  parentOp->walk([&](scf::ForallOp forallOp) {
73  if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
74  return signalPassFailure();
75  }
76  });
77  }
78 };
79 } // namespace
80 
81 std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
82  return std::make_unique<ForallToParallelLoop>();
83 }
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
std::unique_ptr< Pass > createForallToParallelLoopPass()
Creates a pass that converts SCF forall loops to SCF parallel loops.