MLIR 22.0.0git
SCFToOpenMP.cpp
Go to the documentation of this file.
1//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.parallel operations into OpenMP
10// parallel loops.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Pass/Pass.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34/// Matches a block containing a "simple" reduction. The expected shape of the
35/// block is as follows.
36///
37/// ^bb(%arg0, %arg1):
38/// %0 = OpTy(%arg0, %arg1)
39/// scf.reduce.return %0
40template <typename... OpTy>
41static bool matchSimpleReduction(Block &block) {
42 if (block.empty() || llvm::hasSingleElement(block) ||
43 std::next(block.begin(), 2) != block.end())
44 return false;
45
46 if (block.getNumArguments() != 2)
47 return false;
48
50 Value reducedVal = matchReduction({block.getArguments()[1]},
51 /*redPos=*/0, combinerOps);
52
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
54 return false;
55
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.back()) &&
58 block.front().getOperands() == block.getArguments();
59}
60
61/// Matches a block containing a select-based min/max reduction. The types of
62/// select and compare operations are provided as template arguments. The
63/// comparison predicates suitable for min and max are provided as function
64/// arguments. If a reduction is matched, `ifMin` will be set if the reduction
65/// compute the minimum and unset if it computes the maximum, otherwise it
66/// remains unmodified. The expected shape of the block is as follows.
67///
68/// ^bb(%arg0, %arg1):
69/// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
70/// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
71/// scf.reduce.return %1
72template <
73 typename CompareOpTy, typename SelectOpTy,
74 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
75static bool
77 ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
78 static_assert(
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
81
82 // Expect exactly three operations in the block.
83 if (block.empty() || llvm::hasSingleElement(block) ||
84 std::next(block.begin(), 2) == block.end() ||
85 std::next(block.begin(), 3) != block.end())
86 return false;
87
88 // Check op kinds.
89 auto compare = dyn_cast<CompareOpTy>(block.front());
90 auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
92 if (!compare || !select || !terminator)
93 return false;
94
95 // Block arguments must be compared.
96 if (compare->getOperands() != block.getArguments())
97 return false;
98
99 // Detect whether the comparison is less-than or greater-than, otherwise bail.
100 bool isLess;
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
102 isLess = true;
103 } else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
105 isLess = false;
106 } else {
107 return false;
108 }
109
110 if (select.getCondition() != compare.getResult())
111 return false;
112
113 // Detect if the operands are swapped between cmpf and select. Match the
114 // comparison type with the requested type or with the opposite of the
115 // requested type if the operands are swapped. Use generic accessors because
116 // std and LLVM versions of select have different operand names but identical
117 // positions.
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
125 return false;
126
127 if (select.getResult() != terminator.getResult())
128 return false;
129
130 // The reduction is a min if it uses less-than predicates with same operands
131 // or greather-than predicates with swapped operands. Similarly for max.
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
134}
135
136/// Returns the float semantics for the given float type.
137static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
138 if (type.isF16())
139 return llvm::APFloat::IEEEhalf();
140 if (type.isF32())
141 return llvm::APFloat::IEEEsingle();
142 if (type.isF64())
143 return llvm::APFloat::IEEEdouble();
144 if (type.isF128())
145 return llvm::APFloat::IEEEquad();
146 if (type.isBF16())
147 return llvm::APFloat::BFloat();
148 if (type.isF80())
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable("unknown float type");
151}
152
153/// Returns an attribute with the minimum (if `min` is set) or the maximum value
154/// (otherwise) for the given float type.
156 auto fltType = cast<FloatType>(type);
157 return FloatAttr::get(
158 type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
159}
160
161/// Returns an attribute with the signed integer minimum (if `min` is set) or
162/// the maximum value (otherwise) for the given integer type, regardless of its
163/// signedness semantics (only the width is considered).
165 auto intType = cast<IntegerType>(type);
166 unsigned bitwidth = intType.getWidth();
167 return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
168 : llvm::APInt::getSignedMaxValue(bitwidth));
169}
170
171/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
172/// the maximum value (otherwise) for the given integer type, regardless of its
173/// signedness semantics (only the width is considered).
175 auto intType = cast<IntegerType>(type);
176 unsigned bitwidth = intType.getWidth();
177 return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
178 : llvm::APInt::getAllOnes(bitwidth));
179}
180
181/// Creates an OpenMP reduction declaration and inserts it into the provided
182/// symbol table. The declaration has a constant initializer with the neutral
183/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
184/// over from `reduce`.
185static omp::DeclareReductionOp
187 scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
188 OpBuilder::InsertionGuard guard(builder);
189 Type type = reduce.getOperands()[reductionIndex].getType();
190 auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
191 "__scf_reduction", type,
192 /*byref_element_type=*/{});
193 symbolTable.insert(decl);
194
195 builder.createBlock(&decl.getInitializerRegion(),
196 decl.getInitializerRegion().end(), {type},
197 {reduce.getOperands()[reductionIndex].getLoc()});
198 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
199 Value init =
200 LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue);
201 omp::YieldOp::create(builder, reduce.getLoc(), init);
202
203 Operation *terminator =
204 &reduce.getReductions()[reductionIndex].front().back();
205 assert(isa<scf::ReduceReturnOp>(terminator) &&
206 "expected reduce op to be terminated by redure return");
207 builder.setInsertionPoint(terminator);
208 builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
209 terminator->getOperands());
210 builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
211 decl.getReductionRegion(),
212 decl.getReductionRegion().end());
213 return decl;
214}
215
216/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
217/// using llvm.atomicrmw of the given kind.
218static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
219 LLVM::AtomicBinOp atomicKind,
220 omp::DeclareReductionOp decl,
221 scf::ReduceOp reduce,
222 int64_t reductionIndex) {
223 OpBuilder::InsertionGuard guard(builder);
224 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
225 Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
226 builder.createBlock(&decl.getAtomicReductionRegion(),
227 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
228 {reduceOperandLoc, reduceOperandLoc});
229 Block *atomicBlock = &decl.getAtomicReductionRegion().back();
230 builder.setInsertionPointToEnd(atomicBlock);
231 Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(),
232 atomicBlock->getArgument(1));
233 LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind,
234 atomicBlock->getArgument(0), loaded,
235 LLVM::AtomicOrdering::monotonic);
236 omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
237 return decl;
238}
239
240/// Creates an OpenMP reduction declaration that corresponds to the given SCF
241/// reduction and returns it. Recognizes common reductions in order to identify
242/// the neutral value, necessary for the OpenMP declaration. If the reduction
243/// cannot be recognized, returns null.
244static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
245 scf::ReduceOp reduce,
246 int64_t reductionIndex) {
248 SymbolTable symbolTable(container);
249
250 // Insert reduction declarations in the symbol-table ancestor before the
251 // ancestor of the current insertion point.
252 Operation *insertionPoint = reduce;
253 while (insertionPoint->getParentOp() != container)
254 insertionPoint = insertionPoint->getParentOp();
255 OpBuilder::InsertionGuard guard(builder);
256 builder.setInsertionPoint(insertionPoint);
257
258 assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
259 "expected reduction region to have a single element");
260
261 // Match simple binary reductions that can be expressed with atomicrmw.
262 Type type = reduce.getOperands()[reductionIndex].getType();
263 Block &reduction = reduce.getReductions()[reductionIndex].front();
265 omp::DeclareReductionOp decl =
266 createDecl(builder, symbolTable, reduce, reductionIndex,
267 builder.getFloatAttr(type, 0.0));
268 return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
269 reductionIndex);
270 }
272 omp::DeclareReductionOp decl =
273 createDecl(builder, symbolTable, reduce, reductionIndex,
274 builder.getIntegerAttr(type, 0));
275 return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
276 reductionIndex);
277 }
279 omp::DeclareReductionOp decl =
280 createDecl(builder, symbolTable, reduce, reductionIndex,
281 builder.getIntegerAttr(type, 0));
282 return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
283 reductionIndex);
284 }
286 omp::DeclareReductionOp decl =
287 createDecl(builder, symbolTable, reduce, reductionIndex,
288 builder.getIntegerAttr(type, 0));
289 return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
290 reductionIndex);
291 }
293 omp::DeclareReductionOp decl = createDecl(
294 builder, symbolTable, reduce, reductionIndex,
295 builder.getIntegerAttr(
296 type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
297 return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
298 reductionIndex);
299 }
300
301 // Match simple binary reductions that cannot be expressed with atomicrmw.
302 // TODO: add atomic region using cmpxchg (which needs atomic load to be
303 // available as an op).
305 return createDecl(builder, symbolTable, reduce, reductionIndex,
306 builder.getFloatAttr(type, 1.0));
307 }
309 return createDecl(builder, symbolTable, reduce, reductionIndex,
310 builder.getIntegerAttr(type, 1));
311 }
312
313 // Match select-based min/max reductions.
314 bool isMin;
316 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
319 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
321 return createDecl(builder, symbolTable, reduce, reductionIndex,
322 minMaxValueForFloat(type, !isMin));
323 }
325 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
328 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
330 omp::DeclareReductionOp decl =
331 createDecl(builder, symbolTable, reduce, reductionIndex,
332 minMaxValueForSignedInt(type, !isMin));
333 return addAtomicRMW(builder,
334 isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
335 decl, reduce, reductionIndex);
336 }
338 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
341 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343 omp::DeclareReductionOp decl =
344 createDecl(builder, symbolTable, reduce, reductionIndex,
345 minMaxValueForUnsignedInt(type, !isMin));
346 return addAtomicRMW(
347 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348 decl, reduce, reductionIndex);
349 }
350
351 return nullptr;
352}
353
354namespace {
355
356struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
357 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
358 unsigned numThreads;
359
360 ParallelOpLowering(MLIRContext *context,
361 unsigned numThreads = kUseOpenMPDefaultNumThreads)
362 : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
363
364 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
365 PatternRewriter &rewriter) const override {
366 // Declare reductions.
367 // TODO: consider checking it here is already a compatible reduction
368 // declaration and use it instead of redeclaring.
369 SmallVector<Attribute> reductionSyms;
370 SmallVector<omp::DeclareReductionOp> ompReductionDecls;
371 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
372 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
373 omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
374 ompReductionDecls.push_back(decl);
375 if (!decl)
376 return failure();
377 reductionSyms.push_back(
378 SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
379 }
380
381 // Allocate reduction variables. Make sure the we don't overflow the stack
382 // with local `alloca`s by saving and restoring the stack pointer.
383 Location loc = parallelOp.getLoc();
384 Value one =
385 LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
386 rewriter.getI64IntegerAttr(1));
387 SmallVector<Value> reductionVariables;
388 reductionVariables.reserve(parallelOp.getNumReductions());
389 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
390 for (Value init : parallelOp.getInitVals()) {
391 assert((LLVM::isCompatibleType(init.getType()) ||
392 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
393 "cannot create a reduction variable if the type is not an LLVM "
394 "pointer element");
395 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
396 init.getType(), one, 0);
397 LLVM::StoreOp::create(rewriter, loc, init, storage);
398 reductionVariables.push_back(storage);
399 }
400
401 // Replace the reduction operations contained in this loop. Must be done
402 // here rather than in a separate pattern to have access to the list of
403 // reduction variables.
404 for (auto [x, y, rD] : llvm::zip_equal(
405 reductionVariables, reduce.getOperands(), ompReductionDecls)) {
406 OpBuilder::InsertionGuard guard(rewriter);
407 rewriter.setInsertionPoint(reduce);
408 Region &redRegion = rD.getReductionRegion();
409 // The SCF dialect by definition contains only structured operations
410 // and hence the SCF reduction region will contain a single block.
411 // The ompReductionDecls region is a copy of the SCF reduction region
412 // and hence has the same property.
413 assert(redRegion.hasOneBlock() &&
414 "expect reduction region to have one block");
415 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
416 Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
417 rD.getType(), pvtRedVar);
418 // Make a copy of the reduction combiner region in the body
419 mlir::OpBuilder builder(rewriter.getContext());
420 builder.setInsertionPoint(reduce);
421 mlir::IRMapping mapper;
422 assert(redRegion.getNumArguments() == 2 &&
423 "expect reduction region to have two arguments");
424 mapper.map(redRegion.getArgument(0), pvtRedVal);
425 mapper.map(redRegion.getArgument(1), y);
426 for (auto &op : redRegion.getOps()) {
427 Operation *cloneOp = builder.clone(op, mapper);
428 if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
429 assert(yieldOp && yieldOp.getResults().size() == 1 &&
430 "expect YieldOp in reduction region to return one result");
431 Value redVal = yieldOp.getResults()[0];
432 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
433 rewriter.eraseOp(yieldOp);
434 break;
435 }
436 }
437 }
438 rewriter.eraseOp(reduce);
439
440 Value numThreadsVar;
441 if (numThreads > 0) {
442 numThreadsVar = LLVM::ConstantOp::create(
443 rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
444 }
445 // Create the parallel wrapper.
446 auto ompParallel = omp::ParallelOp::create(
447 rewriter, loc,
448 /* allocate_vars = */ llvm::SmallVector<Value>{},
449 /* allocator_vars = */ llvm::SmallVector<Value>{},
450 /* if_expr = */ Value{},
451 /* num_threads = */ numThreadsVar,
452 /* private_vars = */ ValueRange(),
453 /* private_syms = */ nullptr,
454 /* private_needs_barrier = */ nullptr,
455 /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
456 /* reduction_mod = */ nullptr,
457 /* reduction_vars = */ llvm::SmallVector<Value>{},
458 /* reduction_byref = */ DenseBoolArrayAttr{},
459 /* reduction_syms = */ ArrayAttr{});
460 {
461
462 OpBuilder::InsertionGuard guard(rewriter);
463 rewriter.createBlock(&ompParallel.getRegion());
464
465 // Replace the loop.
466 {
467 OpBuilder::InsertionGuard allocaGuard(rewriter);
468 // Create worksharing loop wrapper.
469 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
470 if (!reductionVariables.empty()) {
471 wsloopOp.setReductionSymsAttr(
472 ArrayAttr::get(rewriter.getContext(), reductionSyms));
473 wsloopOp.getReductionVarsMutable().append(reductionVariables);
474 llvm::SmallVector<bool> reductionByRef;
475 // false because these reductions always reduce scalars and so do
476 // not need to pass by reference
477 reductionByRef.resize(reductionVariables.size(), false);
478 wsloopOp.setReductionByref(
479 DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
480 }
481 omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
482
483 // The wrapper's entry block arguments will define the reduction
484 // variables.
485 llvm::SmallVector<mlir::Type> reductionTypes;
486 reductionTypes.reserve(reductionVariables.size());
487 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
488 [](mlir::Value v) { return v.getType(); });
489 rewriter.createBlock(
490 &wsloopOp.getRegion(), {}, reductionTypes,
491 llvm::SmallVector<mlir::Location>(reductionVariables.size(),
492 parallelOp.getLoc()));
493
494 // Create loop nest and populate region with contents of scf.parallel.
495 auto loopOp = omp::LoopNestOp::create(
496 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
497 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
498 parallelOp.getStep(), /*loop_inclusive=*/false,
499 /*tile_sizes=*/nullptr);
500
501 rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
502 loopOp.getRegion().begin());
503
504 // Remove reduction-related block arguments from omp.loop_nest and
505 // redirect uses to the corresponding omp.wsloop block argument.
506 mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
507 unsigned numLoops = parallelOp.getNumLoops();
508 rewriter.replaceAllUsesWith(
509 loopOpEntryBlock.getArguments().drop_front(numLoops),
510 wsloopOp.getRegion().getArguments());
511 loopOpEntryBlock.eraseArguments(
512 numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
513
514 Block *ops =
515 rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
516 rewriter.setInsertionPointToStart(&loopOpEntryBlock);
517
518 auto scope = memref::AllocaScopeOp::create(
519 rewriter, parallelOp.getLoc(), TypeRange());
520 omp::YieldOp::create(rewriter, loc, ValueRange());
521 Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
522 rewriter.mergeBlocks(ops, scopeBlock);
523 rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
524 memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange());
525 }
526 }
527
528 // Load loop results.
529 SmallVector<Value> results;
530 results.reserve(reductionVariables.size());
531 for (auto [variable, type] :
532 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
533 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
534 results.push_back(res);
535 }
536 rewriter.replaceOp(parallelOp, results);
537
538 return success();
539 }
540};
541
542/// Applies the conversion patterns in the given function.
543static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
544 RewritePatternSet patterns(module.getContext());
545 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
546 FrozenRewritePatternSet frozen(std::move(patterns));
547 walkAndApplyPatterns(module, frozen);
548 auto status = module.walk([](Operation *op) {
549 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
550 op->emitError("unconverted operation found");
551 return WalkResult::interrupt();
552 }
553 return WalkResult::advance();
554 });
555 return failure(status.wasInterrupted());
556}
557
558/// A pass converting SCF operations to OpenMP operations.
559struct SCFToOpenMPPass
561
562 using Base::Base;
563
564 /// Pass entry point.
565 void runOnOperation() override {
566 if (failed(applyPatterns(getOperation(), numThreads)))
567 signalPassFailure();
568 }
569};
570
571} // namespace
return success()
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ArrayAttr()
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region &region, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator_range< OpIterator > getOps()
Definition Region.h:172
unsigned getNumArguments()
Definition Region.h:123
BlockArgument getArgument(unsigned i)
Definition Region.h:124
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...