MLIR  20.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 /// Matches a block containing a "simple" reduction. The expected shape of the
36 /// block is as follows.
37 ///
38 /// ^bb(%arg0, %arg1):
39 /// %0 = OpTy(%arg0, %arg1)
40 /// scf.reduce.return %0
41 template <typename... OpTy>
42 static bool matchSimpleReduction(Block &block) {
43  if (block.empty() || llvm::hasSingleElement(block) ||
44  std::next(block.begin(), 2) != block.end())
45  return false;
46 
47  if (block.getNumArguments() != 2)
48  return false;
49 
50  SmallVector<Operation *, 4> combinerOps;
51  Value reducedVal = matchReduction({block.getArguments()[1]},
52  /*redPos=*/0, combinerOps);
53 
54  if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
55  return false;
56 
57  return isa<OpTy...>(combinerOps[0]) &&
58  isa<scf::ReduceReturnOp>(block.back()) &&
59  block.front().getOperands() == block.getArguments();
60 }
61 
62 /// Matches a block containing a select-based min/max reduction. The types of
63 /// select and compare operations are provided as template arguments. The
64 /// comparison predicates suitable for min and max are provided as function
65 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
66 /// compute the minimum and unset if it computes the maximum, otherwise it
67 /// remains unmodified. The expected shape of the block is as follows.
68 ///
69 /// ^bb(%arg0, %arg1):
70 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
71 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
72 /// scf.reduce.return %1
73 template <
74  typename CompareOpTy, typename SelectOpTy,
75  typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
76 static bool
77 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
78  ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
79  static_assert(
80  llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
81  "only arithmetic and llvm select ops are supported");
82 
83  // Expect exactly three operations in the block.
84  if (block.empty() || llvm::hasSingleElement(block) ||
85  std::next(block.begin(), 2) == block.end() ||
86  std::next(block.begin(), 3) != block.end())
87  return false;
88 
89  // Check op kinds.
90  auto compare = dyn_cast<CompareOpTy>(block.front());
91  auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
92  auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
93  if (!compare || !select || !terminator)
94  return false;
95 
96  // Block arguments must be compared.
97  if (compare->getOperands() != block.getArguments())
98  return false;
99 
100  // Detect whether the comparison is less-than or greater-than, otherwise bail.
101  bool isLess;
102  if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103  isLess = true;
104  } else if (llvm::is_contained(greaterThanPredicates,
105  compare.getPredicate())) {
106  isLess = false;
107  } else {
108  return false;
109  }
110 
111  if (select.getCondition() != compare.getResult())
112  return false;
113 
114  // Detect if the operands are swapped between cmpf and select. Match the
115  // comparison type with the requested type or with the opposite of the
116  // requested type if the operands are swapped. Use generic accessors because
117  // std and LLVM versions of select have different operand names but identical
118  // positions.
119  constexpr unsigned kTrueValue = 1;
120  constexpr unsigned kFalseValue = 2;
121  bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
122  select.getOperand(kFalseValue) == compare.getRhs();
123  bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
124  select.getOperand(kFalseValue) == compare.getLhs();
125  if (!sameOperands && !swappedOperands)
126  return false;
127 
128  if (select.getResult() != terminator.getResult())
129  return false;
130 
131  // The reduction is a min if it uses less-than predicates with same operands
132  // or greather-than predicates with swapped operands. Similarly for max.
133  isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
134  return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
135 }
136 
137 /// Returns the float semantics for the given float type.
138 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
139  if (type.isF16())
140  return llvm::APFloat::IEEEhalf();
141  if (type.isF32())
142  return llvm::APFloat::IEEEsingle();
143  if (type.isF64())
144  return llvm::APFloat::IEEEdouble();
145  if (type.isF128())
146  return llvm::APFloat::IEEEquad();
147  if (type.isBF16())
148  return llvm::APFloat::BFloat();
149  if (type.isF80())
150  return llvm::APFloat::x87DoubleExtended();
151  llvm_unreachable("unknown float type");
152 }
153 
154 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
155 /// (otherwise) for the given float type.
156 static Attribute minMaxValueForFloat(Type type, bool min) {
157  auto fltType = cast<FloatType>(type);
158  return FloatAttr::get(
159  type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
160 }
161 
162 /// Returns an attribute with the signed integer minimum (if `min` is set) or
163 /// the maximum value (otherwise) for the given integer type, regardless of its
164 /// signedness semantics (only the width is considered).
166  auto intType = cast<IntegerType>(type);
167  unsigned bitwidth = intType.getWidth();
168  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
169  : llvm::APInt::getSignedMaxValue(bitwidth));
170 }
171 
172 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
173 /// the maximum value (otherwise) for the given integer type, regardless of its
174 /// signedness semantics (only the width is considered).
176  auto intType = cast<IntegerType>(type);
177  unsigned bitwidth = intType.getWidth();
178  return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
179  : llvm::APInt::getAllOnes(bitwidth));
180 }
181 
182 /// Creates an OpenMP reduction declaration and inserts it into the provided
183 /// symbol table. The declaration has a constant initializer with the neutral
184 /// value `initValue`, and the `reductionIndex`-th reduction combiner carried
185 /// over from `reduce`.
186 static omp::DeclareReductionOp
187 createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
188  scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
189  OpBuilder::InsertionGuard guard(builder);
190  Type type = reduce.getOperands()[reductionIndex].getType();
191  auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(),
192  "__scf_reduction", type);
193  symbolTable.insert(decl);
194 
195  builder.createBlock(&decl.getInitializerRegion(),
196  decl.getInitializerRegion().end(), {type},
197  {reduce.getOperands()[reductionIndex].getLoc()});
198  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
199  Value init =
200  builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
201  builder.create<omp::YieldOp>(reduce.getLoc(), init);
202 
203  Operation *terminator =
204  &reduce.getReductions()[reductionIndex].front().back();
205  assert(isa<scf::ReduceReturnOp>(terminator) &&
206  "expected reduce op to be terminated by redure return");
207  builder.setInsertionPoint(terminator);
208  builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
209  terminator->getOperands());
210  builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
211  decl.getReductionRegion(),
212  decl.getReductionRegion().end());
213  return decl;
214 }
215 
216 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
217 /// using llvm.atomicrmw of the given kind.
218 static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
219  LLVM::AtomicBinOp atomicKind,
220  omp::DeclareReductionOp decl,
221  scf::ReduceOp reduce,
222  int64_t reductionIndex) {
223  OpBuilder::InsertionGuard guard(builder);
224  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
225  Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
226  builder.createBlock(&decl.getAtomicReductionRegion(),
227  decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
228  {reduceOperandLoc, reduceOperandLoc});
229  Block *atomicBlock = &decl.getAtomicReductionRegion().back();
230  builder.setInsertionPointToEnd(atomicBlock);
231  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
232  atomicBlock->getArgument(1));
233  builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
234  atomicBlock->getArgument(0), loaded,
235  LLVM::AtomicOrdering::monotonic);
236  builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
237  return decl;
238 }
239 
240 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
241 /// reduction and returns it. Recognizes common reductions in order to identify
242 /// the neutral value, necessary for the OpenMP declaration. If the reduction
243 /// cannot be recognized, returns null.
244 static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
245  scf::ReduceOp reduce,
246  int64_t reductionIndex) {
248  SymbolTable symbolTable(container);
249 
250  // Insert reduction declarations in the symbol-table ancestor before the
251  // ancestor of the current insertion point.
252  Operation *insertionPoint = reduce;
253  while (insertionPoint->getParentOp() != container)
254  insertionPoint = insertionPoint->getParentOp();
255  OpBuilder::InsertionGuard guard(builder);
256  builder.setInsertionPoint(insertionPoint);
257 
258  assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
259  "expected reduction region to have a single element");
260 
261  // Match simple binary reductions that can be expressed with atomicrmw.
262  Type type = reduce.getOperands()[reductionIndex].getType();
263  Block &reduction = reduce.getReductions()[reductionIndex].front();
264  if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
265  omp::DeclareReductionOp decl =
266  createDecl(builder, symbolTable, reduce, reductionIndex,
267  builder.getFloatAttr(type, 0.0));
268  return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
269  reductionIndex);
270  }
271  if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
272  omp::DeclareReductionOp decl =
273  createDecl(builder, symbolTable, reduce, reductionIndex,
274  builder.getIntegerAttr(type, 0));
275  return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
276  reductionIndex);
277  }
278  if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
279  omp::DeclareReductionOp decl =
280  createDecl(builder, symbolTable, reduce, reductionIndex,
281  builder.getIntegerAttr(type, 0));
282  return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
283  reductionIndex);
284  }
285  if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
286  omp::DeclareReductionOp decl =
287  createDecl(builder, symbolTable, reduce, reductionIndex,
288  builder.getIntegerAttr(type, 0));
289  return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
290  reductionIndex);
291  }
292  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
293  omp::DeclareReductionOp decl = createDecl(
294  builder, symbolTable, reduce, reductionIndex,
295  builder.getIntegerAttr(
296  type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
297  return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
298  reductionIndex);
299  }
300 
301  // Match simple binary reductions that cannot be expressed with atomicrmw.
302  // TODO: add atomic region using cmpxchg (which needs atomic load to be
303  // available as an op).
304  if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
305  return createDecl(builder, symbolTable, reduce, reductionIndex,
306  builder.getFloatAttr(type, 1.0));
307  }
308  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
309  return createDecl(builder, symbolTable, reduce, reductionIndex,
310  builder.getIntegerAttr(type, 1));
311  }
312 
313  // Match select-based min/max reductions.
314  bool isMin;
315  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
316  reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317  {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318  matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
319  reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320  {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
321  return createDecl(builder, symbolTable, reduce, reductionIndex,
322  minMaxValueForFloat(type, !isMin));
323  }
324  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
325  reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326  {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
328  reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329  {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
330  omp::DeclareReductionOp decl =
331  createDecl(builder, symbolTable, reduce, reductionIndex,
332  minMaxValueForSignedInt(type, !isMin));
333  return addAtomicRMW(builder,
335  decl, reduce, reductionIndex);
336  }
337  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
338  reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339  {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
341  reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342  {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343  omp::DeclareReductionOp decl =
344  createDecl(builder, symbolTable, reduce, reductionIndex,
345  minMaxValueForUnsignedInt(type, !isMin));
346  return addAtomicRMW(
347  builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348  decl, reduce, reductionIndex);
349  }
350 
351  return nullptr;
352 }
353 
354 namespace {
355 
356 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
357  static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
358  unsigned numThreads;
359 
360  ParallelOpLowering(MLIRContext *context,
361  unsigned numThreads = kUseOpenMPDefaultNumThreads)
362  : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
363 
364  LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
365  PatternRewriter &rewriter) const override {
366  // Declare reductions.
367  // TODO: consider checking it here is already a compatible reduction
368  // declaration and use it instead of redeclaring.
369  SmallVector<Attribute> reductionDeclSymbols;
370  SmallVector<omp::DeclareReductionOp> ompReductionDecls;
371  auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
372  for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
373  omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
374  ompReductionDecls.push_back(decl);
375  if (!decl)
376  return failure();
377  reductionDeclSymbols.push_back(
378  SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
379  }
380 
381  // Allocate reduction variables. Make sure the we don't overflow the stack
382  // with local `alloca`s by saving and restoring the stack pointer.
383  Location loc = parallelOp.getLoc();
384  Value one = rewriter.create<LLVM::ConstantOp>(
385  loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
386  SmallVector<Value> reductionVariables;
387  reductionVariables.reserve(parallelOp.getNumReductions());
388  auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
389  for (Value init : parallelOp.getInitVals()) {
390  assert((LLVM::isCompatibleType(init.getType()) ||
391  isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
392  "cannot create a reduction variable if the type is not an LLVM "
393  "pointer element");
394  Value storage =
395  rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
396  rewriter.create<LLVM::StoreOp>(loc, init, storage);
397  reductionVariables.push_back(storage);
398  }
399 
400  // Replace the reduction operations contained in this loop. Must be done
401  // here rather than in a separate pattern to have access to the list of
402  // reduction variables.
403  for (auto [x, y, rD] : llvm::zip_equal(
404  reductionVariables, reduce.getOperands(), ompReductionDecls)) {
405  OpBuilder::InsertionGuard guard(rewriter);
406  rewriter.setInsertionPoint(reduce);
407  Region &redRegion = rD.getReductionRegion();
408  // The SCF dialect by definition contains only structured operations
409  // and hence the SCF reduction region will contain a single block.
410  // The ompReductionDecls region is a copy of the SCF reduction region
411  // and hence has the same property.
412  assert(redRegion.hasOneBlock() &&
413  "expect reduction region to have one block");
414  Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
415  Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
416  rD.getType(), pvtRedVar);
417  // Make a copy of the reduction combiner region in the body
418  mlir::OpBuilder builder(rewriter.getContext());
419  builder.setInsertionPoint(reduce);
420  mlir::IRMapping mapper;
421  assert(redRegion.getNumArguments() == 2 &&
422  "expect reduction region to have two arguments");
423  mapper.map(redRegion.getArgument(0), pvtRedVal);
424  mapper.map(redRegion.getArgument(1), y);
425  for (auto &op : redRegion.getOps()) {
426  Operation *cloneOp = builder.clone(op, mapper);
427  if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
428  assert(yieldOp && yieldOp.getResults().size() == 1 &&
429  "expect YieldOp in reduction region to return one result");
430  Value redVal = yieldOp.getResults()[0];
431  rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
432  rewriter.eraseOp(yieldOp);
433  break;
434  }
435  }
436  }
437  rewriter.eraseOp(reduce);
438 
439  Value numThreadsVar;
440  if (numThreads > 0) {
441  numThreadsVar = rewriter.create<LLVM::ConstantOp>(
442  loc, rewriter.getI32IntegerAttr(numThreads));
443  }
444  // Create the parallel wrapper.
445  auto ompParallel = rewriter.create<omp::ParallelOp>(
446  loc,
447  /* if_expr = */ Value{},
448  /* num_threads_var = */ numThreadsVar,
449  /* allocate_vars = */ llvm::SmallVector<Value>{},
450  /* allocators_vars = */ llvm::SmallVector<Value>{},
451  /* reduction_vars = */ llvm::SmallVector<Value>{},
452  /* reduction_vars_isbyref = */ DenseBoolArrayAttr{},
453  /* reductions = */ ArrayAttr{},
454  /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
455  /* private_vars = */ ValueRange(),
456  /* privatizers = */ nullptr);
457  {
458 
459  OpBuilder::InsertionGuard guard(rewriter);
460  rewriter.createBlock(&ompParallel.getRegion());
461 
462  // Replace the loop.
463  {
464  OpBuilder::InsertionGuard allocaGuard(rewriter);
465  // Create worksharing loop wrapper.
466  auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
467  if (!reductionVariables.empty()) {
468  wsloopOp.setReductionsAttr(
469  ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
470  wsloopOp.getReductionVarsMutable().append(reductionVariables);
471  llvm::SmallVector<bool> byRefVec;
472  // false because these reductions always reduce scalars and so do
473  // not need to pass by reference
474  byRefVec.resize(reductionVariables.size(), false);
475  wsloopOp.setReductionVarsByref(
476  DenseBoolArrayAttr::get(rewriter.getContext(), byRefVec));
477  }
478  rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
479 
480  // The wrapper's entry block arguments will define the reduction
481  // variables.
482  llvm::SmallVector<mlir::Type> reductionTypes;
483  reductionTypes.reserve(reductionVariables.size());
484  llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
485  [](mlir::Value v) { return v.getType(); });
486  rewriter.createBlock(
487  &wsloopOp.getRegion(), {}, reductionTypes,
488  llvm::SmallVector<mlir::Location>(reductionVariables.size(),
489  parallelOp.getLoc()));
490 
491  rewriter.setInsertionPoint(
492  rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
493 
494  // Create loop nest and populate region with contents of scf.parallel.
495  auto loopOp = rewriter.create<omp::LoopNestOp>(
496  parallelOp.getLoc(), parallelOp.getLowerBound(),
497  parallelOp.getUpperBound(), parallelOp.getStep());
498 
499  rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
500  loopOp.getRegion().begin());
501 
502  // Remove reduction-related block arguments from omp.loop_nest and
503  // redirect uses to the corresponding omp.wsloop block argument.
504  mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
505  unsigned numLoops = parallelOp.getNumLoops();
506  rewriter.replaceAllUsesWith(
507  loopOpEntryBlock.getArguments().drop_front(numLoops),
508  wsloopOp.getRegion().getArguments());
509  loopOpEntryBlock.eraseArguments(
510  numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
511 
512  Block *ops =
513  rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
514  rewriter.setInsertionPointToStart(&loopOpEntryBlock);
515 
516  auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
517  TypeRange());
518  rewriter.create<omp::YieldOp>(loc, ValueRange());
519  Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
520  rewriter.mergeBlocks(ops, scopeBlock);
521  rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
522  rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
523  }
524  }
525 
526  // Load loop results.
527  SmallVector<Value> results;
528  results.reserve(reductionVariables.size());
529  for (auto [variable, type] :
530  llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
531  Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
532  results.push_back(res);
533  }
534  rewriter.replaceOp(parallelOp, results);
535 
536  return success();
537  }
538 };
539 
540 /// Applies the conversion patterns in the given function.
541 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
542  ConversionTarget target(*module.getContext());
543  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
544  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
545  memref::MemRefDialect>();
546 
547  RewritePatternSet patterns(module.getContext());
548  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
549  FrozenRewritePatternSet frozen(std::move(patterns));
550  return applyPartialConversion(module, target, frozen);
551 }
552 
553 /// A pass converting SCF operations to OpenMP operations.
554 struct SCFToOpenMPPass
555  : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
556 
557  using Base::Base;
558 
559  /// Pass entry point.
560  void runOnOperation() override {
561  if (failed(applyPatterns(getOperation(), numThreads)))
562  signalPassFailure();
563  }
564 };
565 
566 } // namespace
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2672
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
Definition: SCFToOpenMP.cpp:42
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
Definition: SCFToOpenMP.cpp:77
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
MLIRContext * getContext() const
Definition: Builders.h:55
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:439
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:53
bool isF32() const
Definition: Types.cpp:52
bool isF128() const
Definition: Types.cpp:55
bool isF16() const
Definition: Types.cpp:50
bool isF80() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
bool isBF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:67
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358