24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
35 template <
typename OpType>
40 matchAndRewrite(OpType curOp,
typename OpType::Adaptor adaptor,
42 auto newOp = rewriter.
create<OpType>(
43 curOp.getLoc(),
TypeRange(), adaptor.getOperands(), curOp->getAttrs());
45 newOp.getRegion().end());
47 *this->getTypeConverter())))
56 struct RegionLessOpWithVarOperandsConversion
60 matchAndRewrite(T curOp,
typename T::Adaptor adaptor,
64 if (failed(converter->
convertTypes(curOp->getResultTypes(), resTypes)))
67 assert(curOp.getNumVariableOperands() ==
68 curOp.getOperation()->getNumOperands() &&
69 "unexpected non-variable operands");
70 for (
unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
71 Value originalVariableOperand = curOp.getVariableOperand(idx);
72 if (!originalVariableOperand)
74 if (isa<MemRefType>(originalVariableOperand.
getType())) {
77 "memref is not supported yet");
79 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
92 matchAndRewrite(T curOp,
typename T::Adaptor adaptor,
96 if (failed(converter->
convertTypes(curOp->getResultTypes(), resTypes)))
99 assert(curOp.getNumVariableOperands() ==
100 curOp.getOperation()->getNumOperands() &&
101 "unexpected non-variable operands");
102 for (
unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
103 Value originalVariableOperand = curOp.getVariableOperand(idx);
104 if (!originalVariableOperand)
106 if (isa<MemRefType>(originalVariableOperand.
getType())) {
109 "memref is not supported yet");
111 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
113 auto newOp = rewriter.
create<T>(curOp.getLoc(), resTypes, convertedOperands,
116 newOp.getRegion().end());
118 *this->getTypeConverter())))
126 template <
typename T>
130 matchAndRewrite(T curOp,
typename T::Adaptor adaptor,
134 if (failed(converter->
convertTypes(curOp->getResultTypes(), resTypes)))
143 struct AtomicReadOpConversion
147 matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
150 Type curElementType = curOp.getElementType();
151 auto newOp = rewriter.
create<omp::AtomicReadOp>(
152 curOp.getLoc(),
TypeRange(), adaptor.getOperands(), curOp->getAttrs());
154 newOp.setElementTypeAttr(typeAttr);
163 matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
168 if (failed(converter->
convertTypes(curOp->getResultTypes(), resTypes)))
175 if (
auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
177 newAttrs.emplace_back(attr.getName(),
TypeAttr::get(newAttr));
179 newAttrs.push_back(attr);
184 curOp, resTypes, adaptor.getOperands(), newAttrs);
189 struct DeclMapperOpConversion
193 matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor,
197 newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr());
198 newAttrs.emplace_back(
199 curOp.getTypeAttrName(),
202 auto newOp = rewriter.
create<omp::DeclareMapperOp>(
203 curOp.getLoc(),
TypeRange(), adaptor.getOperands(), newAttrs);
205 newOp.getRegion().end());
207 *this->getTypeConverter())))
215 template <
typename OpType>
219 void forwardOpAttrs(OpType curOp, OpType newOp)
const {}
222 matchAndRewrite(OpType curOp,
typename OpType::Adaptor adaptor,
224 auto newOp = rewriter.
create<OpType>(
225 curOp.getLoc(),
TypeRange(), curOp.getSymNameAttr(),
227 curOp.getTypeAttr().getValue())));
228 forwardOpAttrs(curOp, newOp);
230 for (
unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
232 newOp.getRegion(idx).end());
234 *this->getTypeConverter())))
244 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
245 omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp)
const {
246 newOp.setDataSharingType(curOp.getDataSharingType());
253 omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
254 omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp,
255 omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp,
256 omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
257 omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>(
259 return typeConverter.
isLegal(op->getOperandTypes()) &&
260 typeConverter.
isLegal(op->getResultTypes());
263 omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp,
264 omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp,
265 omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp,
266 omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp,
267 omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp,
268 omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
270 return std::all_of(op->getRegions().begin(), op->getRegions().end(),
272 return typeConverter.isLegal(®ion);
274 typeConverter.
isLegal(op->getOperandTypes()) &&
275 typeConverter.
isLegal(op->getResultTypes());
278 [&](omp::PrivateClauseOp op) ->
bool {
279 return std::all_of(op->getRegions().begin(), op->getRegions().end(),
281 return typeConverter.isLegal(®ion);
283 typeConverter.
isLegal(op->getOperandTypes()) &&
284 typeConverter.
isLegal(op->getResultTypes()) &&
285 typeConverter.
isLegal(op.getType());
295 [&](omp::MapBoundsType type) ->
Type {
return type; });
298 AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion,
299 MultiRegionOpConversion<omp::DeclareReductionOp>,
300 MultiRegionOpConversion<omp::PrivateClauseOp>,
301 RegionLessOpConversion<omp::CancellationPointOp>,
302 RegionLessOpConversion<omp::CancelOp>,
303 RegionLessOpConversion<omp::CriticalDeclareOp>,
304 RegionLessOpConversion<omp::DeclareMapperInfoOp>,
305 RegionLessOpConversion<omp::OrderedOp>,
306 RegionLessOpConversion<omp::ScanOp>,
307 RegionLessOpConversion<omp::TargetEnterDataOp>,
308 RegionLessOpConversion<omp::TargetExitDataOp>,
309 RegionLessOpConversion<omp::TargetUpdateOp>,
310 RegionLessOpConversion<omp::YieldOp>,
311 RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
312 RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
313 RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
314 RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
315 RegionOpConversion<omp::AtomicCaptureOp>,
316 RegionOpConversion<omp::CriticalOp>,
317 RegionOpConversion<omp::DistributeOp>,
318 RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
319 RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
320 RegionOpConversion<omp::OrderedRegionOp>,
321 RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
322 RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
323 RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
324 RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
325 RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
326 RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
327 RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
331 struct ConvertOpenMPToLLVMPass
332 :
public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
335 void runOnOperation()
override;
339 void ConvertOpenMPToLLVMPass::runOnOperation() {
340 auto module = getOperation();
353 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
354 omp::TaskyieldOp, omp::TerminatorOp>();
367 void loadDependentDialects(
MLIRContext *context)
const final {
368 context->loadDialect<LLVM::LLVMDialect>();
373 void populateConvertToLLVMConversionPatterns(
384 dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
void registerConvertOpenMPToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.