29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
41 template <
typename... OpTy>
43 if (block.
empty() || llvm::hasSingleElement(block) ||
44 std::next(block.
begin(), 2) != block.
end())
54 if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
57 return isa<OpTy...>(combinerOps[0]) &&
58 isa<scf::ReduceReturnOp>(block.
back()) &&
74 typename CompareOpTy,
typename SelectOpTy,
75 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
80 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
81 "only arithmetic and llvm select ops are supported");
84 if (block.
empty() || llvm::hasSingleElement(block) ||
85 std::next(block.
begin(), 2) == block.
end() ||
86 std::next(block.
begin(), 3) != block.
end())
91 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
92 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
93 if (!
compare || !select || !terminator)
102 if (llvm::is_contained(lessThanPredicates,
compare.getPredicate())) {
104 }
else if (llvm::is_contained(greaterThanPredicates,
111 if (select.getCondition() !=
compare.getResult())
119 constexpr
unsigned kTrueValue = 1;
120 constexpr
unsigned kFalseValue = 2;
121 bool sameOperands = select.getOperand(kTrueValue) ==
compare.getLhs() &&
122 select.getOperand(kFalseValue) ==
compare.getRhs();
123 bool swappedOperands = select.getOperand(kTrueValue) ==
compare.getRhs() &&
124 select.getOperand(kFalseValue) ==
compare.getLhs();
125 if (!sameOperands && !swappedOperands)
128 if (select.getResult() != terminator.getResult())
133 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
134 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
140 return llvm::APFloat::IEEEhalf();
142 return llvm::APFloat::IEEEsingle();
144 return llvm::APFloat::IEEEdouble();
146 return llvm::APFloat::IEEEquad();
148 return llvm::APFloat::BFloat();
150 return llvm::APFloat::x87DoubleExtended();
151 llvm_unreachable(
"unknown float type");
157 auto fltType = cast<FloatType>(type);
166 auto intType = cast<IntegerType>(type);
167 unsigned bitwidth = intType.getWidth();
169 : llvm::APInt::getSignedMaxValue(bitwidth));
176 auto intType = cast<IntegerType>(type);
177 unsigned bitwidth = intType.getWidth();
179 : llvm::APInt::getAllOnes(bitwidth));
186 static omp::DeclareReductionOp
192 "__scf_reduction", type);
196 decl.getInitializerRegion().end(), {type},
204 &
reduce.getReductions()[reductionIndex].front().back();
205 assert(isa<scf::ReduceReturnOp>(terminator) &&
206 "expected reduce op to be terminated by redure return");
209 terminator->getOperands());
211 decl.getReductionRegion(),
212 decl.getReductionRegion().end());
219 LLVM::AtomicBinOp atomicKind,
220 omp::DeclareReductionOp decl,
222 int64_t reductionIndex) {
226 builder.
createBlock(&decl.getAtomicReductionRegion(),
227 decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
228 {reduceOperandLoc, reduceOperandLoc});
229 Block *atomicBlock = &decl.getAtomicReductionRegion().
back();
235 LLVM::AtomicOrdering::monotonic);
246 int64_t reductionIndex) {
258 assert(llvm::hasSingleElement(
reduce.getReductions()[reductionIndex]) &&
259 "expected reduction region to have a single element");
263 Block &reduction =
reduce.getReductions()[reductionIndex].front();
264 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
265 omp::DeclareReductionOp decl =
271 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
272 omp::DeclareReductionOp decl =
278 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
279 omp::DeclareReductionOp decl =
285 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
286 omp::DeclareReductionOp decl =
292 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
294 builder, symbolTable,
reduce, reductionIndex,
304 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
308 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
315 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
316 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
319 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
324 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
325 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
328 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
330 omp::DeclareReductionOp decl =
335 decl,
reduce, reductionIndex);
337 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
338 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
341 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343 omp::DeclareReductionOp decl =
347 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348 decl,
reduce, reductionIndex);
357 static constexpr
unsigned kUseOpenMPDefaultNumThreads = 0;
361 unsigned numThreads = kUseOpenMPDefaultNumThreads)
364 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
371 auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
372 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
374 ompReductionDecls.push_back(decl);
377 reductionSyms.push_back(
387 reductionVariables.reserve(parallelOp.getNumReductions());
389 for (
Value init : parallelOp.getInitVals()) {
391 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
392 "cannot create a reduction variable if the type is not an LLVM "
395 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
396 rewriter.
create<LLVM::StoreOp>(loc, init, storage);
397 reductionVariables.push_back(storage);
403 for (
auto [x, y, rD] : llvm::zip_equal(
404 reductionVariables,
reduce.getOperands(), ompReductionDecls)) {
407 Region &redRegion = rD.getReductionRegion();
413 "expect reduction region to have one block");
414 Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
416 rD.getType(), pvtRedVar);
419 builder.setInsertionPoint(
reduce);
422 "expect reduction region to have two arguments");
425 for (
auto &op : redRegion.
getOps()) {
427 if (
auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
428 assert(yieldOp && yieldOp.getResults().size() == 1 &&
429 "expect YieldOp in reduction region to return one result");
430 Value redVal = yieldOp.getResults()[0];
431 rewriter.
create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
440 if (numThreads > 0) {
441 numThreadsVar = rewriter.
create<LLVM::ConstantOp>(
445 auto ompParallel = rewriter.
create<omp::ParallelOp>(
453 omp::ClauseProcBindKindAttr{},
466 auto wsloopOp = rewriter.
create<omp::WsloopOp>(parallelOp.getLoc());
467 if (!reductionVariables.empty()) {
468 wsloopOp.setReductionSymsAttr(
470 wsloopOp.getReductionVarsMutable().append(reductionVariables);
474 reductionByRef.resize(reductionVariables.size(),
false);
475 wsloopOp.setReductionByref(
478 rewriter.
create<omp::TerminatorOp>(loc);
483 reductionTypes.reserve(reductionVariables.size());
484 llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
487 &wsloopOp.getRegion(), {}, reductionTypes,
489 parallelOp.getLoc()));
492 auto loopOp = rewriter.
create<omp::LoopNestOp>(
493 parallelOp.getLoc(), parallelOp.getLowerBound(),
494 parallelOp.getUpperBound(), parallelOp.getStep());
497 loopOp.getRegion().begin());
502 unsigned numLoops = parallelOp.getNumLoops();
505 wsloopOp.getRegion().getArguments());
513 auto scope = rewriter.
create<memref::AllocaScopeOp>(parallelOp.getLoc(),
525 results.reserve(reductionVariables.size());
526 for (
auto [variable, type] :
527 llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
528 Value res = rewriter.
create<LLVM::LoadOp>(loc, type, variable);
529 results.push_back(res);
538 static LogicalResult
applyPatterns(ModuleOp module,
unsigned numThreads) {
540 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
541 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
542 memref::MemRefDialect>();
545 patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
551 struct SCFToOpenMPPass
552 :
public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
557 void runOnOperation()
override {
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::DeclareReductionOp decl, scf::ReduceOp reduce, int64_t reductionIndex)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm....
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce, int64_t reductionIndex)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it.
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
static omp::DeclareReductionOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class describes a specific conversion target.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
bool hasOneBlock()
Return true if this region has exactly one block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...