28 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29 #include "mlir/Conversion/Passes.h.inc"
40 template <
typename... OpTy>
42 if (block.
empty() || llvm::hasSingleElement(block) ||
43 std::next(block.
begin(), 2) != block.
end())
53 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(block.
back()) &&
73 typename CompareOpTy,
typename SelectOpTy,
74 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
83 if (block.
empty() || llvm::hasSingleElement(block) ||
84 std::next(block.
begin(), 2) == block.
end() ||
85 std::next(block.
begin(), 3) != block.
end())
90 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
92 if (!
compare || !select || !terminator)
101 if (llvm::is_contained(lessThanPredicates,
compare.getPredicate())) {
103 }
else if (llvm::is_contained(greaterThanPredicates,
110 if (select.getCondition() !=
compare.getResult())
118 constexpr
unsigned kTrueValue = 1;
119 constexpr
unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) ==
compare.getLhs() &&
121 select.getOperand(kFalseValue) ==
compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) ==
compare.getRhs() &&
123 select.getOperand(kFalseValue) ==
compare.getLhs();
124 if (!sameOperands && !swappedOperands)
127 if (select.getResult() != terminator.getResult())
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
139 return llvm::APFloat::IEEEhalf();
141 return llvm::APFloat::IEEEsingle();
143 return llvm::APFloat::IEEEdouble();
145 return llvm::APFloat::IEEEquad();
147 return llvm::APFloat::BFloat();
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable(
"unknown float type");
156 auto fltType = cast<FloatType>(type);
165 auto intType = cast<IntegerType>(type);
166 unsigned bitwidth = intType.getWidth();
168 : llvm::APInt::getSignedMaxValue(bitwidth));
175 auto intType = cast<IntegerType>(type);
176 unsigned bitwidth = intType.getWidth();
178 : llvm::APInt::getAllOnes(bitwidth));
185 static omp::DeclareReductionOp
191 "__scf_reduction", type);
195 decl.getInitializerRegion().end(), {type},
203 &
reduce.getReductions()[reductionIndex].front().back();
204 assert(isa<scf::ReduceReturnOp>(terminator) &&
205 "expected reduce op to be terminated by redure return");
208 terminator->getOperands());
210 decl.getReductionRegion(),
211 decl.getReductionRegion().end());
218 LLVM::AtomicBinOp atomicKind,
219 omp::DeclareReductionOp decl,
221 int64_t reductionIndex) {
225 builder.
createBlock(&decl.getAtomicReductionRegion(),
226 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
227 {reduceOperandLoc, reduceOperandLoc});
228 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
234 LLVM::AtomicOrdering::monotonic);
245 int64_t reductionIndex) {
257 assert(llvm::hasSingleElement(
reduce.getReductions()[reductionIndex]) &&
258 "expected reduction region to have a single element");
262 Block &reduction =
reduce.getReductions()[reductionIndex].front();
263 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
264 omp::DeclareReductionOp decl =
270 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
271 omp::DeclareReductionOp decl =
277 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
278 omp::DeclareReductionOp decl =
284 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
285 omp::DeclareReductionOp decl =
291 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
293 builder, symbolTable,
reduce, reductionIndex,
303 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
307 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
314 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
315 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
316 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
317 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
318 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
319 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
323 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
324 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
325 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
326 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
327 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
328 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
329 omp::DeclareReductionOp decl =
334 decl,
reduce, reductionIndex);
336 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
337 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
338 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
339 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
340 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
341 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
342 omp::DeclareReductionOp decl =
346 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
347 decl,
reduce, reductionIndex);
356 static constexpr
unsigned kUseOpenMPDefaultNumThreads = 0;
360 unsigned numThreads = kUseOpenMPDefaultNumThreads)
363 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
370 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
371 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
373 ompReductionDecls.push_back(decl);
376 reductionSyms.push_back(
386 reductionVariables.reserve(parallelOp.getNumReductions());
388 for (
Value init : parallelOp.getInitVals()) {
390 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
391 "cannot create a reduction variable if the type is not an LLVM "
394 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
395 rewriter.
create<LLVM::StoreOp>(loc, init, storage);
396 reductionVariables.push_back(storage);
402 for (
auto [x, y, rD] : llvm::zip_equal(
403 reductionVariables,
reduce.getOperands(), ompReductionDecls)) {
406 Region &redRegion = rD.getReductionRegion();
412 "expect reduction region to have one block");
413 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
415 rD.getType(), pvtRedVar);
418 builder.setInsertionPoint(
reduce);
421 "expect reduction region to have two arguments");
424 for (
auto &op : redRegion.
getOps()) {
426 if (
auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
427 assert(yieldOp && yieldOp.getResults().size() == 1 &&
428 "expect YieldOp in reduction region to return one result");
429 Value redVal = yieldOp.getResults()[0];
430 rewriter.
create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
439 if (numThreads > 0) {
440 numThreadsVar = rewriter.
create<LLVM::ConstantOp>(
444 auto ompParallel = rewriter.
create<omp::ParallelOp>(
453 omp::ClauseProcBindKindAttr{},
467 auto wsloopOp = rewriter.
create<omp::WsloopOp>(parallelOp.getLoc());
468 if (!reductionVariables.empty()) {
469 wsloopOp.setReductionSymsAttr(
471 wsloopOp.getReductionVarsMutable().append(reductionVariables);
475 reductionByRef.resize(reductionVariables.size(),
false);
476 wsloopOp.setReductionByref(
479 rewriter.
create<omp::TerminatorOp>(loc);
484 reductionTypes.reserve(reductionVariables.size());
485 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
488 &wsloopOp.getRegion(), {}, reductionTypes,
490 parallelOp.getLoc()));
493 auto loopOp = rewriter.
create<omp::LoopNestOp>(
494 parallelOp.getLoc(), parallelOp.getLowerBound(),
495 parallelOp.getUpperBound(), parallelOp.getStep());
498 loopOp.getRegion().begin());
503 unsigned numLoops = parallelOp.getNumLoops();
506 wsloopOp.getRegion().getArguments());
514 auto scope = rewriter.
create<memref::AllocaScopeOp>(parallelOp.getLoc(),
526 results.reserve(reductionVariables.size());
527 for (
auto [variable, type] :
528 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
529 Value res = rewriter.
create<LLVM::LoadOp>(loc, type, variable);
530 results.push_back(res);
539 static LogicalResult
applyPatterns(ModuleOp module,
unsigned numThreads) {
541 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
542 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
543 memref::MemRefDialect>();
546 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
552 struct SCFToOpenMPPass
553 :
public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
558 void runOnOperation()
override {
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...