28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
40template <
typename... OpTy>
42 if (block.
empty() || llvm::hasSingleElement(block) ||
43 std::next(block.
begin(), 2) != block.
end())
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.
back()) &&
73 typename CompareOpTy,
typename SelectOpTy,
74 typename Predicate =
decltype(std::declval<CompareOpTy>().getPredicate())>
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
83 if (block.
empty() || llvm::hasSingleElement(block) ||
84 std::next(block.
begin(), 2) == block.
end() ||
85 std::next(block.
begin(), 3) != block.
end())
89 auto compare = dyn_cast<CompareOpTy>(block.
front());
90 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
92 if (!compare || !select || !terminator)
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103 }
else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
110 if (select.getCondition() != compare.getResult())
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
127 if (select.getResult() != terminator.getResult())
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
139 return llvm::APFloat::IEEEhalf();
141 return llvm::APFloat::IEEEsingle();
143 return llvm::APFloat::IEEEdouble();
145 return llvm::APFloat::IEEEquad();
147 return llvm::APFloat::BFloat();
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable(
"unknown float type");
156 auto fltType = cast<FloatType>(type);
157 return FloatAttr::get(
165 auto intType = cast<IntegerType>(type);
166 unsigned bitwidth = intType.getWidth();
167 return IntegerAttr::get(type,
min ? llvm::APInt::getSignedMinValue(bitwidth)
168 : llvm::APInt::getSignedMaxValue(bitwidth));
175 auto intType = cast<IntegerType>(type);
176 unsigned bitwidth = intType.getWidth();
177 return IntegerAttr::get(type,
min ? llvm::APInt::getZero(bitwidth)
178 : llvm::APInt::getAllOnes(bitwidth));
185static omp::DeclareReductionOp
189 Type type =
reduce.getOperands()[reductionIndex].getType();
190 auto decl = omp::DeclareReductionOp::create(builder,
reduce.getLoc(),
191 "__scf_reduction", type,
196 decl.getInitializerRegion().end(), {type},
197 {reduce.getOperands()[reductionIndex].getLoc()});
200 LLVM::ConstantOp::create(builder,
reduce.getLoc(), type, initValue);
201 omp::YieldOp::create(builder,
reduce.getLoc(), init);
204 &
reduce.getReductions()[reductionIndex].front().back();
205 assert(isa<scf::ReduceReturnOp>(terminator) &&
206 "expected reduce op to be terminated by redure return");
211 decl.getReductionRegion(),
212 decl.getReductionRegion().end());
219 LLVM::AtomicBinOp atomicKind,
220 omp::DeclareReductionOp decl,
224 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
225 Location reduceOperandLoc =
reduce.getOperands()[reductionIndex].getLoc();
226 builder.
createBlock(&decl.getAtomicReductionRegion(),
227 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
228 {reduceOperandLoc, reduceOperandLoc});
229 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
231 Value loaded = LLVM::LoadOp::create(builder,
reduce.getLoc(), decl.getType(),
233 LLVM::AtomicRMWOp::create(builder,
reduce.getLoc(), atomicKind,
235 LLVM::AtomicOrdering::monotonic);
258 assert(llvm::hasSingleElement(
reduce.getReductions()[reductionIndex]) &&
259 "expected reduction region to have a single element");
262 Type type =
reduce.getOperands()[reductionIndex].getType();
263 Block &reduction =
reduce.getReductions()[reductionIndex].front();
265 omp::DeclareReductionOp decl =
272 omp::DeclareReductionOp decl =
279 omp::DeclareReductionOp decl =
286 omp::DeclareReductionOp decl =
294 builder, symbolTable,
reduce, reductionIndex,
316 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
319 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
325 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
328 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
330 omp::DeclareReductionOp decl =
334 isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
335 decl,
reduce, reductionIndex);
338 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
341 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343 omp::DeclareReductionOp decl =
347 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348 decl,
reduce, reductionIndex);
357 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
361 unsigned numThreads = kUseOpenMPDefaultNumThreads)
364 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
371 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
372 for (
int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
374 ompReductionDecls.push_back(decl);
377 reductionSyms.push_back(
378 SymbolRefAttr::get(rewriter.
getContext(), decl.getSymName()));
385 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getIntegerType(64),
388 reductionVariables.reserve(parallelOp.getNumReductions());
389 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
390 for (
Value init : parallelOp.getInitVals()) {
392 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
393 "cannot create a reduction variable if the type is not an LLVM "
395 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
396 init.getType(), one, 0);
397 LLVM::StoreOp::create(rewriter, loc, init, storage);
398 reductionVariables.push_back(storage);
404 for (
auto [x, y, rD] : llvm::zip_equal(
405 reductionVariables,
reduce.getOperands(), ompReductionDecls)) {
408 Region &redRegion = rD.getReductionRegion();
414 "expect reduction region to have one block");
415 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
416 Value pvtRedVal = LLVM::LoadOp::create(rewriter,
reduce.getLoc(),
417 rD.getType(), pvtRedVar);
420 builder.setInsertionPoint(
reduce);
423 "expect reduction region to have two arguments");
426 for (
auto &op : redRegion.
getOps()) {
428 if (
auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
429 assert(yieldOp && yieldOp.getResults().size() == 1 &&
430 "expect YieldOp in reduction region to return one result");
431 Value redVal = yieldOp.getResults()[0];
432 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
441 if (numThreads > 0) {
442 numThreadsVar = LLVM::ConstantOp::create(
446 auto ompParallel = omp::ParallelOp::create(
455 omp::ClauseProcBindKindAttr{},
469 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
470 if (!reductionVariables.empty()) {
471 wsloopOp.setReductionSymsAttr(
472 ArrayAttr::get(rewriter.
getContext(), reductionSyms));
473 wsloopOp.getReductionVarsMutable().append(reductionVariables);
477 reductionByRef.resize(reductionVariables.size(),
false);
478 wsloopOp.setReductionByref(
481 omp::TerminatorOp::create(rewriter, loc);
486 reductionTypes.reserve(reductionVariables.size());
487 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
490 &wsloopOp.getRegion(), {}, reductionTypes,
492 parallelOp.getLoc()));
495 auto loopOp = omp::LoopNestOp::create(
496 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
497 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
498 parallelOp.getStep(),
false,
502 loopOp.getRegion().begin());
507 unsigned numLoops = parallelOp.getNumLoops();
510 wsloopOp.getRegion().getArguments());
518 auto scope = memref::AllocaScopeOp::create(
519 rewriter, parallelOp.getLoc(),
TypeRange());
520 omp::YieldOp::create(rewriter, loc,
ValueRange());
524 memref::AllocaScopeReturnOp::create(rewriter, loc,
ValueRange());
530 results.reserve(reductionVariables.size());
531 for (
auto [variable, type] :
532 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
533 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
534 results.push_back(res);
543static LogicalResult
applyPatterns(ModuleOp module,
unsigned numThreads) {
545 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
548 auto status =
module.walk([](Operation *op) {
549 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
550 op->emitError("unconverted operation found");
555 return failure(status.wasInterrupted());
559struct SCFToOpenMPPass
565 void runOnOperation()
override {
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...