28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
40template <
typename... OpTy>
42 if (block.
empty() || llvm::hasSingleElement(block) ||
43 std::next(block.
begin(), 2) != block.
end())
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.
back()) &&
73 typename CompareOpTy,
typename SelectOpTy,
74 typename Predicate =
decltype(std::declval<CompareOpTy>().getPredicate())>
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
83 if (block.
empty() || llvm::hasSingleElement(block) ||
84 std::next(block.
begin(), 2) == block.
end() ||
85 std::next(block.
begin(), 3) != block.
end())
89 auto compare = dyn_cast<CompareOpTy>(block.
front());
90 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
92 if (!compare || !select || !terminator)
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103 }
else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
110 if (select.getCondition() != compare.getResult())
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
127 if (select.getResult() != terminator.getResult())
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
139 return llvm::APFloat::IEEEhalf();
141 return llvm::APFloat::IEEEsingle();
143 return llvm::APFloat::IEEEdouble();
145 return llvm::APFloat::IEEEquad();
147 return llvm::APFloat::BFloat();
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable(
"unknown float type");
156 if (
auto vecType = dyn_cast<VectorType>(type))
165 auto fltType = cast<FloatType>(elType);
176 auto intType = cast<IntegerType>(elType);
177 unsigned bitwidth = intType.getWidth();
178 auto val =
min ? llvm::APInt::getSignedMinValue(bitwidth)
179 : llvm::APInt::getSignedMaxValue(bitwidth);
189 auto intType = cast<IntegerType>(elType);
190 unsigned bitwidth = intType.getWidth();
192 min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
201static omp::DeclareReductionOp
205 Type type =
reduce.getOperands()[reductionIndex].getType();
206 auto decl = omp::DeclareReductionOp::create(builder,
reduce.getLoc(),
207 "__scf_reduction", type,
212 decl.getInitializerRegion().end(), {type},
213 {reduce.getOperands()[reductionIndex].getLoc()});
216 LLVM::ConstantOp::create(builder,
reduce.getLoc(), type, initValue);
217 omp::YieldOp::create(builder,
reduce.getLoc(), init);
220 &
reduce.getReductions()[reductionIndex].front().back();
221 assert(isa<scf::ReduceReturnOp>(terminator) &&
222 "expected reduce op to be terminated by reduce return");
227 decl.getReductionRegion(),
228 decl.getReductionRegion().end());
235 LLVM::AtomicBinOp atomicKind,
236 omp::DeclareReductionOp decl,
240 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
241 Location reduceOperandLoc =
reduce.getOperands()[reductionIndex].getLoc();
242 builder.
createBlock(&decl.getAtomicReductionRegion(),
243 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
244 {reduceOperandLoc, reduceOperandLoc});
245 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
247 Value loaded = LLVM::LoadOp::create(builder,
reduce.getLoc(), decl.getType(),
249 LLVM::AtomicRMWOp::create(builder,
reduce.getLoc(), atomicKind,
251 LLVM::AtomicOrdering::monotonic);
279 assert(llvm::hasSingleElement(
reduce.getReductions()[reductionIndex]) &&
280 "expected reduction region to have a single element");
283 Type type =
reduce.getOperands()[reductionIndex].getType();
284 Block &reduction =
reduce.getReductions()[reductionIndex].front();
292 builder, symbolTable,
reduce, reductionIndex,
295 decl,
reduce, reductionIndex)
300 builder, symbolTable,
reduce, reductionIndex,
303 decl,
reduce, reductionIndex)
308 builder, symbolTable,
reduce, reductionIndex,
311 decl,
reduce, reductionIndex)
316 builder, symbolTable,
reduce, reductionIndex,
319 decl,
reduce, reductionIndex)
325 builder, symbolTable,
reduce, reductionIndex,
328 decl,
reduce, reductionIndex)
337 builder, symbolTable,
reduce, reductionIndex,
343 builder, symbolTable,
reduce, reductionIndex,
351 arith::CmpFPredicate>(
352 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
353 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
355 arith::CmpFPredicate>(
356 reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
357 {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
364 arith::CmpIPredicate>(
365 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
366 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
368 arith::CmpIPredicate>(
369 reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
370 {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
371 omp::DeclareReductionOp decl =
375 isMin ? LLVM::AtomicBinOp::min
376 : LLVM::AtomicBinOp::max,
377 decl,
reduce, reductionIndex)
383 arith::CmpIPredicate>(
384 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
385 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
387 arith::CmpIPredicate>(
388 reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
389 {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
390 omp::DeclareReductionOp decl =
394 isMin ? LLVM::AtomicBinOp::umin
395 : LLVM::AtomicBinOp::umax,
396 decl,
reduce, reductionIndex)
406 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
410 unsigned numThreads = kUseOpenMPDefaultNumThreads)
413 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
418 for (
Value init : parallelOp.getInitVals()) {
420 !isa<LLVM::PointerElementTypeInterface>(init.getType()))
422 parallelOp,
"reduction init type is not an LLVM-compatible type");
430 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
431 for (
int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
433 ompReductionDecls.push_back(decl);
436 reductionSyms.push_back(
437 SymbolRefAttr::get(rewriter.
getContext(), decl.getSymName()));
444 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getIntegerType(64),
447 reductionVariables.reserve(parallelOp.getNumReductions());
448 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
449 for (
Value init : parallelOp.getInitVals()) {
450 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
451 init.getType(), one, 0);
452 LLVM::StoreOp::create(rewriter, loc, init, storage);
453 reductionVariables.push_back(storage);
459 for (
auto [x, y, rD] : llvm::zip_equal(
460 reductionVariables,
reduce.getOperands(), ompReductionDecls)) {
463 Region &redRegion = rD.getReductionRegion();
469 "expect reduction region to have one block");
470 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
471 Value pvtRedVal = LLVM::LoadOp::create(rewriter,
reduce.getLoc(),
472 rD.getType(), pvtRedVar);
475 builder.setInsertionPoint(
reduce);
478 "expect reduction region to have two arguments");
481 for (
auto &op : redRegion.
getOps()) {
483 if (
auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
484 assert(yieldOp && yieldOp.getResults().size() == 1 &&
485 "expect YieldOp in reduction region to return one result");
486 Value redVal = yieldOp.getResults()[0];
487 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
496 if (numThreads > 0) {
497 Value numThreadsVar = LLVM::ConstantOp::create(
499 numThreadsVars.push_back(numThreadsVar);
502 auto ompParallel = omp::ParallelOp::create(
511 omp::ClauseProcBindKindAttr{},
525 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
526 if (!reductionVariables.empty()) {
527 wsloopOp.setReductionSymsAttr(
528 ArrayAttr::get(rewriter.
getContext(), reductionSyms));
529 wsloopOp.getReductionVarsMutable().append(reductionVariables);
533 reductionByRef.resize(reductionVariables.size(),
false);
534 wsloopOp.setReductionByref(
537 omp::TerminatorOp::create(rewriter, loc);
542 reductionTypes.reserve(reductionVariables.size());
543 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
546 &wsloopOp.getRegion(), {}, reductionTypes,
548 parallelOp.getLoc()));
551 auto loopOp = omp::LoopNestOp::create(
552 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
553 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
554 parallelOp.getStep(),
false,
558 loopOp.getRegion().begin());
563 unsigned numLoops = parallelOp.getNumLoops();
566 wsloopOp.getRegion().getArguments());
574 auto scope = memref::AllocaScopeOp::create(
575 rewriter, parallelOp.getLoc(),
TypeRange());
576 omp::YieldOp::create(rewriter, loc,
ValueRange());
580 memref::AllocaScopeReturnOp::create(rewriter, loc,
ValueRange());
586 results.reserve(reductionVariables.size());
587 for (
auto [variable, type] :
588 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
589 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
590 results.push_back(res);
599static LogicalResult
applyPatterns(ModuleOp module,
unsigned numThreads) {
601 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
604 auto status =
module.walk([](Operation *op) {
605 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
606 op->emitError("unconverted operation found");
611 return failure(status.wasInterrupted());
615struct SCFToOpenMPPass
616 :
public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
621 void runOnOperation()
override {
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static Attribute getSplatOrScalarAttr(Type type, Attribute val)
Helper to create a splat attribute for vector types, or return the scalar attribute for scalar types.
static bool supportsAtomic(Type type)
Returns true if the type is supported by llvm.atomicrmw.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...