MLIR  20.0.0git
ForallToFor.cpp
Go to the documentation of this file.
1 //===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForallOp's into SCF.ForOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_SCFFORALLTOFORLOOP
21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace llvm;
25 using namespace mlir;
26 using scf::LoopNest;
27 
28 LogicalResult
29 mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
31  OpBuilder::InsertionGuard guard(rewriter);
32  rewriter.setInsertionPoint(forallOp);
33 
34  Location loc = forallOp.getLoc();
35  SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
36  SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
37  SmallVector<Value> steps = forallOp.getStep(rewriter);
38  LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
39 
40  SmallVector<Value> ivs = llvm::map_to_vector(
41  loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
42 
43  Block *innermostBlock = loopNest.loops.back().getBody();
44  rewriter.eraseOp(forallOp.getBody()->getTerminator());
45  rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
46  innermostBlock->getTerminator()->getIterator(),
47  ivs);
48  rewriter.eraseOp(forallOp);
49 
50  if (results) {
51  llvm::move(loopNest.loops, std::back_inserter(*results));
52  }
53 
54  return success();
55 }
56 
57 namespace {
58 struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
59  void runOnOperation() override {
60  Operation *parentOp = getOperation();
61  IRRewriter rewriter(parentOp->getContext());
62 
63  parentOp->walk([&](scf::ForallOp forallOp) {
64  if (failed(scf::forallToForLoop(rewriter, forallOp))) {
65  return signalPassFailure();
66  }
67  });
68  }
69 };
70 } // namespace
71 
72 std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
73  return std::make_unique<ForallToForLoop>();
74 }
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)
Try converting scf.forall into a set of nested scf.for loops.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
Include the generated interface declarations.
std::unique_ptr< Pass > createForallToForLoopPass()
Creates a pass that converts SCF forall loops to SCF for loops.
Definition: ForallToFor.cpp:72
LoopVector loops
Definition: SCF.h:67