MLIR  21.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 /// Matches a block containing a "simple" reduction. The expected shape of the
35 /// block is as follows.
36 ///
37 /// ^bb(%arg0, %arg1):
38 /// %0 = OpTy(%arg0, %arg1)
39 /// scf.reduce.return %0
40 template <typename... OpTy>
41 static bool matchSimpleReduction(Block &block) {
42  if (block.empty() || llvm::hasSingleElement(block) ||
43  std::next(block.begin(), 2) != block.end())
44  return false;
45 
46  if (block.getNumArguments() != 2)
47  return false;
48 
49  SmallVector<Operation *, 4> combinerOps;
50  Value reducedVal = matchReduction({block.getArguments()[1]},
51  /*redPos=*/0, combinerOps);
52 
53  if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
54  return false;
55 
56  return isa<OpTy...>(combinerOps[0]) &&
57  isa<scf::ReduceReturnOp>(block.back()) &&
58  block.front().getOperands() == block.getArguments();
59 }
60 
61 /// Matches a block containing a select-based min/max reduction. The types of
62 /// select and compare operations are provided as template arguments. The
63 /// comparison predicates suitable for min and max are provided as function
64 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
65 /// compute the minimum and unset if it computes the maximum, otherwise it
66 /// remains unmodified. The expected shape of the block is as follows.
67 ///
68 /// ^bb(%arg0, %arg1):
69 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
70 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
71 /// scf.reduce.return %1
72 template <
73  typename CompareOpTy, typename SelectOpTy,
74  typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
75 static bool
76 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
77  ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
78  static_assert(
79  llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80  "only arithmetic and llvm select ops are supported");
81 
82  // Expect exactly three operations in the block.
83  if (block.empty() || llvm::hasSingleElement(block) ||
84  std::next(block.begin(), 2) == block.end() ||
85  std::next(block.begin(), 3) != block.end())
86  return false;
87 
88  // Check op kinds.
89  auto compare = dyn_cast<CompareOpTy>(block.front());
90  auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
91  auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
92  if (!compare || !select || !terminator)
93  return false;
94 
95  // Block arguments must be compared.
96  if (compare->getOperands() != block.getArguments())
97  return false;
98 
99  // Detect whether the comparison is less-than or greater-than, otherwise bail.
100  bool isLess;
101  if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
102  isLess = true;
103  } else if (llvm::is_contained(greaterThanPredicates,
104  compare.getPredicate())) {
105  isLess = false;
106  } else {
107  return false;
108  }
109 
110  if (select.getCondition() != compare.getResult())
111  return false;
112 
113  // Detect if the operands are swapped between cmpf and select. Match the
114  // comparison type with the requested type or with the opposite of the
115  // requested type if the operands are swapped. Use generic accessors because
116  // std and LLVM versions of select have different operand names but identical
117  // positions.
118  constexpr unsigned kTrueValue = 1;
119  constexpr unsigned kFalseValue = 2;
120  bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121  select.getOperand(kFalseValue) == compare.getRhs();
122  bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123  select.getOperand(kFalseValue) == compare.getLhs();
124  if (!sameOperands && !swappedOperands)
125  return false;
126 
127  if (select.getResult() != terminator.getResult())
128  return false;
129 
130  // The reduction is a min if it uses less-than predicates with same operands
131  // or greather-than predicates with swapped operands. Similarly for max.
132  isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133  return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
134 }
135 
136 /// Returns the float semantics for the given float type.
137 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
138  if (type.isF16())
139  return llvm::APFloat::IEEEhalf();
140  if (type.isF32())
141  return llvm::APFloat::IEEEsingle();
142  if (type.isF64())
143  return llvm::APFloat::IEEEdouble();
144  if (type.isF128())
145  return llvm::APFloat::IEEEquad();
146  if (type.isBF16())
147  return llvm::APFloat::BFloat();
148  if (type.isF80())
149  return llvm::APFloat::x87DoubleExtended();
150  llvm_unreachable("unknown float type");
151 }
152 
153 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
154 /// (otherwise) for the given float type.
155 static Attribute minMaxValueForFloat(Type type, bool min) {
156  auto fltType = cast<FloatType>(type);
157  return FloatAttr::get(
158  type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
159 }
160 
161 /// Returns an attribute with the signed integer minimum (if `min` is set) or
162 /// the maximum value (otherwise) for the given integer type, regardless of its
163 /// signedness semantics (only the width is considered).
165  auto intType = cast<IntegerType>(type);
166  unsigned bitwidth = intType.getWidth();
167  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
168  : llvm::APInt::getSignedMaxValue(bitwidth));
169 }
170 
171 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
172 /// the maximum value (otherwise) for the given integer type, regardless of its
173 /// signedness semantics (only the width is considered).
175  auto intType = cast<IntegerType>(type);
176  unsigned bitwidth = intType.getWidth();
177  return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
178  : llvm::APInt::getAllOnes(bitwidth));
179 }
180 
181 /// Creates an OpenMP reduction declaration and inserts it into the provided
182 /// symbol table. The declaration has a constant initializer with the neutral
183 /// value `initValue`, and the `reductionIndex`-th reduction combiner carried
184 /// over from `reduce`.
185 static omp::DeclareReductionOp
186 createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
187  scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
188  OpBuilder::InsertionGuard guard(builder);
189  Type type = reduce.getOperands()[reductionIndex].getType();
190  auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(),
191  "__scf_reduction", type);
192  symbolTable.insert(decl);
193 
194  builder.createBlock(&decl.getInitializerRegion(),
195  decl.getInitializerRegion().end(), {type},
196  {reduce.getOperands()[reductionIndex].getLoc()});
197  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
198  Value init =
199  builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
200  builder.create<omp::YieldOp>(reduce.getLoc(), init);
201 
202  Operation *terminator =
203  &reduce.getReductions()[reductionIndex].front().back();
204  assert(isa<scf::ReduceReturnOp>(terminator) &&
205  "expected reduce op to be terminated by redure return");
206  builder.setInsertionPoint(terminator);
207  builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
208  terminator->getOperands());
209  builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
210  decl.getReductionRegion(),
211  decl.getReductionRegion().end());
212  return decl;
213 }
214 
215 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
216 /// using llvm.atomicrmw of the given kind.
217 static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
218  LLVM::AtomicBinOp atomicKind,
219  omp::DeclareReductionOp decl,
220  scf::ReduceOp reduce,
221  int64_t reductionIndex) {
222  OpBuilder::InsertionGuard guard(builder);
223  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
224  Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
225  builder.createBlock(&decl.getAtomicReductionRegion(),
226  decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
227  {reduceOperandLoc, reduceOperandLoc});
228  Block *atomicBlock = &decl.getAtomicReductionRegion().back();
229  builder.setInsertionPointToEnd(atomicBlock);
230  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
231  atomicBlock->getArgument(1));
232  builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
233  atomicBlock->getArgument(0), loaded,
234  LLVM::AtomicOrdering::monotonic);
235  builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
236  return decl;
237 }
238 
239 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
240 /// reduction and returns it. Recognizes common reductions in order to identify
241 /// the neutral value, necessary for the OpenMP declaration. If the reduction
242 /// cannot be recognized, returns null.
243 static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
244  scf::ReduceOp reduce,
245  int64_t reductionIndex) {
247  SymbolTable symbolTable(container);
248 
249  // Insert reduction declarations in the symbol-table ancestor before the
250  // ancestor of the current insertion point.
251  Operation *insertionPoint = reduce;
252  while (insertionPoint->getParentOp() != container)
253  insertionPoint = insertionPoint->getParentOp();
254  OpBuilder::InsertionGuard guard(builder);
255  builder.setInsertionPoint(insertionPoint);
256 
257  assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
258  "expected reduction region to have a single element");
259 
260  // Match simple binary reductions that can be expressed with atomicrmw.
261  Type type = reduce.getOperands()[reductionIndex].getType();
262  Block &reduction = reduce.getReductions()[reductionIndex].front();
263  if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
264  omp::DeclareReductionOp decl =
265  createDecl(builder, symbolTable, reduce, reductionIndex,
266  builder.getFloatAttr(type, 0.0));
267  return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
268  reductionIndex);
269  }
270  if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
271  omp::DeclareReductionOp decl =
272  createDecl(builder, symbolTable, reduce, reductionIndex,
273  builder.getIntegerAttr(type, 0));
274  return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
275  reductionIndex);
276  }
277  if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
278  omp::DeclareReductionOp decl =
279  createDecl(builder, symbolTable, reduce, reductionIndex,
280  builder.getIntegerAttr(type, 0));
281  return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
282  reductionIndex);
283  }
284  if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
285  omp::DeclareReductionOp decl =
286  createDecl(builder, symbolTable, reduce, reductionIndex,
287  builder.getIntegerAttr(type, 0));
288  return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
289  reductionIndex);
290  }
291  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
292  omp::DeclareReductionOp decl = createDecl(
293  builder, symbolTable, reduce, reductionIndex,
294  builder.getIntegerAttr(
295  type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
296  return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
297  reductionIndex);
298  }
299 
300  // Match simple binary reductions that cannot be expressed with atomicrmw.
301  // TODO: add atomic region using cmpxchg (which needs atomic load to be
302  // available as an op).
303  if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
304  return createDecl(builder, symbolTable, reduce, reductionIndex,
305  builder.getFloatAttr(type, 1.0));
306  }
307  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
308  return createDecl(builder, symbolTable, reduce, reductionIndex,
309  builder.getIntegerAttr(type, 1));
310  }
311 
312  // Match select-based min/max reductions.
313  bool isMin;
314  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
315  reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
316  {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
317  matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
318  reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
319  {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
320  return createDecl(builder, symbolTable, reduce, reductionIndex,
321  minMaxValueForFloat(type, !isMin));
322  }
323  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
324  reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
325  {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
326  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
327  reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
328  {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
329  omp::DeclareReductionOp decl =
330  createDecl(builder, symbolTable, reduce, reductionIndex,
331  minMaxValueForSignedInt(type, !isMin));
332  return addAtomicRMW(builder,
334  decl, reduce, reductionIndex);
335  }
336  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
337  reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
338  {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
339  matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
340  reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
341  {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
342  omp::DeclareReductionOp decl =
343  createDecl(builder, symbolTable, reduce, reductionIndex,
344  minMaxValueForUnsignedInt(type, !isMin));
345  return addAtomicRMW(
346  builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
347  decl, reduce, reductionIndex);
348  }
349 
350  return nullptr;
351 }
352 
353 namespace {
354 
355 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
356  static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
357  unsigned numThreads;
358 
359  ParallelOpLowering(MLIRContext *context,
360  unsigned numThreads = kUseOpenMPDefaultNumThreads)
361  : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
362 
363  LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
364  PatternRewriter &rewriter) const override {
365  // Declare reductions.
366  // TODO: consider checking it here is already a compatible reduction
367  // declaration and use it instead of redeclaring.
368  SmallVector<Attribute> reductionSyms;
369  SmallVector<omp::DeclareReductionOp> ompReductionDecls;
370  auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
371  for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
372  omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
373  ompReductionDecls.push_back(decl);
374  if (!decl)
375  return failure();
376  reductionSyms.push_back(
377  SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
378  }
379 
380  // Allocate reduction variables. Make sure the we don't overflow the stack
381  // with local `alloca`s by saving and restoring the stack pointer.
382  Location loc = parallelOp.getLoc();
383  Value one = rewriter.create<LLVM::ConstantOp>(
384  loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
385  SmallVector<Value> reductionVariables;
386  reductionVariables.reserve(parallelOp.getNumReductions());
387  auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
388  for (Value init : parallelOp.getInitVals()) {
389  assert((LLVM::isCompatibleType(init.getType()) ||
390  isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
391  "cannot create a reduction variable if the type is not an LLVM "
392  "pointer element");
393  Value storage =
394  rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
395  rewriter.create<LLVM::StoreOp>(loc, init, storage);
396  reductionVariables.push_back(storage);
397  }
398 
399  // Replace the reduction operations contained in this loop. Must be done
400  // here rather than in a separate pattern to have access to the list of
401  // reduction variables.
402  for (auto [x, y, rD] : llvm::zip_equal(
403  reductionVariables, reduce.getOperands(), ompReductionDecls)) {
404  OpBuilder::InsertionGuard guard(rewriter);
405  rewriter.setInsertionPoint(reduce);
406  Region &redRegion = rD.getReductionRegion();
407  // The SCF dialect by definition contains only structured operations
408  // and hence the SCF reduction region will contain a single block.
409  // The ompReductionDecls region is a copy of the SCF reduction region
410  // and hence has the same property.
411  assert(redRegion.hasOneBlock() &&
412  "expect reduction region to have one block");
413  Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
414  Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
415  rD.getType(), pvtRedVar);
416  // Make a copy of the reduction combiner region in the body
417  mlir::OpBuilder builder(rewriter.getContext());
418  builder.setInsertionPoint(reduce);
419  mlir::IRMapping mapper;
420  assert(redRegion.getNumArguments() == 2 &&
421  "expect reduction region to have two arguments");
422  mapper.map(redRegion.getArgument(0), pvtRedVal);
423  mapper.map(redRegion.getArgument(1), y);
424  for (auto &op : redRegion.getOps()) {
425  Operation *cloneOp = builder.clone(op, mapper);
426  if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
427  assert(yieldOp && yieldOp.getResults().size() == 1 &&
428  "expect YieldOp in reduction region to return one result");
429  Value redVal = yieldOp.getResults()[0];
430  rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
431  rewriter.eraseOp(yieldOp);
432  break;
433  }
434  }
435  }
436  rewriter.eraseOp(reduce);
437 
438  Value numThreadsVar;
439  if (numThreads > 0) {
440  numThreadsVar = rewriter.create<LLVM::ConstantOp>(
441  loc, rewriter.getI32IntegerAttr(numThreads));
442  }
443  // Create the parallel wrapper.
444  auto ompParallel = rewriter.create<omp::ParallelOp>(
445  loc,
446  /* allocate_vars = */ llvm::SmallVector<Value>{},
447  /* allocator_vars = */ llvm::SmallVector<Value>{},
448  /* if_expr = */ Value{},
449  /* num_threads = */ numThreadsVar,
450  /* private_vars = */ ValueRange(),
451  /* private_syms = */ nullptr,
452  /* private_needs_barrier = */ nullptr,
453  /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
454  /* reduction_mod = */ nullptr,
455  /* reduction_vars = */ llvm::SmallVector<Value>{},
456  /* reduction_byref = */ DenseBoolArrayAttr{},
457  /* reduction_syms = */ ArrayAttr{});
458  {
459 
460  OpBuilder::InsertionGuard guard(rewriter);
461  rewriter.createBlock(&ompParallel.getRegion());
462 
463  // Replace the loop.
464  {
465  OpBuilder::InsertionGuard allocaGuard(rewriter);
466  // Create worksharing loop wrapper.
467  auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
468  if (!reductionVariables.empty()) {
469  wsloopOp.setReductionSymsAttr(
470  ArrayAttr::get(rewriter.getContext(), reductionSyms));
471  wsloopOp.getReductionVarsMutable().append(reductionVariables);
472  llvm::SmallVector<bool> reductionByRef;
473  // false because these reductions always reduce scalars and so do
474  // not need to pass by reference
475  reductionByRef.resize(reductionVariables.size(), false);
476  wsloopOp.setReductionByref(
477  DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
478  }
479  rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
480 
481  // The wrapper's entry block arguments will define the reduction
482  // variables.
483  llvm::SmallVector<mlir::Type> reductionTypes;
484  reductionTypes.reserve(reductionVariables.size());
485  llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
486  [](mlir::Value v) { return v.getType(); });
487  rewriter.createBlock(
488  &wsloopOp.getRegion(), {}, reductionTypes,
489  llvm::SmallVector<mlir::Location>(reductionVariables.size(),
490  parallelOp.getLoc()));
491 
492  // Create loop nest and populate region with contents of scf.parallel.
493  auto loopOp = rewriter.create<omp::LoopNestOp>(
494  parallelOp.getLoc(), parallelOp.getLowerBound(),
495  parallelOp.getUpperBound(), parallelOp.getStep());
496 
497  rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
498  loopOp.getRegion().begin());
499 
500  // Remove reduction-related block arguments from omp.loop_nest and
501  // redirect uses to the corresponding omp.wsloop block argument.
502  mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
503  unsigned numLoops = parallelOp.getNumLoops();
504  rewriter.replaceAllUsesWith(
505  loopOpEntryBlock.getArguments().drop_front(numLoops),
506  wsloopOp.getRegion().getArguments());
507  loopOpEntryBlock.eraseArguments(
508  numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
509 
510  Block *ops =
511  rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
512  rewriter.setInsertionPointToStart(&loopOpEntryBlock);
513 
514  auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
515  TypeRange());
516  rewriter.create<omp::YieldOp>(loc, ValueRange());
517  Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
518  rewriter.mergeBlocks(ops, scopeBlock);
519  rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
520  rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
521  }
522  }
523 
524  // Load loop results.
525  SmallVector<Value> results;
526  results.reserve(reductionVariables.size());
527  for (auto [variable, type] :
528  llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
529  Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
530  results.push_back(res);
531  }
532  rewriter.replaceOp(parallelOp, results);
533 
534  return success();
535  }
536 };
537 
538 /// Applies the conversion patterns in the given function.
539 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
540  ConversionTarget target(*module.getContext());
541  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
542  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
543  memref::MemRefDialect>();
544 
545  RewritePatternSet patterns(module.getContext());
546  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
547  FrozenRewritePatternSet frozen(std::move(patterns));
548  return applyPartialConversion(module, target, frozen);
549 }
550 
551 /// A pass converting SCF operations to OpenMP operations.
552 struct SCFToOpenMPPass
553  : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
554 
555  using Base::Base;
556 
557  /// Pass entry point.
558  void runOnOperation() override {
559  if (failed(applyPatterns(getOperation(), numThreads)))
560  signalPassFailure();
561  }
562 };
563 
564 } // namespace
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2921
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
Definition: SCFToOpenMP.cpp:41
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
Definition: SCFToOpenMP.cpp:76
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314