28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
40template <
typename... OpTy>
42 if (block.
empty() || llvm::hasSingleElement(block) ||
43 std::next(block.
begin(), 2) != block.
end())
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.
back()) &&
73 typename CompareOpTy,
typename SelectOpTy,
74 typename Predicate =
decltype(std::declval<CompareOpTy>().getPredicate())>
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
83 if (block.
empty() || llvm::hasSingleElement(block) ||
84 std::next(block.
begin(), 2) == block.
end() ||
85 std::next(block.
begin(), 3) != block.
end())
89 auto compare = dyn_cast<CompareOpTy>(block.
front());
90 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
92 if (!compare || !select || !terminator)
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103 }
else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
110 if (select.getCondition() != compare.getResult())
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
127 if (select.getResult() != terminator.getResult())
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
139 return llvm::APFloat::IEEEhalf();
141 return llvm::APFloat::IEEEsingle();
143 return llvm::APFloat::IEEEdouble();
145 return llvm::APFloat::IEEEquad();
147 return llvm::APFloat::BFloat();
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable(
"unknown float type");
156 if (
auto vecType = dyn_cast<VectorType>(type))
165 auto fltType = cast<FloatType>(elType);
176 auto intType = cast<IntegerType>(elType);
177 unsigned bitwidth = intType.getWidth();
178 auto val =
min ? llvm::APInt::getSignedMinValue(bitwidth)
179 : llvm::APInt::getSignedMaxValue(bitwidth);
189 auto intType = cast<IntegerType>(elType);
190 unsigned bitwidth = intType.getWidth();
192 min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
201static omp::DeclareReductionOp
205 Type type =
reduce.getOperands()[reductionIndex].getType();
206 auto decl = omp::DeclareReductionOp::create(builder,
reduce.getLoc(),
207 "__scf_reduction", type,
212 decl.getInitializerRegion().end(), {type},
213 {reduce.getOperands()[reductionIndex].getLoc()});
216 LLVM::ConstantOp::create(builder,
reduce.getLoc(), type, initValue);
217 omp::YieldOp::create(builder,
reduce.getLoc(), init);
220 &
reduce.getReductions()[reductionIndex].front().back();
221 assert(isa<scf::ReduceReturnOp>(terminator) &&
222 "expected reduce op to be terminated by reduce return");
227 decl.getReductionRegion(),
228 decl.getReductionRegion().end());
235 LLVM::AtomicBinOp atomicKind,
236 omp::DeclareReductionOp decl,
240 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
241 Location reduceOperandLoc =
reduce.getOperands()[reductionIndex].getLoc();
242 builder.
createBlock(&decl.getAtomicReductionRegion(),
243 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
244 {reduceOperandLoc, reduceOperandLoc});
245 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
247 Value loaded = LLVM::LoadOp::create(builder,
reduce.getLoc(), decl.getType(),
249 LLVM::AtomicRMWOp::create(builder,
reduce.getLoc(), atomicKind,
251 LLVM::AtomicOrdering::monotonic);
279 assert(llvm::hasSingleElement(
reduce.getReductions()[reductionIndex]) &&
280 "expected reduction region to have a single element");
283 Type type =
reduce.getOperands()[reductionIndex].getType();
284 Block &reduction =
reduce.getReductions()[reductionIndex].front();
292 builder, symbolTable,
reduce, reductionIndex,
295 decl,
reduce, reductionIndex)
300 builder, symbolTable,
reduce, reductionIndex,
303 decl,
reduce, reductionIndex)
308 builder, symbolTable,
reduce, reductionIndex,
311 decl,
reduce, reductionIndex)
316 builder, symbolTable,
reduce, reductionIndex,
319 decl,
reduce, reductionIndex)
325 builder, symbolTable,
reduce, reductionIndex,
328 decl,
reduce, reductionIndex)
337 builder, symbolTable,
reduce, reductionIndex,
343 builder, symbolTable,
reduce, reductionIndex,
351 arith::CmpFPredicate>(
352 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
353 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
355 arith::CmpFPredicate>(
356 reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
357 {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
364 arith::CmpIPredicate>(
365 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
366 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
368 arith::CmpIPredicate>(
369 reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
370 {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
371 omp::DeclareReductionOp decl =
375 isMin ? LLVM::AtomicBinOp::min
376 : LLVM::AtomicBinOp::max,
377 decl,
reduce, reductionIndex)
383 arith::CmpIPredicate>(
384 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
385 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
387 arith::CmpIPredicate>(
388 reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
389 {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
390 omp::DeclareReductionOp decl =
394 isMin ? LLVM::AtomicBinOp::umin
395 : LLVM::AtomicBinOp::umax,
396 decl,
reduce, reductionIndex)
406 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
410 unsigned numThreads = kUseOpenMPDefaultNumThreads)
413 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
420 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
421 for (
int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
423 ompReductionDecls.push_back(decl);
426 reductionSyms.push_back(
427 SymbolRefAttr::get(rewriter.
getContext(), decl.getSymName()));
434 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getIntegerType(64),
437 reductionVariables.reserve(parallelOp.getNumReductions());
438 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
439 for (
Value init : parallelOp.getInitVals()) {
441 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
442 "cannot create a reduction variable if the type is not an LLVM "
444 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
445 init.getType(), one, 0);
446 LLVM::StoreOp::create(rewriter, loc, init, storage);
447 reductionVariables.push_back(storage);
453 for (
auto [x, y, rD] : llvm::zip_equal(
454 reductionVariables,
reduce.getOperands(), ompReductionDecls)) {
457 Region &redRegion = rD.getReductionRegion();
463 "expect reduction region to have one block");
464 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
465 Value pvtRedVal = LLVM::LoadOp::create(rewriter,
reduce.getLoc(),
466 rD.getType(), pvtRedVar);
469 builder.setInsertionPoint(
reduce);
472 "expect reduction region to have two arguments");
475 for (
auto &op : redRegion.
getOps()) {
477 if (
auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
478 assert(yieldOp && yieldOp.getResults().size() == 1 &&
479 "expect YieldOp in reduction region to return one result");
480 Value redVal = yieldOp.getResults()[0];
481 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
490 if (numThreads > 0) {
491 Value numThreadsVar = LLVM::ConstantOp::create(
493 numThreadsVars.push_back(numThreadsVar);
496 auto ompParallel = omp::ParallelOp::create(
505 omp::ClauseProcBindKindAttr{},
519 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
520 if (!reductionVariables.empty()) {
521 wsloopOp.setReductionSymsAttr(
522 ArrayAttr::get(rewriter.
getContext(), reductionSyms));
523 wsloopOp.getReductionVarsMutable().append(reductionVariables);
527 reductionByRef.resize(reductionVariables.size(),
false);
528 wsloopOp.setReductionByref(
531 omp::TerminatorOp::create(rewriter, loc);
536 reductionTypes.reserve(reductionVariables.size());
537 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
540 &wsloopOp.getRegion(), {}, reductionTypes,
542 parallelOp.getLoc()));
545 auto loopOp = omp::LoopNestOp::create(
546 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
547 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
548 parallelOp.getStep(),
false,
552 loopOp.getRegion().begin());
557 unsigned numLoops = parallelOp.getNumLoops();
560 wsloopOp.getRegion().getArguments());
568 auto scope = memref::AllocaScopeOp::create(
569 rewriter, parallelOp.getLoc(),
TypeRange());
570 omp::YieldOp::create(rewriter, loc,
ValueRange());
574 memref::AllocaScopeReturnOp::create(rewriter, loc,
ValueRange());
580 results.reserve(reductionVariables.size());
581 for (
auto [variable, type] :
582 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
583 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
584 results.push_back(res);
593static LogicalResult
applyPatterns(ModuleOp module,
unsigned numThreads) {
595 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
598 auto status =
module.walk([](Operation *op) {
599 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
600 op->emitError("unconverted operation found");
605 return failure(status.wasInterrupted());
609struct SCFToOpenMPPass
610 :
public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
615 void runOnOperation()
override {
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static Attribute getSplatOrScalarAttr(Type type, Attribute val)
Helper to create a splat attribute for vector types, or return the scalar attribute for scalar types.
static bool supportsAtomic(Type type)
Returns true if the type is supported by llvm.atomicrmw.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...