MLIR 22.0.0git
ParallelForToNestedFors.cpp
Go to the documentation of this file.
1//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.ParallelOp to nested scf.for ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "llvm/Support/Debug.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
21#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22} // namespace mlir
23
24#define DEBUG_TYPE "parallel-for-to-nested-fors"
25using namespace mlir;
26
27FailureOr<scf::LoopNest>
29 scf::ParallelOp parallelOp) {
30
31 if (!parallelOp.getResults().empty())
32 return rewriter.notifyMatchFailure(
33 parallelOp, "Currently scf.parallel to scf.for conversion doesn't "
34 "support scf.parallel with results.");
35
36 rewriter.setInsertionPoint(parallelOp);
37
38 Location loc = parallelOp.getLoc();
39 SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
40 SmallVector<Value> upperBounds = parallelOp.getUpperBound();
41 SmallVector<Value> steps = parallelOp.getStep();
42
43 assert(lowerBounds.size() == upperBounds.size() &&
44 lowerBounds.size() == steps.size() &&
45 "Mismatched parallel loop bounds");
46
47 scf::LoopNest loopNest =
48 scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
49
50 SmallVector<Value> newInductionVars = llvm::map_to_vector(
51 loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
52 Block *linearizedBody = loopNest.loops.back().getBody();
53 Block *parallelBody = parallelOp.getBody();
54 rewriter.eraseOp(parallelBody->getTerminator());
55 rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
56 newInductionVars);
57 rewriter.eraseOp(parallelOp);
58 return loopNest;
59}
60
61namespace {
62struct ParallelForToNestedFors final
63 : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
64 void runOnOperation() override {
65 Operation *parentOp = getOperation();
66 IRRewriter rewriter(parentOp->getContext());
67
68 parentOp->walk(
69 [&](scf::ParallelOp parallelOp) {
70 if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
71 LLVM_DEBUG(
72 llvm::dbgs()
73 << "Failed to convert scf.parallel to nested scf.for ops for:\n"
74 << parallelOp << "\n");
75 return WalkResult::advance();
76 }
77 return WalkResult::advance();
78 });
79 }
80};
81} // namespace
82
84 return std::make_unique<ParallelForToNestedFors>();
85}
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
static WalkResult advance()
Definition WalkResult.h:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:837
FailureOr< scf::LoopNest > parallelForToNestedFors(RewriterBase &rewriter, ParallelOp parallelOp)
Try converting scf.forall into an scf.parallel loop.
Include the generated interface declarations.
std::unique_ptr< Pass > createParallelForToNestedForsPass()
Creates a pass that converts SCF forall loops to SCF parallel loops.
LoopVector loops
Definition SCF.h:67