29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
41 template <
typename... OpTy>
43 if (block.
empty() || llvm::hasSingleElement(block) ||
44 std::next(block.
begin(), 2) != block.
end())
54 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
57 return isa<OpTy...>(combinerOps[0]) &&
58 isa<scf::ReduceReturnOp>(block.
back()) &&
74 typename CompareOpTy,
typename SelectOpTy,
75 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
80 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
81 "only arithmetic and llvm select ops are supported");
84 if (block.
empty() || llvm::hasSingleElement(block) ||
85 std::next(block.
begin(), 2) == block.
end() ||
86 std::next(block.
begin(), 3) != block.
end())
91 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
92 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
93 if (!
compare || !select || !terminator)
102 if (llvm::is_contained(lessThanPredicates,
compare.getPredicate())) {
104 }
else if (llvm::is_contained(greaterThanPredicates,
111 if (select.getCondition() !=
compare.getResult())
119 constexpr
unsigned kTrueValue = 1;
120 constexpr
unsigned kFalseValue = 2;
121 bool sameOperands = select.getOperand(kTrueValue) ==
compare.getLhs() &&
122 select.getOperand(kFalseValue) ==
compare.getRhs();
123 bool swappedOperands = select.getOperand(kTrueValue) ==
compare.getRhs() &&
124 select.getOperand(kFalseValue) ==
compare.getLhs();
125 if (!sameOperands && !swappedOperands)
128 if (select.getResult() != terminator.getResult())
133 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
134 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
140 return llvm::APFloat::IEEEhalf();
142 return llvm::APFloat::IEEEsingle();
144 return llvm::APFloat::IEEEdouble();
146 return llvm::APFloat::IEEEquad();
148 return llvm::APFloat::BFloat();
150 return llvm::APFloat::x87DoubleExtended();
151 llvm_unreachable(
"unknown float type");
157 auto fltType = cast<FloatType>(type);
166 auto intType = cast<IntegerType>(type);
167 unsigned bitwidth = intType.getWidth();
169 : llvm::APInt::getSignedMaxValue(bitwidth));
176 auto intType = cast<IntegerType>(type);
177 unsigned bitwidth = intType.getWidth();
179 : llvm::APInt::getAllOnes(bitwidth));
190 auto decl = builder.
create<omp::ReductionDeclareOp>(
196 decl.getInitializerRegion().end(), {type},
204 assert(isa<scf::ReduceReturnOp>(terminator) &&
205 "expected reduce op to be terminated by redure return");
208 terminator->getOperands());
210 decl.getReductionRegion().end());
217 bool useOpaquePointers) {
218 if (useOpaquePointers)
226 LLVM::AtomicBinOp atomicKind,
227 omp::ReductionDeclareOp decl,
229 bool useOpaquePointers) {
234 builder.
createBlock(&decl.getAtomicReductionRegion(),
235 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
236 {reduceOperandLoc, reduceOperandLoc});
237 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
243 LLVM::AtomicOrdering::monotonic);
254 bool useOpaquePointers) {
266 assert(llvm::hasSingleElement(
reduce.getRegion()) &&
267 "expected reduction region to have a single element");
272 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
278 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
284 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
290 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
296 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
298 builder, symbolTable,
reduce,
308 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
312 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
319 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
320 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
321 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
322 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
323 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
324 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
328 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
329 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
330 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
331 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
332 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
333 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
338 decl,
reduce, useOpaquePointers);
340 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
341 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
342 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
343 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
344 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
345 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
349 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
350 decl,
reduce, useOpaquePointers);
360 bool useOpaquePointers;
362 ParallelOpLowering(
MLIRContext *context,
bool useOpaquePointers)
364 useOpaquePointers(useOpaquePointers) {}
372 for (
auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
373 omp::ReductionDeclareOp decl =
377 reductionDeclSymbols.push_back(
387 reductionVariables.reserve(parallelOp.getNumReductions());
388 for (
Value init : parallelOp.getInitVals()) {
390 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
391 "cannot create a reduction variable if the type is not an LLVM "
395 init.getType(), one, 0);
396 rewriter.
create<LLVM::StoreOp>(loc, init, storage);
397 reductionVariables.push_back(storage);
404 llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
406 scf::ReduceOp reduceOp = std::get<0>(pair);
409 reduceOp, reduceOp.getOperand(), std::get<1>(pair));
413 auto ompParallel = rewriter.
create<omp::ParallelOp>(loc);
422 auto loop = rewriter.
create<omp::WsLoopOp>(
423 parallelOp.getLoc(), parallelOp.getLowerBound(),
424 parallelOp.getUpperBound(), parallelOp.getStep());
425 rewriter.
create<omp::TerminatorOp>(loc);
428 loop.getRegion().begin());
431 loop.getRegion().begin()->begin());
435 auto scope = rewriter.
create<memref::AllocaScopeOp>(parallelOp.getLoc(),
440 auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
443 oldYield, oldYield->getOperands());
444 if (!reductionVariables.empty()) {
445 loop.setReductionsAttr(
447 loop.getReductionVarsMutable().append(reductionVariables);
454 results.reserve(reductionVariables.size());
455 for (
auto [variable, type] :
456 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
457 Value res = rewriter.
create<LLVM::LoadOp>(loc, type, variable);
458 results.push_back(res);
469 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
470 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
471 memref::MemRefDialect>();
474 patterns.add<ParallelOpLowering>(module.getContext(), useOpaquePointers);
480 struct SCFToOpenMPPass
481 :
public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
486 void runOnOperation()
override {
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static LLVM::LLVMPointerType getPointerType(Type elementType, bool useOpaquePointers)
Returns an LLVM pointer type with the given element type, or an opaque pointer if 'useOpaquePointers'...
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, scf::ReduceOp reduce, bool useOpaquePointers)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, bool useOpaquePointers)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...