MLIR 23.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "mlir/Pass/Pass.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30
31/// A pattern that converts the result and operand types, attributes, and region
32/// arguments of an OpenMP operation to the LLVM dialect.
33///
34/// Attributes are copied verbatim by default, and only translated if they are
35/// type attributes.
36///
37/// Region bodies, if any, are not modified and expected to either be processed
38/// by the conversion infrastructure or already contain ops compatible with LLVM
39/// dialect types.
40template <typename T>
41struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
42 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
43
44 OpenMPOpConversion(LLVMTypeConverter &typeConverter,
45 PatternBenefit benefit = 1)
46 : ConvertOpToLLVMPattern<T>(typeConverter, benefit) {
47 // Operations using CanonicalLoopInfoType are lowered only by
48 // mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then,
49 // the type and operations using it must be preserved.
50 typeConverter.addConversion(
51 [&](::mlir::omp::CanonicalLoopInfoType type) { return type; });
52 }
53
54 LogicalResult
55 matchAndRewrite(T op, typename T::Adaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 // Translate result types.
58 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
59 SmallVector<Type> resTypes;
60 if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
61 return failure();
62
63 // Translate type attributes.
64 // They are kept unmodified except if they are type attributes.
65 SmallVector<NamedAttribute> convertedAttrs;
66 for (NamedAttribute attr : op->getAttrs()) {
67 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
68 Type convertedType = converter->convertType(typeAttr.getValue());
69 if (!convertedType)
70 return rewriter.notifyMatchFailure(
71 op, "failed to convert type in attribute");
72 convertedAttrs.emplace_back(attr.getName(),
73 TypeAttr::get(convertedType));
74 } else {
75 convertedAttrs.push_back(attr);
76 }
77 }
78
79 // Translate operands.
80 SmallVector<Value> convertedOperands;
81 convertedOperands.reserve(op->getNumOperands());
82 for (auto [originalOperand, convertedOperand] :
83 llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
84 if (!originalOperand)
85 return failure();
86
87 // TODO: Revisit whether we need to trigger an error specifically for this
88 // set of operations. Consider removing this check or updating the list.
89 if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
90 omp::FlushOp, omp::MapBoundsOp,
91 omp::ThreadprivateOp>::value) {
92 if (isa<MemRefType>(originalOperand.getType())) {
93 // TODO: Support memref type in variable operands
94 return rewriter.notifyMatchFailure(op, "memref is not supported yet");
95 }
96 }
97 convertedOperands.push_back(convertedOperand);
98 }
99
100 // Create new operation.
101 auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands,
102 convertedAttrs);
103
104 // Translate regions.
105 for (auto [originalRegion, convertedRegion] :
106 llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
107 rewriter.inlineRegionBefore(originalRegion, convertedRegion,
108 convertedRegion.end());
109 if (failed(rewriter.convertRegionTypes(&convertedRegion,
110 *this->getTypeConverter())))
111 return failure();
112 }
113
114 // Delete old operation and replace result uses with those of the new one.
115 rewriter.replaceOp(op, newOp->getResults());
116 return success();
117 }
118};
119
120} // namespace
121
123 ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
124 target.addDynamicallyLegalOp<
125#define GET_OP_LIST
126#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
127 >([&](Operation *op) {
128 return typeConverter.isLegal(op->getOperandTypes()) &&
129 typeConverter.isLegal(op->getResultTypes()) &&
130 llvm::all_of(op->getRegions(),
131 [&](Region &region) {
132 return typeConverter.isLegal(&region);
133 }) &&
134 llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) {
135 auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
136 return !typeAttr || typeConverter.isLegal(typeAttr.getValue());
137 });
138 });
139}
140
141/// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
142/// passed as template argument.
143template <typename... Ts>
144static inline RewritePatternSet &
146 RewritePatternSet &patterns) {
147 return patterns.add<OpenMPOpConversion<Ts>...>(converter);
148}
149
151 RewritePatternSet &patterns) {
152 // This type is allowed when converting OpenMP to LLVM Dialect, it carries
153 // bounds information for map clauses and the operation and type are
154 // discarded on lowering to LLVM-IR from the OpenMP dialect.
155 converter.addConversion(
156 [&](omp::MapBoundsType type) -> Type { return type; });
157 converter.addConversion(
158 [&](omp::AffinityEntryType type) -> Type { return type; });
159 converter.addConversion([&](omp::IteratedType type) -> Type { return type; });
160
161 // Add conversions for all OpenMP operations.
163#define GET_OP_LIST
164#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
165 >(converter, patterns);
166}
167
168namespace {
169struct ConvertOpenMPToLLVMPass
170 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
171 using Base::Base;
172
173 void runOnOperation() override;
174};
175} // namespace
176
177void ConvertOpenMPToLLVMPass::runOnOperation() {
178 auto module = getOperation();
179
180 // Convert to OpenMP operations with LLVM IR dialect
181 RewritePatternSet patterns(&getContext());
182 LLVMTypeConverter converter(&getContext());
183 arith::populateArithToLLVMConversionPatterns(converter, patterns);
187 populateFuncToLLVMConversionPatterns(converter, patterns);
188 populateOpenMPToLLVMConversionPatterns(converter, patterns);
189
190 LLVMConversionTarget target(getContext());
191 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
192 omp::TaskyieldOp, omp::TerminatorOp>();
194 if (failed(applyPartialConversion(module, target, std::move(patterns))))
195 signalPassFailure();
196}
197
198//===----------------------------------------------------------------------===//
199// ConvertToLLVMPatternInterface implementation
200//===----------------------------------------------------------------------===//
201namespace {
202/// Implement the interface to convert OpenMP to LLVM.
203struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
204 OpenMPToLLVMDialectInterface(Dialect *dialect)
205 : ConvertToLLVMPatternInterface(dialect) {}
206
207 void loadDependentDialects(MLIRContext *context) const final {
208 context->loadDialect<LLVM::LLVMDialect>();
209 }
210
211 /// Hook for derived dialect interface to provide conversion patterns
212 /// and mark dialect legal for the conversion target.
213 void populateConvertToLLVMConversionPatterns(
214 ConversionTarget &target, LLVMTypeConverter &typeConverter,
215 RewritePatternSet &patterns) const final {
217 populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
218 }
219};
220} // namespace
221
223 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
224 dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
225 });
226}
return success()
b getContext())
static RewritePatternSet & addOpenMPOpConversions(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Add an OpenMPOpConversion<T> conversion pattern for each operation type passed as template argument.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:29
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.