15 #include "../PassDetail.h" 35 template <
typename... OpTy>
37 if (block.
empty() || llvm::hasSingleElement(block) ||
38 std::next(block.
begin(), 2) != block.
end())
49 combinerOps.size() != 1)
52 return isa<OpTy...>(combinerOps[0]) &&
53 isa<scf::ReduceReturnOp>(block.
back()) &&
69 typename CompareOpTy,
typename SelectOpTy,
70 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
76 "only arithmetic and llvm select ops are supported");
79 if (block.
empty() || llvm::hasSingleElement(block) ||
80 std::next(block.
begin(), 2) == block.
end() ||
81 std::next(block.
begin(), 3) != block.
end())
86 auto select = dyn_cast<SelectOpTy>(block.
front().getNextNode());
87 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.
back());
88 if (!
compare || !select || !terminator)
97 if (llvm::is_contained(lessThanPredicates,
compare.getPredicate())) {
99 }
else if (llvm::is_contained(greaterThanPredicates,
106 if (select.getCondition() !=
compare.getResult())
114 constexpr
unsigned kTrueValue = 1;
115 constexpr
unsigned kFalseValue = 2;
116 bool sameOperands = select.getOperand(kTrueValue) ==
compare.getLhs() &&
117 select.getOperand(kFalseValue) ==
compare.getRhs();
118 bool swappedOperands = select.getOperand(kTrueValue) ==
compare.getRhs() &&
119 select.getOperand(kFalseValue) ==
compare.getLhs();
120 if (!sameOperands && !swappedOperands)
123 if (select.getResult() != terminator.getResult())
128 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
129 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
135 return llvm::APFloat::IEEEhalf();
137 return llvm::APFloat::IEEEsingle();
139 return llvm::APFloat::IEEEdouble();
141 return llvm::APFloat::IEEEquad();
143 return llvm::APFloat::BFloat();
145 return llvm::APFloat::x87DoubleExtended();
146 llvm_unreachable(
"unknown float type");
153 return FloatAttr::get(
161 auto intType = type.
cast<IntegerType>();
162 unsigned bitwidth = intType.getWidth();
163 return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
164 : llvm::APInt::getSignedMaxValue(bitwidth));
171 auto intType = type.
cast<IntegerType>();
172 unsigned bitwidth = intType.getWidth();
173 return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth)
174 : llvm::APInt::getAllOnesValue(bitwidth));
182 scf::ReduceOp reduce,
185 auto decl = builder.
create<omp::ReductionDeclareOp>(
186 reduce.getLoc(),
"__scf_reduction", reduce.getOperand().getType());
189 Type type = reduce.getOperand().getType();
190 builder.
createBlock(&decl.initializerRegion(), decl.initializerRegion().end(),
191 {type}, {reduce.getOperand().getLoc()});
194 builder.
create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
195 builder.
create<omp::YieldOp>(reduce.getLoc(), init);
198 assert(isa<scf::ReduceReturnOp>(terminator) &&
199 "expected reduce op to be terminated by redure return");
202 terminator->getOperands());
204 decl.reductionRegion().end());
211 LLVM::AtomicBinOp atomicKind,
212 omp::ReductionDeclareOp decl,
213 scf::ReduceOp reduce) {
215 Type type = reduce.getOperand().getType();
217 Location reduceOperandLoc = reduce.getOperand().getLoc();
219 decl.atomicReductionRegion().end(), {ptrType, ptrType},
220 {reduceOperandLoc, reduceOperandLoc});
221 Block *atomicBlock = &decl.atomicReductionRegion().
back();
223 Value loaded = builder.
create<LLVM::LoadOp>(reduce.getLoc(),
225 builder.
create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind,
227 LLVM::AtomicOrdering::monotonic);
237 scf::ReduceOp reduce) {
249 assert(llvm::hasSingleElement(reduce.getRegion()) &&
250 "expected reduction region to have a single element");
253 Type type = reduce.getOperand().getType();
254 Block &reduction = reduce.getRegion().
front();
255 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
256 omp::ReductionDeclareOp decl =
createDecl(builder, symbolTable, reduce,
258 return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
260 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
261 omp::ReductionDeclareOp decl =
createDecl(builder, symbolTable, reduce,
263 return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
265 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
266 omp::ReductionDeclareOp decl =
createDecl(builder, symbolTable, reduce,
268 return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
270 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
271 omp::ReductionDeclareOp decl =
createDecl(builder, symbolTable, reduce,
273 return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
275 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
277 builder, symbolTable, reduce,
280 return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
286 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
287 return createDecl(builder, symbolTable, reduce,
293 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
294 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
295 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
296 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
297 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
298 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
299 return createDecl(builder, symbolTable, reduce,
302 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
303 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
304 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
305 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
306 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
307 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
314 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
315 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
316 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
317 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
318 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
319 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
323 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
341 for (
auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
345 reductionDeclSymbols.push_back(
346 SymbolRefAttr::get(rewriter.
getContext(), decl.sym_name()));
355 reductionVariables.reserve(parallelOp.getNumReductions());
356 for (
Value init : parallelOp.getInitVals()) {
358 init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
359 "cannot create a reduction variable if the type is not an LLVM " 363 rewriter.
create<LLVM::StoreOp>(loc, init, storage);
364 reductionVariables.push_back(storage);
371 llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
373 scf::ReduceOp reduceOp = std::get<0>(pair);
376 reduceOp, reduceOp.getOperand(), std::get<1>(pair));
380 auto ompParallel = rewriter.
create<omp::ParallelOp>(loc);
389 auto loop = rewriter.
create<omp::WsLoopOp>(
390 parallelOp.getLoc(), parallelOp.getLowerBound(),
391 parallelOp.getUpperBound(), parallelOp.getStep());
392 rewriter.
create<omp::TerminatorOp>(loc);
395 loop.region().begin());
398 loop.region().begin()->begin());
402 auto scope = rewriter.
create<memref::AllocaScopeOp>(parallelOp.getLoc(),
407 auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
410 oldYield, oldYield->getOperands());
411 if (!reductionVariables.empty()) {
413 ArrayAttr::get(rewriter.
getContext(), reductionDeclSymbols));
414 loop.reduction_varsMutable().append(reductionVariables);
421 results.reserve(reductionVariables.size());
422 for (
Value variable : reductionVariables) {
423 Value res = rewriter.
create<LLVM::LoadOp>(loc, variable);
424 results.push_back(res);
435 target.
addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
436 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
437 memref::MemRefDialect>();
440 patterns.
add<ParallelOpLowering>(module.getContext());
446 struct SCFToOpenMPPass :
public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
448 void runOnOperation()
override {
457 return std::make_unique<SCFToOpenMPPass>();
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
MLIRContext * getContext() const
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation is a basic unit of execution within MLIR.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class represents a frozen set of patterns that can be processed by a pattern applicator...
Block represents an ordered list of Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
static Attribute minMaxValueForUnsignedInt(Type type, bool min)
Returns an attribute with the unsigned integer minimum (if min is set) or the maximum value (otherwis...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
std::unique_ptr< OperationPass< ModuleOp > > createConvertSCFToOpenMPPass()
static bool matchSelectReduction(Block &block, ArrayRef< Predicate > lessThanPredicates, ArrayRef< Predicate > greaterThanPredicates, bool &isMin)
Matches a block containing a select-based min/max reduction.
static Attribute minMaxValueForFloat(Type type, bool min)
Returns an attribute with the minimum (if min is set) or the maximum value (otherwise) for the given ...
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
BlockArgument getArgument(unsigned i)
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
static constexpr const bool value
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
static bool matchSimpleReduction(Block &block)
Matches a block containing a "simple" reduction.
FloatAttr getFloatAttr(Type type, double value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
unsigned getNumArguments()
Attributes are known-constant values of operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
IntegerType getIntegerType(unsigned width)
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
This class provides an abstraction over the various different ranges of value types.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation *> &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
BlockArgListType getArguments()
This class represents an argument of a Block.
static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, Attribute initValue)
Creates an OpenMP reduction declaration and inserts it into the provided symbol table.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, scf::ReduceOp reduce)
Creates an OpenMP reduction declaration that corresponds to the given SCF reduction and returns it...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, scf::ReduceOp reduce)
Adds an atomic reduction combiner to the given OpenMP reduction declaration using llvm...
This class describes a specific conversion target.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
int compare(Fraction x, Fraction y)
Three-way comparison between two fractions.
static Attribute minMaxValueForSignedInt(Type type, bool min)
Returns an attribute with the signed integer minimum (if min is set) or the maximum value (otherwise)...
static const llvm::fltSemantics & fltSemanticsForType(FloatType type)
Returns the float semantics for the given float type.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)