28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
40template <
typename... OpTy>
42 if (block.
empty() || llvm::hasSingleElement(block) ||
43 std::next(block.
begin(), 2) != block.
end())
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.
back()) &&
73 typename CompareOpTy,
typename SelectOpTy,
74 typename Predicate =
decltype(std::declval<CompareOpTy>().getPredicate())>
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
83 if (block.
empty() || llvm::hasSingleElement(block) ||
84 std::next(block.
begin(), 2) == block.
end() ||
85 std::next(block.
begin(), 3) != block.
end())
89 auto compare = dyn_cast<CompareOpTy>(block.
front());
90 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
92 if (!compare || !select || !terminator)
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
103 }
else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
110 if (select.getCondition() != compare.getResult())
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
127 if (select.getResult() != terminator.getResult())
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
139 return llvm::APFloat::IEEEhalf();
141 return llvm::APFloat::IEEEsingle();
143 return llvm::APFloat::IEEEdouble();
145 return llvm::APFloat::IEEEquad();
147 return llvm::APFloat::BFloat();
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable(
"unknown float type");
156 auto fltType = cast<FloatType>(type);
157 return FloatAttr::get(
165 auto intType = cast<IntegerType>(type);
166 unsigned bitwidth = intType.getWidth();
167 return IntegerAttr::get(type,
min ? llvm::APInt::getSignedMinValue(bitwidth)
168 : llvm::APInt::getSignedMaxValue(bitwidth));
175 auto intType = cast<IntegerType>(type);
176 unsigned bitwidth = intType.getWidth();
177 return IntegerAttr::get(type,
min ? llvm::APInt::getZero(bitwidth)
178 : llvm::APInt::getAllOnes(bitwidth));
185static omp::DeclareReductionOp
189 Type type =
reduce.getOperands()[reductionIndex].getType();
190 auto decl = omp::DeclareReductionOp::create(builder,
reduce.getLoc(),
191 "__scf_reduction", type);
195 decl.getInitializerRegion().end(), {type},
196 {reduce.getOperands()[reductionIndex].getLoc()});
199 LLVM::ConstantOp::create(builder,
reduce.getLoc(), type, initValue);
200 omp::YieldOp::create(builder,
reduce.getLoc(), init);
203 &
reduce.getReductions()[reductionIndex].front().back();
204 assert(isa<scf::ReduceReturnOp>(terminator) &&
205 "expected reduce op to be terminated by redure return");
210 decl.getReductionRegion(),
211 decl.getReductionRegion().end());
218 LLVM::AtomicBinOp atomicKind,
219 omp::DeclareReductionOp decl,
223 auto ptrType = LLVM::LLVMPointerType::get(builder.
getContext());
224 Location reduceOperandLoc =
reduce.getOperands()[reductionIndex].getLoc();
225 builder.
createBlock(&decl.getAtomicReductionRegion(),
226 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
227 {reduceOperandLoc, reduceOperandLoc});
228 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
230 Value loaded = LLVM::LoadOp::create(builder,
reduce.getLoc(), decl.getType(),
232 LLVM::AtomicRMWOp::create(builder,
reduce.getLoc(), atomicKind,
234 LLVM::AtomicOrdering::monotonic);
257 assert(llvm::hasSingleElement(
reduce.getReductions()[reductionIndex]) &&
258 "expected reduction region to have a single element");
261 Type type =
reduce.getOperands()[reductionIndex].getType();
262 Block &reduction =
reduce.getReductions()[reductionIndex].front();
264 omp::DeclareReductionOp decl =
271 omp::DeclareReductionOp decl =
278 omp::DeclareReductionOp decl =
285 omp::DeclareReductionOp decl =
293 builder, symbolTable,
reduce, reductionIndex,
315 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
316 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
319 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
324 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
325 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
328 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
329 omp::DeclareReductionOp decl =
333 isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
334 decl,
reduce, reductionIndex);
337 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
338 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
341 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
342 omp::DeclareReductionOp decl =
346 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
347 decl,
reduce, reductionIndex);
356 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
360 unsigned numThreads = kUseOpenMPDefaultNumThreads)
363 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
370 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
371 for (
int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
373 ompReductionDecls.push_back(decl);
376 reductionSyms.push_back(
377 SymbolRefAttr::get(rewriter.
getContext(), decl.getSymName()));
384 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getIntegerType(64),
387 reductionVariables.reserve(parallelOp.getNumReductions());
388 auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
389 for (
Value init : parallelOp.getInitVals()) {
391 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
392 "cannot create a reduction variable if the type is not an LLVM "
394 Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType,
395 init.getType(), one, 0);
396 LLVM::StoreOp::create(rewriter, loc, init, storage);
397 reductionVariables.push_back(storage);
403 for (
auto [x, y, rD] : llvm::zip_equal(
404 reductionVariables,
reduce.getOperands(), ompReductionDecls)) {
407 Region &redRegion = rD.getReductionRegion();
413 "expect reduction region to have one block");
414 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
415 Value pvtRedVal = LLVM::LoadOp::create(rewriter,
reduce.getLoc(),
416 rD.getType(), pvtRedVar);
419 builder.setInsertionPoint(
reduce);
422 "expect reduction region to have two arguments");
425 for (
auto &op : redRegion.
getOps()) {
427 if (
auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
428 assert(yieldOp && yieldOp.getResults().size() == 1 &&
429 "expect YieldOp in reduction region to return one result");
430 Value redVal = yieldOp.getResults()[0];
431 LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
440 if (numThreads > 0) {
441 numThreadsVar = LLVM::ConstantOp::create(
445 auto ompParallel = omp::ParallelOp::create(
454 omp::ClauseProcBindKindAttr{},
468 auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
469 if (!reductionVariables.empty()) {
470 wsloopOp.setReductionSymsAttr(
471 ArrayAttr::get(rewriter.
getContext(), reductionSyms));
472 wsloopOp.getReductionVarsMutable().append(reductionVariables);
476 reductionByRef.resize(reductionVariables.size(),
false);
477 wsloopOp.setReductionByref(
480 omp::TerminatorOp::create(rewriter, loc);
485 reductionTypes.reserve(reductionVariables.size());
486 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
489 &wsloopOp.getRegion(), {}, reductionTypes,
491 parallelOp.getLoc()));
494 auto loopOp = omp::LoopNestOp::create(
495 rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
496 parallelOp.getLowerBound(), parallelOp.getUpperBound(),
497 parallelOp.getStep(),
false,
501 loopOp.getRegion().begin());
506 unsigned numLoops = parallelOp.getNumLoops();
509 wsloopOp.getRegion().getArguments());
517 auto scope = memref::AllocaScopeOp::create(
518 rewriter, parallelOp.getLoc(),
TypeRange());
519 omp::YieldOp::create(rewriter, loc,
ValueRange());
523 memref::AllocaScopeReturnOp::create(rewriter, loc,
ValueRange());
529 results.reserve(reductionVariables.size());
530 for (
auto [variable, type] :
531 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
532 Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
533 results.push_back(res);
542static LogicalResult
applyPatterns(ModuleOp module,
unsigned numThreads) {
544 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
547 auto status =
module.walk([](Operation *op) {
548 if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) {
549 op->emitError("unconverted operation found");
554 return failure(status.wasInterrupted());
558struct SCFToOpenMPPass
564 void runOnOperation()
override {
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...