15 #include "llvm/ADT/SmallSet.h"
16 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Support/Debug.h"
20 #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING
21 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 #define DEBUG_TYPE "parallel-loop-collapsing"
29 struct TestSCFParallelLoopCollapsing
30 :
public impl::TestSCFParallelLoopCollapsingBase<
31 TestSCFParallelLoopCollapsing> {
33 void runOnOperation()
override {
42 if (!clCollapsedIndices0.empty())
43 combinedLoops.push_back(clCollapsedIndices0);
44 if (!clCollapsedIndices1.empty()) {
45 if (clCollapsedIndices0.empty()) {
47 <<
"collapsed-indices-1 specified but not collapsed-indices-0";
51 combinedLoops.push_back(clCollapsedIndices1);
53 if (!clCollapsedIndices2.empty()) {
54 if (clCollapsedIndices1.empty()) {
56 <<
"collapsed-indices-2 specified but not collapsed-indices-1";
60 combinedLoops.push_back(clCollapsedIndices2);
63 if (combinedLoops.empty()) {
64 llvm::errs() <<
"No collapsed-indices were specified. This pass is only "
65 "for testing and does not automatically collapse all "
66 "parallel loops or similar.";
73 llvm::SmallSet<unsigned, 8> flattenedCombinedLoops;
74 unsigned maxCollapsedIndex = 0;
75 for (
auto &loops : combinedLoops) {
76 for (
auto &loop : loops) {
77 flattenedCombinedLoops.insert(loop);
78 maxCollapsedIndex =
std::max(maxCollapsedIndex, loop);
82 if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 ||
83 !flattenedCombinedLoops.contains(maxCollapsedIndex)) {
85 <<
"collapsed-indices arguments must include all values [0,N).";
94 module->
walk([&](scf::ParallelOp op) {
95 if (flattenedCombinedLoops.size() != op.getNumLoops()) {
96 op.emitOpError(
"has ")
98 <<
" iter args while this limited functionality testing pass was "
99 "configured only for loops with exactly "
100 << flattenedCombinedLoops.size() <<
" iter args.";
110 return std::make_unique<TestSCFParallelLoopCollapsing>();
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Include the generated interface declarations.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
std::unique_ptr< Pass > createTestSCFParallelLoopCollapsingPass()
Creates a pass that transforms a single ParallelLoop over N induction variables into another Parallel...