MLIR  20.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// A pattern that converts the region arguments in a single-region OpenMP
32 /// operation to the LLVM dialect. The body of the region is not modified and is
33 /// expected to either be processed by the conversion infrastructure or already
34 /// contain ops compatible with LLVM dialect types.
35 template <typename OpType>
36 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
38 
39  LogicalResult
40  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override {
42  auto newOp = rewriter.create<OpType>(
43  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
44  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
45  newOp.getRegion().end());
46  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
47  *this->getTypeConverter())))
48  return failure();
49 
50  rewriter.eraseOp(curOp);
51  return success();
52  }
53 };
54 
55 template <typename T>
56 struct RegionLessOpWithVarOperandsConversion
57  : public ConvertOpToLLVMPattern<T> {
59  LogicalResult
60  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
61  ConversionPatternRewriter &rewriter) const override {
63  SmallVector<Type> resTypes;
64  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
65  return failure();
66  SmallVector<Value> convertedOperands;
67  assert(curOp.getNumVariableOperands() ==
68  curOp.getOperation()->getNumOperands() &&
69  "unexpected non-variable operands");
70  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
71  Value originalVariableOperand = curOp.getVariableOperand(idx);
72  if (!originalVariableOperand)
73  return failure();
74  if (isa<MemRefType>(originalVariableOperand.getType())) {
75  // TODO: Support memref type in variable operands
76  return rewriter.notifyMatchFailure(curOp,
77  "memref is not supported yet");
78  }
79  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
80  }
81 
82  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
83  curOp->getAttrs());
84  return success();
85  }
86 };
87 
88 template <typename T>
89 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
91  LogicalResult
92  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
93  ConversionPatternRewriter &rewriter) const override {
95  SmallVector<Type> resTypes;
96  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
97  return failure();
98  SmallVector<Value> convertedOperands;
99  assert(curOp.getNumVariableOperands() ==
100  curOp.getOperation()->getNumOperands() &&
101  "unexpected non-variable operands");
102  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
103  Value originalVariableOperand = curOp.getVariableOperand(idx);
104  if (!originalVariableOperand)
105  return failure();
106  if (isa<MemRefType>(originalVariableOperand.getType())) {
107  // TODO: Support memref type in variable operands
108  return rewriter.notifyMatchFailure(curOp,
109  "memref is not supported yet");
110  }
111  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
112  }
113  auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
114  curOp->getAttrs());
115  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
116  newOp.getRegion().end());
117  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
118  *this->getTypeConverter())))
119  return failure();
120 
121  rewriter.eraseOp(curOp);
122  return success();
123  }
124 };
125 
126 template <typename T>
127 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
129  LogicalResult
130  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
133  SmallVector<Type> resTypes;
134  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
135  return failure();
136 
137  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
138  curOp->getAttrs());
139  return success();
140  }
141 };
142 
143 struct AtomicReadOpConversion
144  : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
146  LogicalResult
147  matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
148  ConversionPatternRewriter &rewriter) const override {
150  Type curElementType = curOp.getElementType();
151  auto newOp = rewriter.create<omp::AtomicReadOp>(
152  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
153  TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
154  newOp.setElementTypeAttr(typeAttr);
155  rewriter.eraseOp(curOp);
156  return success();
157  }
158 };
159 
160 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
162  LogicalResult
163  matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
164  ConversionPatternRewriter &rewriter) const override {
166 
167  SmallVector<Type> resTypes;
168  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
169  return failure();
170 
171  // Copy attributes of the curOp except for the typeAttr which should
172  // be converted
174  for (NamedAttribute attr : curOp->getAttrs()) {
175  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
176  Type newAttr = converter->convertType(typeAttr.getValue());
177  newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
178  } else {
179  newAttrs.push_back(attr);
180  }
181  }
182 
183  rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
184  curOp, resTypes, adaptor.getOperands(), newAttrs);
185  return success();
186  }
187 };
188 
189 template <typename OpType>
190 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
192 
193  void forwardOpAttrs(OpType curOp, OpType newOp) const {}
194 
195  LogicalResult
196  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override {
198  auto newOp = rewriter.create<OpType>(
199  curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
200  TypeAttr::get(this->getTypeConverter()->convertType(
201  curOp.getTypeAttr().getValue())));
202  forwardOpAttrs(curOp, newOp);
203 
204  for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
205  rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
206  newOp.getRegion(idx).end());
207  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
208  *this->getTypeConverter())))
209  return failure();
210  }
211 
212  rewriter.eraseOp(curOp);
213  return success();
214  }
215 };
216 
217 template <>
218 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
219  omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
220  newOp.setDataSharingType(curOp.getDataSharingType());
221 }
222 } // namespace
223 
225  ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
226  target.addDynamicallyLegalOp<
227  omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
228  omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
229  omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
230  omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
231  omp::YieldOp>([&](Operation *op) {
232  return typeConverter.isLegal(op->getOperandTypes()) &&
233  typeConverter.isLegal(op->getResultTypes());
234  });
235  target.addDynamicallyLegalOp<
236  omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp,
237  omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp,
238  omp::OrderedRegionOp, omp::ParallelOp, omp::PrivateClauseOp,
239  omp::SectionOp, omp::SectionsOp, omp::SimdOp, omp::SingleOp,
240  omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, omp::TaskloopOp,
241  omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) {
242  return std::all_of(op->getRegions().begin(), op->getRegions().end(),
243  [&](Region &region) {
244  return typeConverter.isLegal(&region);
245  }) &&
246  typeConverter.isLegal(op->getOperandTypes()) &&
247  typeConverter.isLegal(op->getResultTypes());
248  });
249 }
250 
252  RewritePatternSet &patterns) {
253  // This type is allowed when converting OpenMP to LLVM Dialect, it carries
254  // bounds information for map clauses and the operation and type are
255  // discarded on lowering to LLVM-IR from the OpenMP dialect.
256  converter.addConversion(
257  [&](omp::MapBoundsType type) -> Type { return type; });
258 
259  patterns.add<
260  AtomicReadOpConversion, MapInfoOpConversion,
261  MultiRegionOpConversion<omp::DeclareReductionOp>,
262  MultiRegionOpConversion<omp::PrivateClauseOp>,
263  RegionLessOpConversion<omp::CancellationPointOp>,
264  RegionLessOpConversion<omp::CancelOp>,
265  RegionLessOpConversion<omp::CriticalDeclareOp>,
266  RegionLessOpConversion<omp::OrderedOp>,
267  RegionLessOpConversion<omp::TargetEnterDataOp>,
268  RegionLessOpConversion<omp::TargetExitDataOp>,
269  RegionLessOpConversion<omp::TargetUpdateOp>,
270  RegionLessOpConversion<omp::YieldOp>,
271  RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
272  RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
273  RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
274  RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
275  RegionOpConversion<omp::AtomicCaptureOp>,
276  RegionOpConversion<omp::CriticalOp>,
277  RegionOpConversion<omp::DistributeOp>,
278  RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
279  RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
280  RegionOpConversion<omp::OrderedRegionOp>,
281  RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
282  RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
283  RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
284  RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
285  RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
286  RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
287  RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
288 }
289 
290 namespace {
291 struct ConvertOpenMPToLLVMPass
292  : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
293  using Base::Base;
294 
295  void runOnOperation() override;
296 };
297 } // namespace
298 
299 void ConvertOpenMPToLLVMPass::runOnOperation() {
300  auto module = getOperation();
301 
302  // Convert to OpenMP operations with LLVM IR dialect
303  RewritePatternSet patterns(&getContext());
304  LLVMTypeConverter converter(&getContext());
308  populateFuncToLLVMConversionPatterns(converter, patterns);
309  populateOpenMPToLLVMConversionPatterns(converter, patterns);
310 
312  target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
313  omp::TaskyieldOp, omp::TerminatorOp>();
314  configureOpenMPToLLVMConversionLegality(target, converter);
315  if (failed(applyPartialConversion(module, target, std::move(patterns))))
316  signalPassFailure();
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // ConvertToLLVMPatternInterface implementation
321 //===----------------------------------------------------------------------===//
322 namespace {
323 /// Implement the interface to convert OpenMP to LLVM.
324 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
326  void loadDependentDialects(MLIRContext *context) const final {
327  context->loadDialect<LLVM::LLVMDialect>();
328  }
329 
330  /// Hook for derived dialect interface to provide conversion patterns
331  /// and mark dialect legal for the conversion target.
332  void populateConvertToLLVMConversionPatterns(
333  ConversionTarget &target, LLVMTypeConverter &typeConverter,
334  RewritePatternSet &patterns) const final {
335  configureOpenMPToLLVMConversionLegality(target, typeConverter);
336  populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
337  }
338 };
339 } // namespace
340 
342  registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
343  dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
344  });
345 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:733
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.