MLIR 23.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/Pass/Pass.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30
31/// A pattern that converts the result and operand types, attributes, and region
32/// arguments of an OpenMP operation to the LLVM dialect.
33///
34/// Attributes are copied verbatim by default, and only translated if they are
35/// type attributes.
36///
37/// Region bodies, if any, are not modified and expected to either be processed
38/// by the conversion infrastructure or already contain ops compatible with LLVM
39/// dialect types.
40template <typename T>
41struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
42 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
43
44 OpenMPOpConversion(LLVMTypeConverter &typeConverter,
45 PatternBenefit benefit = 1)
46 : ConvertOpToLLVMPattern<T>(typeConverter, benefit) {
47 // Operations using CanonicalLoopInfoType are lowered only by
48 // mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then,
49 // the type and operations using it must be preserved.
50 typeConverter.addConversion(
51 [&](::mlir::omp::CanonicalLoopInfoType type) { return type; });
52 }
53
54 LogicalResult
55 matchAndRewrite(T op, typename T::Adaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 // Translate result types.
58 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
59 SmallVector<Type> resTypes;
60 if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
61 return failure();
62
63 // Translate type attributes.
64 // They are kept unmodified except if they are type attributes.
65 SmallVector<NamedAttribute> convertedAttrs;
66 for (NamedAttribute attr : op->getAttrs()) {
67 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
68 Type convertedType = converter->convertType(typeAttr.getValue());
69 if (!convertedType)
70 return rewriter.notifyMatchFailure(
71 op, "failed to convert type in attribute");
72 convertedAttrs.emplace_back(attr.getName(),
73 TypeAttr::get(convertedType));
74 } else {
75 convertedAttrs.push_back(attr);
76 }
77 }
78
79 // Translate operands.
80 SmallVector<Value> convertedOperands;
81 convertedOperands.reserve(op->getNumOperands());
82 for (auto [originalOperand, convertedOperand] :
83 llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
84 if (!originalOperand)
85 return failure();
86
87 // TODO: Revisit whether we need to trigger an error specifically for this
88 // set of operations. Consider removing this check or updating the list.
89 if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
90 omp::FlushOp, omp::MapBoundsOp,
91 omp::ThreadprivateOp>::value) {
92 if (isa<MemRefType>(originalOperand.getType())) {
93 // TODO: Support memref type in variable operands
94 return rewriter.notifyMatchFailure(op, "memref is not supported yet");
95 }
96 }
97 convertedOperands.push_back(convertedOperand);
98 }
99
100 // Create new operation.
101 auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands,
102 convertedAttrs);
103
104 // Translate regions.
105 for (auto [originalRegion, convertedRegion] :
106 llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
107 rewriter.inlineRegionBefore(originalRegion, convertedRegion,
108 convertedRegion.end());
109 if (failed(rewriter.convertRegionTypes(&convertedRegion,
110 *this->getTypeConverter())))
111 return failure();
112 }
113
114 // Delete old operation and replace result uses with those of the new one.
115 rewriter.replaceOp(op, newOp->getResults());
116 return success();
117 }
118};
119
120} // namespace
121
123 ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
124 target.addDynamicallyLegalOp<
125#define GET_OP_LIST
126#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
127 >([&](Operation *op) {
128 return typeConverter.isLegal(op->getOperandTypes()) &&
129 typeConverter.isLegal(op->getResultTypes()) &&
130 llvm::all_of(op->getRegions(),
131 [&](Region &region) {
132 return typeConverter.isLegal(&region);
133 }) &&
134 llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) {
135 auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
136 return !typeAttr || typeConverter.isLegal(typeAttr.getValue());
137 });
138 });
139}
140
141/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
142/// passed as template argument.
143template <typename... Ts>
144static inline RewritePatternSet &
146 RewritePatternSet &patterns) {
147 return patterns.add<OpenMPOpConversion<Ts>...>(converter);
148}
149
151 RewritePatternSet &patterns) {
152 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
153 // bounds information for map clauses and the operation and type are
154 // discarded on lowering to LLVM-IR from the OpenMP dialect.
155 converter.addConversion(
156 [&](omp::MapBoundsType type) -> Type { return type; });
157
158 // Add conversions for all OpenMP operations.
160#define GET_OP_LIST
161#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
162 >(converter, patterns);
163}
164
165namespace {
166struct ConvertOpenMPToLLVMPass
167 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
168 using Base::Base;
169
170 void runOnOperation() override;
171};
172} // namespace
173
174void ConvertOpenMPToLLVMPass::runOnOperation() {
175 auto module = getOperation();
176
177 // Convert to OpenMP operations with LLVM IR dialect
178 RewritePatternSet patterns(&getContext());
179 LLVMTypeConverter converter(&getContext());
180 arith::populateArithToLLVMConversionPatterns(converter, patterns);
184 populateFuncToLLVMConversionPatterns(converter, patterns);
185 populateOpenMPToLLVMConversionPatterns(converter, patterns);
186
187 LLVMConversionTarget target(getContext());
188 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
189 omp::TaskyieldOp, omp::TerminatorOp>();
191 if (failed(applyPartialConversion(module, target, std::move(patterns))))
192 signalPassFailure();
193}
194
195//===----------------------------------------------------------------------===//
196// ConvertToLLVMPatternInterface implementation
197//===----------------------------------------------------------------------===//
198namespace {
199/// Implement the interface to convert OpenMP to LLVM.
200struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
202 void loadDependentDialects(MLIRContext *context) const final {
203 context->loadDialect<LLVM::LLVMDialect>();
204 }
205
206 /// Hook for derived dialect interface to provide conversion patterns
207 /// and mark dialect legal for the conversion target.
208 void populateConvertToLLVMConversionPatterns(
209 ConversionTarget &target, LLVMTypeConverter &typeConverter,
210 RewritePatternSet &patterns) const final {
212 populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
213 }
214};
215} // namespace
216
218 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
219 dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
220 });
221}
return success()
b getContext())
static RewritePatternSet & addOpenMPOpConversions(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Add an OpenMPOpConversion<T> conversion pattern for each operation type passed as template argument.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:29
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.