MLIR  21.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// A pattern that converts the result and operand types, attributes, and region
33 /// arguments of an OpenMP operation to the LLVM dialect.
34 ///
35 /// Attributes are copied verbatim by default, and only translated if they are
36 /// type attributes.
37 ///
38 /// Region bodies, if any, are not modified and expected to either be processed
39 /// by the conversion infrastructure or already contain ops compatible with LLVM
40 /// dialect types.
41 template <typename T>
42 struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
44 
45  LogicalResult
46  matchAndRewrite(T op, typename T::Adaptor adaptor,
47  ConversionPatternRewriter &rewriter) const override {
48  // Translate result types.
50  SmallVector<Type> resTypes;
51  if (failed(converter->convertTypes(op->getResultTypes(), resTypes)))
52  return failure();
53 
54  // Translate type attributes.
55  // They are kept unmodified except if they are type attributes.
56  SmallVector<NamedAttribute> convertedAttrs;
57  for (NamedAttribute attr : op->getAttrs()) {
58  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
59  Type convertedType = converter->convertType(typeAttr.getValue());
60  convertedAttrs.emplace_back(attr.getName(),
61  TypeAttr::get(convertedType));
62  } else {
63  convertedAttrs.push_back(attr);
64  }
65  }
66 
67  // Translate operands.
68  SmallVector<Value> convertedOperands;
69  convertedOperands.reserve(op->getNumOperands());
70  for (auto [originalOperand, convertedOperand] :
71  llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
72  if (!originalOperand)
73  return failure();
74 
75  // TODO: Revisit whether we need to trigger an error specifically for this
76  // set of operations. Consider removing this check or updating the list.
77  if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
78  omp::FlushOp, omp::MapBoundsOp,
79  omp::ThreadprivateOp>::value) {
80  if (isa<MemRefType>(originalOperand.getType())) {
81  // TODO: Support memref type in variable operands
82  return rewriter.notifyMatchFailure(op, "memref is not supported yet");
83  }
84  }
85  convertedOperands.push_back(convertedOperand);
86  }
87 
88  // Create new operation.
89  auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands,
90  convertedAttrs);
91 
92  // Translate regions.
93  for (auto [originalRegion, convertedRegion] :
94  llvm::zip_equal(op->getRegions(), newOp->getRegions())) {
95  rewriter.inlineRegionBefore(originalRegion, convertedRegion,
96  convertedRegion.end());
97  if (failed(rewriter.convertRegionTypes(&convertedRegion,
98  *this->getTypeConverter())))
99  return failure();
100  }
101 
102  // Delete old operation and replace result uses with those of the new one.
103  rewriter.replaceOp(op, newOp->getResults());
104  return success();
105  }
106 };
107 
108 } // namespace
109 
111  ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
112  target.addDynamicallyLegalOp<
113 #define GET_OP_LIST
114 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
115  >([&](Operation *op) {
116  return typeConverter.isLegal(op->getOperandTypes()) &&
117  typeConverter.isLegal(op->getResultTypes()) &&
118  std::all_of(op->getRegions().begin(), op->getRegions().end(),
119  [&](Region &region) {
120  return typeConverter.isLegal(&region);
121  }) &&
122  std::all_of(op->getAttrs().begin(), op->getAttrs().end(),
123  [&](NamedAttribute attr) {
124  auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
125  return !typeAttr ||
126  typeConverter.isLegal(typeAttr.getValue());
127  });
128  });
129 }
130 
131 /// Add an `OpenMPOpConversion<T>` conversion pattern for each operation type
132 /// passed as template argument.
133 template <typename... Ts>
134 static inline RewritePatternSet &
137  return patterns.add<OpenMPOpConversion<Ts>...>(converter);
138 }
139 
142  // This type is allowed when converting OpenMP to LLVM Dialect, it carries
143  // bounds information for map clauses and the operation and type are
144  // discarded on lowering to LLVM-IR from the OpenMP dialect.
145  converter.addConversion(
146  [&](omp::MapBoundsType type) -> Type { return type; });
147 
148  // Add conversions for all OpenMP operations.
150 #define GET_OP_LIST
151 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
152  >(converter, patterns);
153 }
154 
155 namespace {
156 struct ConvertOpenMPToLLVMPass
157  : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
158  using Base::Base;
159 
160  void runOnOperation() override;
161 };
162 } // namespace
163 
164 void ConvertOpenMPToLLVMPass::runOnOperation() {
165  auto module = getOperation();
166 
167  // Convert to OpenMP operations with LLVM IR dialect
169  LLVMTypeConverter converter(&getContext());
176 
178  target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
179  omp::TaskyieldOp, omp::TerminatorOp>();
180  configureOpenMPToLLVMConversionLegality(target, converter);
181  if (failed(applyPartialConversion(module, target, std::move(patterns))))
182  signalPassFailure();
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // ConvertToLLVMPatternInterface implementation
187 //===----------------------------------------------------------------------===//
188 namespace {
189 /// Implement the interface to convert OpenMP to LLVM.
190 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
192  void loadDependentDialects(MLIRContext *context) const final {
193  context->loadDialect<LLVM::LLVMDialect>();
194  }
195 
196  /// Hook for derived dialect interface to provide conversion patterns
197  /// and mark dialect legal for the conversion target.
198  void populateConvertToLLVMConversionPatterns(
199  ConversionTarget &target, LLVMTypeConverter &typeConverter,
200  RewritePatternSet &patterns) const final {
201  configureOpenMPToLLVMConversionLegality(target, typeConverter);
203  }
204 };
205 } // namespace
206 
208  registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
209  dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
210  });
211 }
static MLIRContext * getContext(OpFoldResult val)
static RewritePatternSet & addOpenMPOpConversions(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Add an OpenMPOpConversion<T> conversion pattern for each operation type passed as template argument.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:742
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.