MLIR  21.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// A pattern that converts the result and operand types, attributes, and region
33 /// arguments of an OpenMP operation to the LLVM dialect.
34 ///
35 /// Attributes are copied verbatim by default, and only translated if they are
36 /// type attributes.
37 ///
38 /// Region bodies, if any, are not modified and expected to either be processed
39 /// by the conversion infrastructure or already contain ops compatible with LLVM
40 /// dialect types.
41 template <typename T>
42 struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
44 
45  LogicalResult
46  matchAndRewrite(T op, typename T::Adaptor adaptor,
47  ConversionPatternRewriter &rewriter) const override {
48  // Translate result types.
50  SmallVector<Type> resTypes;
51  if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
52  return failure();
53 
54  // Translate type attributes.
55  // They are kept unmodified except if they are type attributes.
56  SmallVector<NamedAttribute> convertedAttrs;
57  for (NamedAttribute attr : op->getAttrs()) {
58  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
59  Type convertedType = converter->convertType(typeAttr.getValue());
60  convertedAttrs.emplace_back(attr.getName(),
61  TypeAttr::get(convertedType));
62  } else {
63  convertedAttrs.push_back(attr);
64  }
65  }
66 
67  // Translate operands.
68  SmallVector<Value> convertedOperands;
69  convertedOperands.reserve(op->getNumOperands());
70  for (auto [originalOperand, convertedOperand] :
71  llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
72  if (!originalOperand)
73  return failure();
74 
75  // TODO: Revisit whether we need to trigger an error specifically for this
76  // set of operations. Consider removing this check or updating the list.
77  if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
78  omp::FlushOp, omp::MapBoundsOp,
79  omp::ThreadprivateOp>::value) {
80  if (isa<MemRefType>(originalOperand.getType())) {
81  // TODO: Support memref type in variable operands
82  return rewriter.notifyMatchFailure(op, "memref is not supported yet");
83  }
84  }
85  convertedOperands.push_back(convertedOperand);
86  }
87 
88  // Create new operation.
89  auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
90  convertedAttrs);
91 
92  // Translate regions.
93  for (auto [originalRegion, convertedRegion] :
94  llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
95  rewriter.inlineRegionBefore(originalRegion, convertedRegion,
96  convertedRegion.end());
97  if (failed(rewriter.convertRegionTypes(&convertedRegion,
98  *this->getTypeConverter())))
99  return failure();
100  }
101 
102  // Delete old operation and replace result uses with those of the new one.
103  rewriter.replaceOp(op, newOp->getResults());
104  return success();
105  }
106 };
107 
108 } // namespace
109 
111  ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
112  target.addDynamicallyLegalOp<
113 #define GET_OP_LIST
114 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
115  >([&](Operation *op) {
116  return typeConverter.isLegal(op->getOperandTypes()) &&
117  typeConverter.isLegal(op->getResultTypes()) &&
118  llvm::all_of(op->getRegions(),
119  [&](Region &region) {
120  return typeConverter.isLegal(&region);
121  }) &&
122  llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) {
123  auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
124  return !typeAttr || typeConverter.isLegal(typeAttr.getValue());
125  });
126  });
127 }
128 
129 /// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
130 /// passed as template argument.
131 template <typename... Ts>
132 static inline RewritePatternSet &
135  return patterns.add<OpenMPOpConversion<Ts>...>(converter);
136 }
137 
140  // This type is allowed when converting OpenMP to LLVM Dialect, it carries
141  // bounds information for map clauses and the operation and type are
142  // discarded on lowering to LLVM-IR from the OpenMP dialect.
143  converter.addConversion(
144  [&](omp::MapBoundsType type) -> Type { return type; });
145 
146  // Add conversions for all OpenMP operations.
148 #define GET_OP_LIST
149 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
150  >(converter, patterns);
151 }
152 
153 namespace {
154 struct ConvertOpenMPToLLVMPass
155  : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
156  using Base::Base;
157 
158  void runOnOperation() override;
159 };
160 } // namespace
161 
162 void ConvertOpenMPToLLVMPass::runOnOperation() {
163  auto module = getOperation();
164 
165  // Convert to OpenMP operations with LLVM IR dialect
167  LLVMTypeConverter converter(&getContext());
174 
176  target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
177  omp::TaskyieldOp, omp::TerminatorOp>();
178  configureOpenMPToLLVMConversionLegality(target, converter);
179  if (failed(applyPartialConversion(module, target, std::move(patterns))))
180  signalPassFailure();
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // ConvertToLLVMPatternInterface implementation
185 //===----------------------------------------------------------------------===//
186 namespace {
187 /// Implement the interface to convert OpenMP to LLVM.
188 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
190  void loadDependentDialects(MLIRContext *context) const final {
191  context->loadDialect<LLVM::LLVMDialect>();
192  }
193 
194  /// Hook for derived dialect interface to provide conversion patterns
195  /// and mark dialect legal for the conversion target.
196  void populateConvertToLLVMConversionPatterns(
197  ConversionTarget &target, LLVMTypeConverter &typeConverter,
198  RewritePatternSet &patterns) const final {
199  configureOpenMPToLLVMConversionLegality(target, typeConverter);
201  }
202 };
203 } // namespace
204 
206  registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
207  dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
208  });
209 }
static MLIRContext * getContext(OpFoldResult val)
static RewritePatternSet & addOpenMPOpConversions(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Add an OpenMPOpConversion<T> conversion pattern for each operation type passed as template argument.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:195
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true, SymbolTableCollection *symbolTables=nullptr)
Populate the cf.assert to LLVM conversion pattern.
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:757
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.