MLIR 23.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.parallel operations into OpenMP
10// parallel loops.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Pass/Pass.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34/// Matches a block containing a "simple" reduction. The expected shape of the
35/// block is as follows.
36///
37/// ^bb(%arg0, %arg1):
38/// %0 = OpTy(%arg0, %arg1)
39/// scf.reduce.return %0
40template <typename... OpTy>
41static bool matchSimpleReduction(Block &block) {
42 if (block.empty() || llvm::hasSingleElement(block) ||
43 std::next(block.begin(), 2) != block.end())
44 return false;
45
46 if (block.getNumArguments() != 2)
47 return false;
48
50 Value reducedVal = matchReduction({block.getArguments()[1]},
51 /*redPos=*/0, combinerOps);
52
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
54 return false;
55
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.back()) &&
58 block.front().getOperands() == block.getArguments();
59}
60
61/// Matches a block containing a select-based min/max reduction. The types of
62/// select and compare operations are provided as template arguments. The
63/// comparison predicates suitable for min and max are provided as function
64/// arguments. If a reduction is matched, `ifMin` will be set if the reduction
65/// compute the minimum and unset if it computes the maximum, otherwise it
66/// remains unmodified. The expected shape of the block is as follows.
67///
68/// ^bb(%arg0, %arg1):
69/// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
70/// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
71/// scf.reduce.return %1
72template <
73 typename CompareOpTy, typename SelectOpTy,
74 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
75static bool
77 ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
78 static_assert(
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
81
82 // Expect exactly three operations in the block.
83 if (block.empty() || llvm::hasSingleElement(block) ||
84 std::next(block.begin(), 2) == block.end() ||
85 std::next(block.begin(), 3) != block.end())
86 return false;
87
88 // Check op kinds.
89 auto compare = dyn_cast<CompareOpTy>(block.front());
90 auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
92 if (!compare || !select || !terminator)
93 return false;
94
95 // Block arguments must be compared.
96 if (compare->getOperands() != block.getArguments())
97 return false;
98
99 // Detect whether the comparison is less-than or greater-than, otherwise bail.
100 bool isLess;
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
102 isLess = true;
103 } else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
105 isLess = false;
106 } else {
107 return false;
108 }
109
110 if (select.getCondition() != compare.getResult())
111 return false;
112
113 // Detect if the operands are swapped between cmpf and select. Match the
114 // comparison type with the requested type or with the opposite of the
115 // requested type if the operands are swapped. Use generic accessors because
116 // std and LLVM versions of select have different operand names but identical
117 // positions.
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
125 return false;
126
127 if (select.getResult() != terminator.getResult())
128 return false;
129
130 // The reduction is a min if it uses less-than predicates with same operands
131 // or greather-than predicates with swapped operands. Similarly for max.
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
134}
135
136/// Returns the float semantics for the given float type.
137static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
138 if (type.isF16())
139 return llvm::APFloat::IEEEhalf();
140 if (type.isF32())
141 return llvm::APFloat::IEEEsingle();
142 if (type.isF64())
143 return llvm::APFloat::IEEEdouble();
144 if (type.isF128())
145 return llvm::APFloat::IEEEquad();
146 if (type.isBF16())
147 return llvm::APFloat::BFloat();
148 if (type.isF80())
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable("unknown float type");
151}
152
153/// Helper to create a splat attribute for vector types, or return the scalar
154/// attribute for scalar types.
156 if (auto vecType = dyn_cast<VectorType>(type))
157 return DenseElementsAttr::get(vecType, val);
158 return val;
159}
160
161/// Returns an attribute with the minimum (if `min` is set) or the maximum value
162/// (otherwise) for the given float type.
164 Type elType = getElementTypeOrSelf(type);
165 auto fltType = cast<FloatType>(elType);
166 auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
167
168 return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
169}
170
171/// Returns an attribute with the signed integer minimum (if `min` is set) or
172/// the maximum value (otherwise) for the given integer type, regardless of its
173/// signedness semantics (only the width is considered).
175 Type elType = getElementTypeOrSelf(type);
176 auto intType = cast<IntegerType>(elType);
177 unsigned bitwidth = intType.getWidth();
178 auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
179 : llvm::APInt::getSignedMaxValue(bitwidth);
180
181 return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
182}
183
184/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
185/// the maximum value (otherwise) for the given integer type, regardless of its
186/// signedness semantics (only the width is considered).
188 Type elType = getElementTypeOrSelf(type);
189 auto intType = cast<IntegerType>(elType);
190 unsigned bitwidth = intType.getWidth();
191 auto val =
192 min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
193
194 return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
195}
196
197/// Creates an OpenMP reduction declaration and inserts it into the provided
198/// symbol table. The declaration has a constant initializer with the neutral
199/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
200/// over from `reduce`.
201static omp::DeclareReductionOp
203 scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
204 OpBuilder::InsertionGuard guard(builder);
205 Type type = reduce.getOperands()[reductionIndex].getType();
206 auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
207 "__scf_reduction", type,
208 /*byref_element_type=*/{});
209 symbolTable.insert(decl);
210
211 builder.createBlock(&decl.getInitializerRegion(),
212 decl.getInitializerRegion().end(), {type},
213 {reduce.getOperands()[reductionIndex].getLoc()});
214 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
215 Value init =
216 LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue);
217 omp::YieldOp::create(builder, reduce.getLoc(), init);
218
219 Operation *terminator =
220 &reduce.getReductions()[reductionIndex].front().back();
221 assert(isa<scf::ReduceReturnOp>(terminator) &&
222 "expected reduce op to be terminated by reduce return");
223 builder.setInsertionPoint(terminator);
224 builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
225 terminator->getOperands());
226 builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
227 decl.getReductionRegion(),
228 decl.getReductionRegion().end());
229 return decl;
230}
231
232/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
233/// using llvm.atomicrmw of the given kind.
234static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
235 LLVM::AtomicBinOp atomicKind,
236 omp::DeclareReductionOp decl,
237 scf::ReduceOp reduce,
238 int64_t reductionIndex) {
239 OpBuilder::InsertionGuard guard(builder);
240 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
241 Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
242 builder.createBlock(&decl.getAtomicReductionRegion(),
243 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
244 {reduceOperandLoc, reduceOperandLoc});
245 Block *atomicBlock = &decl.getAtomicReductionRegion().back();
246 builder.setInsertionPointToEnd(atomicBlock);
247 Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(),
248 atomicBlock->getArgument(1));
249 LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind,
250 atomicBlock->getArgument(0), loaded,
251 LLVM::AtomicOrdering::monotonic);
252 omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
253 return decl;
254}
255
256/// Returns true if the type is supported by llvm.atomicrmw.
257/// LLVM IR currently does not support atomic operations on vector types.
258/// See LLVM Language Reference Manual on 'atomicrmw'.
259static bool supportsAtomic(Type type) { return !isa<VectorType>(type); }
260
261/// Creates an OpenMP reduction declaration that corresponds to the given SCF
262/// reduction and returns it. Recognizes common reductions in order to identify
263/// the neutral value, necessary for the OpenMP declaration. If the reduction
264/// cannot be recognized, returns null.
265static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
266 scf::ReduceOp reduce,
267 int64_t reductionIndex) {
269 SymbolTable symbolTable(container);
270
271 // Insert reduction declarations in the symbol-table ancestor before the
272 // ancestor of the current insertion point.
273 Operation *insertionPoint = reduce;
274 while (insertionPoint->getParentOp() != container)
275 insertionPoint = insertionPoint->getParentOp();
276 OpBuilder::InsertionGuard guard(builder);
277 builder.setInsertionPoint(insertionPoint);
278
279 assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
280 "expected reduction region to have a single element");
281
282 // Match simple binary reductions that can be expressed with atomicrmw.
283 Type type = reduce.getOperands()[reductionIndex].getType();
284 Block &reduction = reduce.getReductions()[reductionIndex].front();
285
286 // Handle scalar element type extraction for vector bitwidth safety.
287 Type elType = getElementTypeOrSelf(type);
288
289 // Arithmetic Reductions
291 omp::DeclareReductionOp decl = createDecl(
292 builder, symbolTable, reduce, reductionIndex,
293 getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
294 return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
295 decl, reduce, reductionIndex)
296 : decl;
297 }
299 omp::DeclareReductionOp decl = createDecl(
300 builder, symbolTable, reduce, reductionIndex,
301 getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
302 return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
303 decl, reduce, reductionIndex)
304 : decl;
305 }
307 omp::DeclareReductionOp decl = createDecl(
308 builder, symbolTable, reduce, reductionIndex,
309 getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
310 return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
311 decl, reduce, reductionIndex)
312 : decl;
313 }
315 omp::DeclareReductionOp decl = createDecl(
316 builder, symbolTable, reduce, reductionIndex,
317 getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
318 return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
319 decl, reduce, reductionIndex)
320 : decl;
321 }
323 APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
324 omp::DeclareReductionOp decl = createDecl(
325 builder, symbolTable, reduce, reductionIndex,
326 getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
327 return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
328 decl, reduce, reductionIndex)
329 : decl;
330 }
331
332 // Match simple binary reductions that cannot be expressed with atomicrmw.
333 // TODO: add atomic region using cmpxchg (which needs atomic load to be
334 // available as an op).
336 return createDecl(
337 builder, symbolTable, reduce, reductionIndex,
338 getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
339 }
340
342 return createDecl(
343 builder, symbolTable, reduce, reductionIndex,
344 getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
345 }
346
347 // Match select-based min/max reductions.
348 bool isMin;
349 // Floating Point Min/Max
350 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp,
351 arith::CmpFPredicate>(
352 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
353 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
354 matchSelectReduction<arith::CmpFOp, arith::SelectOp,
355 arith::CmpFPredicate>(
356 reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
357 {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
358 return createDecl(builder, symbolTable, reduce, reductionIndex,
359 minMaxValueForFloat(type, !isMin));
360 }
361
362 // Integer Min/Max
363 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
364 arith::CmpIPredicate>(
365 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
366 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
367 matchSelectReduction<arith::CmpIOp, arith::SelectOp,
368 arith::CmpIPredicate>(
369 reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
370 {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
371 omp::DeclareReductionOp decl =
372 createDecl(builder, symbolTable, reduce, reductionIndex,
373 minMaxValueForSignedInt(type, !isMin));
374 return supportsAtomic(type) ? addAtomicRMW(builder,
375 isMin ? LLVM::AtomicBinOp::min
376 : LLVM::AtomicBinOp::max,
377 decl, reduce, reductionIndex)
378 : decl;
379 }
380
381 // Unsigned Integer Min/Max
382 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
383 arith::CmpIPredicate>(
384 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
385 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
386 matchSelectReduction<arith::CmpIOp, arith::SelectOp,
387 arith::CmpIPredicate>(
388 reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
389 {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
390 omp::DeclareReductionOp decl =
391 createDecl(builder, symbolTable, reduce, reductionIndex,
392 minMaxValueForUnsignedInt(type, !isMin));
393 return supportsAtomic(type) ? addAtomicRMW(builder,
394 isMin ? LLVM::AtomicBinOp::umin
395 : LLVM::AtomicBinOp::umax,
396 decl, reduce, reductionIndex)
397 : decl;
398 }
399
400 return nullptr;
401}
402
403namespace {
404
405struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
406 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
407 unsigned numThreads;
408
409 ParallelOpLowering(MLIRContext *context,
410 unsigned numThreads = kUseOpenMPDefaultNumThreads)
411 : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
412
413 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
414 PatternRewriter &rewriter) const override {
415 // Declare reductions.
416 // TODO: consider checking it here is already a compatible reduction
417 // declaration and use it instead of redeclaring.
418 SmallVector<Attribute> reductionSyms;
419 SmallVector<omp::DeclareReductionOp> ompReductionDecls;
420 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
421 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
422 omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
423 ompReductionDecls.push_back(decl);
424 if (!decl)
425 return failure();
426 reductionSyms.push_back(
427 SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
428 }
429
430 // Allocate reduction variables. Make sure the we don't overflow the stack
431 // with local `alloca`s by saving and restoring the stack pointer.
432 Location loc = parallelOp.getLoc();
433 Value one =
434 LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
435 rewriter.getI64IntegerAttr(1));
436 SmallVector<Value> reductionVariables;
437 reductionVariables.reserve(parallelOp.getNumReductions());
438 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
439 for (Value init : parallelOp.getInitVals()) {
440 assert((LLVM::isCompatibleType(init.getType()) ||
441 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
442 "cannot create a reduction variable if the type is not an LLVM "
443 "pointer element");
444 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
445 init.getType(), one, 0);
446 LLVM::StoreOp::create(rewriter, loc, init, storage);
447 reductionVariables.push_back(storage);
448 }
449
450 // Replace the reduction operations contained in this loop. Must be done
451 // here rather than in a separate pattern to have access to the list of
452 // reduction variables.
453 for (auto [x, y, rD] : llvm::zip_equal(
454 reductionVariables, reduce.getOperands(), ompReductionDecls)) {
455 OpBuilder::InsertionGuard guard(rewriter);
456 rewriter.setInsertionPoint(reduce);
457 Region &redRegion = rD.getReductionRegion();
458 // The SCF dialect by definition contains only structured operations
459 // and hence the SCF reduction region will contain a single block.
460 // The ompReductionDecls region is a copy of the SCF reduction region
461 // and hence has the same property.
462 assert(redRegion.hasOneBlock() &&
463 "expect reduction region to have one block");
464 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
465 Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
466 rD.getType(), pvtRedVar);
467 // Make a copy of the reduction combiner region in the body
468 mlir::OpBuilder builder(rewriter.getContext());
469 builder.setInsertionPoint(reduce);
470 mlir::IRMapping mapper;
471 assert(redRegion.getNumArguments() == 2 &&
472 "expect reduction region to have two arguments");
473 mapper.map(redRegion.getArgument(0), pvtRedVal);
474 mapper.map(redRegion.getArgument(1), y);
475 for (auto &op : redRegion.getOps()) {
476 Operation *cloneOp = builder.clone(op, mapper);
477 if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
478 assert(yieldOp && yieldOp.getResults().size() == 1 &&
479 "expect YieldOp in reduction region to return one result");
480 Value redVal = yieldOp.getResults()[0];
481 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
482 rewriter.eraseOp(yieldOp);
483 break;
484 }
485 }
486 }
487 rewriter.eraseOp(reduce);
488
489 SmallVector<Value> numThreadsVars;
490 if (numThreads > 0) {
491 Value numThreadsVar = LLVM::ConstantOp::create(
492 rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
493 numThreadsVars.push_back(numThreadsVar);
494 }
495 // Create the parallel wrapper.
496 auto ompParallel = omp::ParallelOp::create(
497 rewriter, loc,
498 /* allocate_vars = */ llvm::SmallVector<Value>{},
499 /* allocator_vars = */ llvm::SmallVector<Value>{},
500 /* if_expr = */ Value{},
501 /* num_threads_vars = */ numThreadsVars,
502 /* private_vars = */ ValueRange(),
503 /* private_syms = */ nullptr,
504 /* private_needs_barrier = */ nullptr,
505 /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
506 /* reduction_mod = */ nullptr,
507 /* reduction_vars = */ llvm::SmallVector<Value>{},
508 /* reduction_byref = */ DenseBoolArrayAttr{},
509 /* reduction_syms = */ ArrayAttr{});
510 {
511
512 OpBuilder::InsertionGuard guard(rewriter);
513 rewriter.createBlock(&ompParallel.getRegion());
514
515 // Replace the loop.
516 {
517 OpBuilder::InsertionGuard allocaGuard(rewriter);
518 // Create worksharing loop wrapper.
519 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
520 if (!reductionVariables.empty()) {
521 wsloopOp.setReductionSymsAttr(
522 ArrayAttr::get(rewriter.getContext(), reductionSyms));
523 wsloopOp.getReductionVarsMutable().append(reductionVariables);
524 llvm::SmallVector<bool> reductionByRef;
525 // false because these reductions always reduce scalars and so do
526 // not need to pass by reference
527 reductionByRef.resize(reductionVariables.size(), false);
528 wsloopOp.setReductionByref(
529 DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
530 }
531 omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
532
533 // The wrapper's entry block arguments will define the reduction
534 // variables.
535 llvm::SmallVector<mlir::Type> reductionTypes;
536 reductionTypes.reserve(reductionVariables.size());
537 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
538 [](mlir::Value v) { return v.getType(); });
539 rewriter.createBlock(
540 &wsloopOp.getRegion(), {}, reductionTypes,
541 llvm::SmallVector<mlir::Location>(reductionVariables.size(),
542 parallelOp.getLoc()));
543
544 // Create loop nest and populate region with contents of scf.parallel.
545 auto loopOp = omp::LoopNestOp::create(
546 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
547 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
548 parallelOp.getStep(), /*loop_inclusive=*/false,
549 /*tile_sizes=*/nullptr);
550
551 rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
552 loopOp.getRegion().begin());
553
554 // Remove reduction-related block arguments from omp.loop_nest and
555 // redirect uses to the corresponding omp.wsloop block argument.
556 mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
557 unsigned numLoops = parallelOp.getNumLoops();
558 rewriter.replaceAllUsesWith(
559 loopOpEntryBlock.getArguments().drop_front(numLoops),
560 wsloopOp.getRegion().getArguments());
561 loopOpEntryBlock.eraseArguments(
562 numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
563
564 Block *ops =
565 rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
566 rewriter.setInsertionPointToStart(&loopOpEntryBlock);
567
568 auto scope = memref::AllocaScopeOp::create(
569 rewriter, parallelOp.getLoc(), TypeRange());
570 omp::YieldOp::create(rewriter, loc, ValueRange());
571 Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
572 rewriter.mergeBlocks(ops, scopeBlock);
573 rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
574 memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange());
575 }
576 }
577
578 // Load loop results.
579 SmallVector<Value> results;
580 results.reserve(reductionVariables.size());
581 for (auto [variable, type] :
582 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
583 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
584 results.push_back(res);
585 }
586 rewriter.replaceOp(parallelOp, results);
587
588 return success();
589 }
590};
591
592/// Applies the conversion patterns in the given function.
593static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
594 RewritePatternSet patterns(module.getContext());
595 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
596 FrozenRewritePatternSet frozen(std::move(patterns));
597 walkAndApplyPatterns(module, frozen);
598 auto status = module.walk([](Operation *op) {
599 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
600 op->emitError("unconverted operation found");
601 return WalkResult::interrupt();
602 }
603 return WalkResult::advance();
604 });
605 return failure(status.wasInterrupted());
606}
607
608/// A pass converting SCF operations to OpenMP operations.
609struct SCFToOpenMPPass
610 : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
611
612 using Base::Base;
613
614 /// Pass entry point.
615 void runOnOperation() override {
616 if (failed(applyPatterns(getOperation(), numThreads)))
617 signalPassFailure();
618 }
619};
620
621} // namespace
return success()
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ArrayAttr()
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static Attribute getSplatOrScalarAttr(Type type, Attribute val)
Helper to create a splat attribute for vector types, or return the scalar attribute for scalar types.
static bool supportsAtomic(Type type)
Returns true if the type is supported by llvm.atomicrmw.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:206
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator_range< OpIterator > getOps()
Definition Region.h:172
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...