MLIR  19.0.0git
ForallToFor.cpp
Go to the documentation of this file.
1 //===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForallOp's into SCF.ForOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_SCFFORALLTOFORLOOP
21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace llvm;
25 using namespace mlir;
26 using scf::ForallOp;
27 using scf::ForOp;
28 using scf::LoopNest;
29 
31 mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
33  OpBuilder::InsertionGuard guard(rewriter);
34  rewriter.setInsertionPoint(forallOp);
35 
36  Location loc = forallOp.getLoc();
38  rewriter, loc, forallOp.getMixedLowerBound());
40  rewriter, loc, forallOp.getMixedUpperBound());
41  SmallVector<Value> steps =
42  getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
43  LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
44 
45  SmallVector<Value> ivs = llvm::map_to_vector(
46  loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
47 
48  Block *innermostBlock = loopNest.loops.back().getBody();
49  rewriter.eraseOp(forallOp.getBody()->getTerminator());
50  rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
51  innermostBlock->getTerminator()->getIterator(),
52  ivs);
53  rewriter.eraseOp(forallOp);
54 
55  if (results) {
56  llvm::move(loopNest.loops, std::back_inserter(*results));
57  }
58 
59  return success();
60 }
61 
62 namespace {
63 struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
64  void runOnOperation() override {
65  Operation *parentOp = getOperation();
66  IRRewriter rewriter(parentOp->getContext());
67 
68  parentOp->walk([&](scf::ForallOp forallOp) {
69  if (failed(scf::forallToForLoop(rewriter, forallOp))) {
70  return signalPassFailure();
71  }
72  });
73  }
74 };
75 } // namespace
76 
77 std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
78  return std::make_unique<ForallToForLoop>();
79 }
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
Include the generated interface declarations.
Definition: CallGraph.h:229
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)
Try converting scf.forall into a set of nested scf.for loops.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
Include the generated interface declarations.
std::unique_ptr< Pass > createForallToForLoopPass()
Creates a pass that converts SCF forall loops to SCF for loops.
Definition: ForallToFor.cpp:77
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LoopVector loops
Definition: SCF.h:73