MLIR 22.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.parallel operations into OpenMP
10// parallel loops.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Pass/Pass.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34/// Matches a block containing a "simple" reduction. The expected shape of the
35/// block is as follows.
36///
37/// ^bb(%arg0, %arg1):
38/// %0 = OpTy(%arg0, %arg1)
39/// scf.reduce.return %0
40template <typename... OpTy>
41static bool matchSimpleReduction(Block &block) {
42 if (block.empty() || llvm::hasSingleElement(block) ||
43 std::next(block.begin(), 2) != block.end())
44 return false;
45
46 if (block.getNumArguments() != 2)
47 return false;
48
50 Value reducedVal = matchReduction({block.getArguments()[1]},
51 /*redPos=*/0, combinerOps);
52
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
54 return false;
55
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.back()) &&
58 block.front().getOperands() == block.getArguments();
59}
60
61/// Matches a block containing a select-based min/max reduction. The types of
62/// select and compare operations are provided as template arguments. The
63/// comparison predicates suitable for min and max are provided as function
64/// arguments. If a reduction is matched, `ifMin` will be set if the reduction
65/// compute the minimum and unset if it computes the maximum, otherwise it
66/// remains unmodified. The expected shape of the block is as follows.
67///
68/// ^bb(%arg0, %arg1):
69/// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
70/// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
71/// scf.reduce.return %1
72template <
73 typename CompareOpTy, typename SelectOpTy,
74 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
75static bool
77 ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
78 static_assert(
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
81
82 // Expect exactly three operations in the block.
83 if (block.empty() || llvm::hasSingleElement(block) ||
84 std::next(block.begin(), 2) == block.end() ||
85 std::next(block.begin(), 3) != block.end())
86 return false;
87
88 // Check op kinds.
89 auto compare = dyn_cast<CompareOpTy>(block.front());
90 auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
92 if (!compare || !select || !terminator)
93 return false;
94
95 // Block arguments must be compared.
96 if (compare->getOperands() != block.getArguments())
97 return false;
98
99 // Detect whether the comparison is less-than or greater-than, otherwise bail.
100 bool isLess;
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
102 isLess = true;
103 } else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
105 isLess = false;
106 } else {
107 return false;
108 }
109
110 if (select.getCondition() != compare.getResult())
111 return false;
112
113 // Detect if the operands are swapped between cmpf and select. Match the
114 // comparison type with the requested type or with the opposite of the
115 // requested type if the operands are swapped. Use generic accessors because
116 // std and LLVM versions of select have different operand names but identical
117 // positions.
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
125 return false;
126
127 if (select.getResult() != terminator.getResult())
128 return false;
129
130 // The reduction is a min if it uses less-than predicates with same operands
131 // or greather-than predicates with swapped operands. Similarly for max.
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
134}
135
136/// Returns the float semantics for the given float type.
137static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
138 if (type.isF16())
139 return llvm::APFloat::IEEEhalf();
140 if (type.isF32())
141 return llvm::APFloat::IEEEsingle();
142 if (type.isF64())
143 return llvm::APFloat::IEEEdouble();
144 if (type.isF128())
145 return llvm::APFloat::IEEEquad();
146 if (type.isBF16())
147 return llvm::APFloat::BFloat();
148 if (type.isF80())
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable("unknown float type");
151}
152
153/// Returns an attribute with the minimum (if `min` is set) or the maximum value
154/// (otherwise) for the given float type.
156 auto fltType = cast<FloatType>(type);
157 return FloatAttr::get(
158 type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
159}
160
161/// Returns an attribute with the signed integer minimum (if `min` is set) or
162/// the maximum value (otherwise) for the given integer type, regardless of its
163/// signedness semantics (only the width is considered).
165 auto intType = cast<IntegerType>(type);
166 unsigned bitwidth = intType.getWidth();
167 return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
168 : llvm::APInt::getSignedMaxValue(bitwidth));
169}
170
171/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
172/// the maximum value (otherwise) for the given integer type, regardless of its
173/// signedness semantics (only the width is considered).
175 auto intType = cast<IntegerType>(type);
176 unsigned bitwidth = intType.getWidth();
177 return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
178 : llvm::APInt::getAllOnes(bitwidth));
179}
180
181/// Creates an OpenMP reduction declaration and inserts it into the provided
182/// symbol table. The declaration has a constant initializer with the neutral
183/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
184/// over from `reduce`.
185static omp::DeclareReductionOp
187 scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
188 OpBuilder::InsertionGuard guard(builder);
189 Type type = reduce.getOperands()[reductionIndex].getType();
190 auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
191 "__scf_reduction", type);
192 symbolTable.insert(decl);
193
194 builder.createBlock(&decl.getInitializerRegion(),
195 decl.getInitializerRegion().end(), {type},
196 {reduce.getOperands()[reductionIndex].getLoc()});
197 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
198 Value init =
199 LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue);
200 omp::YieldOp::create(builder, reduce.getLoc(), init);
201
202 Operation *terminator =
203 &reduce.getReductions()[reductionIndex].front().back();
204 assert(isa<scf::ReduceReturnOp>(terminator) &&
205 "expected reduce op to be terminated by redure return");
206 builder.setInsertionPoint(terminator);
207 builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
208 terminator->getOperands());
209 builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
210 decl.getReductionRegion(),
211 decl.getReductionRegion().end());
212 return decl;
213}
214
215/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
216/// using llvm.atomicrmw of the given kind.
217static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
218 LLVM::AtomicBinOp atomicKind,
219 omp::DeclareReductionOp decl,
220 scf::ReduceOp reduce,
221 int64_t reductionIndex) {
222 OpBuilder::InsertionGuard guard(builder);
223 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
224 Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
225 builder.createBlock(&decl.getAtomicReductionRegion(),
226 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
227 {reduceOperandLoc, reduceOperandLoc});
228 Block *atomicBlock = &decl.getAtomicReductionRegion().back();
229 builder.setInsertionPointToEnd(atomicBlock);
230 Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(),
231 atomicBlock->getArgument(1));
232 LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind,
233 atomicBlock->getArgument(0), loaded,
234 LLVM::AtomicOrdering::monotonic);
235 omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
236 return decl;
237}
238
239/// Creates an OpenMP reduction declaration that corresponds to the given SCF
240/// reduction and returns it. Recognizes common reductions in order to identify
241/// the neutral value, necessary for the OpenMP declaration. If the reduction
242/// cannot be recognized, returns null.
243static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
244 scf::ReduceOp reduce,
245 int64_t reductionIndex) {
247 SymbolTable symbolTable(container);
248
249 // Insert reduction declarations in the symbol-table ancestor before the
250 // ancestor of the current insertion point.
251 Operation *insertionPoint = reduce;
252 while (insertionPoint->getParentOp() != container)
253 insertionPoint = insertionPoint->getParentOp();
254 OpBuilder::InsertionGuard guard(builder);
255 builder.setInsertionPoint(insertionPoint);
256
257 assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
258 "expected reduction region to have a single element");
259
260 // Match simple binary reductions that can be expressed with atomicrmw.
261 Type type = reduce.getOperands()[reductionIndex].getType();
262 Block &reduction = reduce.getReductions()[reductionIndex].front();
264 omp::DeclareReductionOp decl =
265 createDecl(builder, symbolTable, reduce, reductionIndex,
266 builder.getFloatAttr(type, 0.0));
267 return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
268 reductionIndex);
269 }
271 omp::DeclareReductionOp decl =
272 createDecl(builder, symbolTable, reduce, reductionIndex,
273 builder.getIntegerAttr(type, 0));
274 return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
275 reductionIndex);
276 }
278 omp::DeclareReductionOp decl =
279 createDecl(builder, symbolTable, reduce, reductionIndex,
280 builder.getIntegerAttr(type, 0));
281 return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
282 reductionIndex);
283 }
285 omp::DeclareReductionOp decl =
286 createDecl(builder, symbolTable, reduce, reductionIndex,
287 builder.getIntegerAttr(type, 0));
288 return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
289 reductionIndex);
290 }
292 omp::DeclareReductionOp decl = createDecl(
293 builder, symbolTable, reduce, reductionIndex,
294 builder.getIntegerAttr(
295 type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
296 return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
297 reductionIndex);
298 }
299
300 // Match simple binary reductions that cannot be expressed with atomicrmw.
301 // TODO: add atomic region using cmpxchg (which needs atomic load to be
302 // available as an op).
304 return createDecl(builder, symbolTable, reduce, reductionIndex,
305 builder.getFloatAttr(type, 1.0));
306 }
308 return createDecl(builder, symbolTable, reduce, reductionIndex,
309 builder.getIntegerAttr(type, 1));
310 }
311
312 // Match select-based min/max reductions.
313 bool isMin;
315 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
316 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
319 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
320 return createDecl(builder, symbolTable, reduce, reductionIndex,
321 minMaxValueForFloat(type, !isMin));
322 }
324 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
325 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
328 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
329 omp::DeclareReductionOp decl =
330 createDecl(builder, symbolTable, reduce, reductionIndex,
331 minMaxValueForSignedInt(type, !isMin));
332 return addAtomicRMW(builder,
333 isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
334 decl, reduce, reductionIndex);
335 }
337 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
338 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
341 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
342 omp::DeclareReductionOp decl =
343 createDecl(builder, symbolTable, reduce, reductionIndex,
344 minMaxValueForUnsignedInt(type, !isMin));
345 return addAtomicRMW(
346 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
347 decl, reduce, reductionIndex);
348 }
349
350 return nullptr;
351}
352
353namespace {
354
355struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
356 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
357 unsigned numThreads;
358
359 ParallelOpLowering(MLIRContext *context,
360 unsigned numThreads = kUseOpenMPDefaultNumThreads)
361 : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
362
363 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
364 PatternRewriter &rewriter) const override {
365 // Declare reductions.
366 // TODO: consider checking it here is already a compatible reduction
367 // declaration and use it instead of redeclaring.
368 SmallVector<Attribute> reductionSyms;
369 SmallVector<omp::DeclareReductionOp> ompReductionDecls;
370 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
371 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
372 omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
373 ompReductionDecls.push_back(decl);
374 if (!decl)
375 return failure();
376 reductionSyms.push_back(
377 SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
378 }
379
380 // Allocate reduction variables. Make sure the we don't overflow the stack
381 // with local `alloca`s by saving and restoring the stack pointer.
382 Location loc = parallelOp.getLoc();
383 Value one =
384 LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
385 rewriter.getI64IntegerAttr(1));
386 SmallVector<Value> reductionVariables;
387 reductionVariables.reserve(parallelOp.getNumReductions());
388 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
389 for (Value init : parallelOp.getInitVals()) {
390 assert((LLVM::isCompatibleType(init.getType()) ||
391 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
392 "cannot create a reduction variable if the type is not an LLVM "
393 "pointer element");
394 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
395 init.getType(), one, 0);
396 LLVM::StoreOp::create(rewriter, loc, init, storage);
397 reductionVariables.push_back(storage);
398 }
399
400 // Replace the reduction operations contained in this loop. Must be done
401 // here rather than in a separate pattern to have access to the list of
402 // reduction variables.
403 for (auto [x, y, rD] : llvm::zip_equal(
404 reductionVariables, reduce.getOperands(), ompReductionDecls)) {
405 OpBuilder::InsertionGuard guard(rewriter);
406 rewriter.setInsertionPoint(reduce);
407 Region &redRegion = rD.getReductionRegion();
408 // The SCF dialect by definition contains only structured operations
409 // and hence the SCF reduction region will contain a single block.
410 // The ompReductionDecls region is a copy of the SCF reduction region
411 // and hence has the same property.
412 assert(redRegion.hasOneBlock() &&
413 "expect reduction region to have one block");
414 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
415 Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
416 rD.getType(), pvtRedVar);
417 // Make a copy of the reduction combiner region in the body
418 mlir::OpBuilder builder(rewriter.getContext());
419 builder.setInsertionPoint(reduce);
420 mlir::IRMapping mapper;
421 assert(redRegion.getNumArguments() == 2 &&
422 "expect reduction region to have two arguments");
423 mapper.map(redRegion.getArgument(0), pvtRedVal);
424 mapper.map(redRegion.getArgument(1), y);
425 for (auto &op : redRegion.getOps()) {
426 Operation *cloneOp = builder.clone(op, mapper);
427 if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
428 assert(yieldOp && yieldOp.getResults().size() == 1 &&
429 "expect YieldOp in reduction region to return one result");
430 Value redVal = yieldOp.getResults()[0];
431 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
432 rewriter.eraseOp(yieldOp);
433 break;
434 }
435 }
436 }
437 rewriter.eraseOp(reduce);
438
439 Value numThreadsVar;
440 if (numThreads > 0) {
441 numThreadsVar = LLVM::ConstantOp::create(
442 rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
443 }
444 // Create the parallel wrapper.
445 auto ompParallel = omp::ParallelOp::create(
446 rewriter, loc,
447 /* allocate_vars = */ llvm::SmallVector<Value>{},
448 /* allocator_vars = */ llvm::SmallVector<Value>{},
449 /* if_expr = */ Value{},
450 /* num_threads = */ numThreadsVar,
451 /* private_vars = */ ValueRange(),
452 /* private_syms = */ nullptr,
453 /* private_needs_barrier = */ nullptr,
454 /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
455 /* reduction_mod = */ nullptr,
456 /* reduction_vars = */ llvm::SmallVector<Value>{},
457 /* reduction_byref = */ DenseBoolArrayAttr{},
458 /* reduction_syms = */ ArrayAttr{});
459 {
460
461 OpBuilder::InsertionGuard guard(rewriter);
462 rewriter.createBlock(&ompParallel.getRegion());
463
464 // Replace the loop.
465 {
466 OpBuilder::InsertionGuard allocaGuard(rewriter);
467 // Create worksharing loop wrapper.
468 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
469 if (!reductionVariables.empty()) {
470 wsloopOp.setReductionSymsAttr(
471 ArrayAttr::get(rewriter.getContext(), reductionSyms));
472 wsloopOp.getReductionVarsMutable().append(reductionVariables);
473 llvm::SmallVector<bool> reductionByRef;
474 // false because these reductions always reduce scalars and so do
475 // not need to pass by reference
476 reductionByRef.resize(reductionVariables.size(), false);
477 wsloopOp.setReductionByref(
478 DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
479 }
480 omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
481
482 // The wrapper's entry block arguments will define the reduction
483 // variables.
484 llvm::SmallVector<mlir::Type> reductionTypes;
485 reductionTypes.reserve(reductionVariables.size());
486 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
487 [](mlir::Value v) { return v.getType(); });
488 rewriter.createBlock(
489 &wsloopOp.getRegion(), {}, reductionTypes,
490 llvm::SmallVector<mlir::Location>(reductionVariables.size(),
491 parallelOp.getLoc()));
492
493 // Create loop nest and populate region with contents of scf.parallel.
494 auto loopOp = omp::LoopNestOp::create(
495 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
496 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
497 parallelOp.getStep(), /*loop_inclusive=*/false,
498 /*tile_sizes=*/nullptr);
499
500 rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
501 loopOp.getRegion().begin());
502
503 // Remove reduction-related block arguments from omp.loop_nest and
504 // redirect uses to the corresponding omp.wsloop block argument.
505 mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
506 unsigned numLoops = parallelOp.getNumLoops();
507 rewriter.replaceAllUsesWith(
508 loopOpEntryBlock.getArguments().drop_front(numLoops),
509 wsloopOp.getRegion().getArguments());
510 loopOpEntryBlock.eraseArguments(
511 numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
512
513 Block *ops =
514 rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
515 rewriter.setInsertionPointToStart(&loopOpEntryBlock);
516
517 auto scope = memref::AllocaScopeOp::create(
518 rewriter, parallelOp.getLoc(), TypeRange());
519 omp::YieldOp::create(rewriter, loc, ValueRange());
520 Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
521 rewriter.mergeBlocks(ops, scopeBlock);
522 rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
523 memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange());
524 }
525 }
526
527 // Load loop results.
528 SmallVector<Value> results;
529 results.reserve(reductionVariables.size());
530 for (auto [variable, type] :
531 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
532 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
533 results.push_back(res);
534 }
535 rewriter.replaceOp(parallelOp, results);
536
537 return success();
538 }
539};
540
541/// Applies the conversion patterns in the given function.
542static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
543 RewritePatternSet patterns(module.getContext());
544 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
545 FrozenRewritePatternSet frozen(std::move(patterns));
546 walkAndApplyPatterns(module, frozen);
547 auto status = module.walk([](Operation *op) {
548 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
549 op->emitError("unconverted operation found");
550 return WalkResult::interrupt();
551 }
552 return WalkResult::advance();
553 });
554 return failure(status.wasInterrupted());
555}
556
557/// A pass converting SCF operations to OpenMP operations.
558struct SCFToOpenMPPass
560
561 using Base::Base;
562
563 /// Pass entry point.
564 void runOnOperation() override {
565 if (failed(applyPatterns(getOperation(), numThreads)))
566 signalPassFailure();
567 }
568};
569
570} // namespace
return success()
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ArrayAttr()
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator_range< OpIterator > getOps()
Definition Region.h:172
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...