MLIR  21.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// A pattern that converts the region arguments in a single-region OpenMP
32 /// operation to the LLVM dialect. The body of the region is not modified and is
33 /// expected to either be processed by the conversion infrastructure or already
34 /// contain ops compatible with LLVM dialect types.
35 template <typename OpType>
36 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
38 
39  LogicalResult
40  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override {
42  auto newOp = rewriter.create<OpType>(
43  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
44  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
45  newOp.getRegion().end());
46  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
47  *this->getTypeConverter())))
48  return failure();
49 
50  rewriter.eraseOp(curOp);
51  return success();
52  }
53 };
54 
55 template <typename T>
56 struct RegionLessOpWithVarOperandsConversion
57  : public ConvertOpToLLVMPattern<T> {
59  LogicalResult
60  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
61  ConversionPatternRewriter &rewriter) const override {
63  SmallVector<Type> resTypes;
64  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
65  return failure();
66  SmallVector<Value> convertedOperands;
67  assert(curOp.getNumVariableOperands() ==
68  curOp.getOperation()->getNumOperands() &&
69  "unexpected non-variable operands");
70  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
71  Value originalVariableOperand = curOp.getVariableOperand(idx);
72  if (!originalVariableOperand)
73  return failure();
74  if (isa<MemRefType>(originalVariableOperand.getType())) {
75  // TODO: Support memref type in variable operands
76  return rewriter.notifyMatchFailure(curOp,
77  "memref is not supported yet");
78  }
79  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
80  }
81 
82  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
83  curOp->getAttrs());
84  return success();
85  }
86 };
87 
88 template <typename T>
89 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
91  LogicalResult
92  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
93  ConversionPatternRewriter &rewriter) const override {
95  SmallVector<Type> resTypes;
96  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
97  return failure();
98  SmallVector<Value> convertedOperands;
99  assert(curOp.getNumVariableOperands() ==
100  curOp.getOperation()->getNumOperands() &&
101  "unexpected non-variable operands");
102  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
103  Value originalVariableOperand = curOp.getVariableOperand(idx);
104  if (!originalVariableOperand)
105  return failure();
106  if (isa<MemRefType>(originalVariableOperand.getType())) {
107  // TODO: Support memref type in variable operands
108  return rewriter.notifyMatchFailure(curOp,
109  "memref is not supported yet");
110  }
111  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
112  }
113  auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
114  curOp->getAttrs());
115  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
116  newOp.getRegion().end());
117  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
118  *this->getTypeConverter())))
119  return failure();
120 
121  rewriter.eraseOp(curOp);
122  return success();
123  }
124 };
125 
126 template <typename T>
127 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
129  LogicalResult
130  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
133  SmallVector<Type> resTypes;
134  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
135  return failure();
136 
137  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
138  curOp->getAttrs());
139  return success();
140  }
141 };
142 
143 struct AtomicReadOpConversion
144  : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
146  LogicalResult
147  matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
148  ConversionPatternRewriter &rewriter) const override {
150  Type curElementType = curOp.getElementType();
151  auto newOp = rewriter.create<omp::AtomicReadOp>(
152  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
153  TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
154  newOp.setElementTypeAttr(typeAttr);
155  rewriter.eraseOp(curOp);
156  return success();
157  }
158 };
159 
160 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
162  LogicalResult
163  matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
164  ConversionPatternRewriter &rewriter) const override {
166 
167  SmallVector<Type> resTypes;
168  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
169  return failure();
170 
171  // Copy attributes of the curOp except for the typeAttr which should
172  // be converted
174  for (NamedAttribute attr : curOp->getAttrs()) {
175  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
176  Type newAttr = converter->convertType(typeAttr.getValue());
177  newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
178  } else {
179  newAttrs.push_back(attr);
180  }
181  }
182 
183  rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
184  curOp, resTypes, adaptor.getOperands(), newAttrs);
185  return success();
186  }
187 };
188 
189 struct DeclMapperOpConversion
190  : public ConvertOpToLLVMPattern<omp::DeclareMapperOp> {
192  LogicalResult
193  matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override {
197  newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr());
198  newAttrs.emplace_back(
199  curOp.getTypeAttrName(),
200  TypeAttr::get(converter->convertType(curOp.getType())));
201 
202  auto newOp = rewriter.create<omp::DeclareMapperOp>(
203  curOp.getLoc(), TypeRange(), adaptor.getOperands(), newAttrs);
204  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
205  newOp.getRegion().end());
206  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
207  *this->getTypeConverter())))
208  return failure();
209 
210  rewriter.eraseOp(curOp);
211  return success();
212  }
213 };
214 
215 template <typename OpType>
216 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
218 
219  void forwardOpAttrs(OpType curOp, OpType newOp) const {}
220 
221  LogicalResult
222  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
223  ConversionPatternRewriter &rewriter) const override {
224  auto newOp = rewriter.create<OpType>(
225  curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
226  TypeAttr::get(this->getTypeConverter()->convertType(
227  curOp.getTypeAttr().getValue())));
228  forwardOpAttrs(curOp, newOp);
229 
230  for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
231  rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
232  newOp.getRegion(idx).end());
233  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
234  *this->getTypeConverter())))
235  return failure();
236  }
237 
238  rewriter.eraseOp(curOp);
239  return success();
240  }
241 };
242 
243 template <>
244 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
245  omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
246  newOp.setDataSharingType(curOp.getDataSharingType());
247 }
248 } // namespace
249 
251  ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
252  target.addDynamicallyLegalOp<
253  omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
254  omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp,
255  omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp,
256  omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
257  omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>(
258  [&](Operation *op) {
259  return typeConverter.isLegal(op->getOperandTypes()) &&
260  typeConverter.isLegal(op->getResultTypes());
261  });
262  target.addDynamicallyLegalOp<
263  omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp,
264  omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp,
265  omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp,
266  omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp,
267  omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp,
268  omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
269  omp::WsloopOp>([&](Operation *op) {
270  return std::all_of(op->getRegions().begin(), op->getRegions().end(),
271  [&](Region &region) {
272  return typeConverter.isLegal(&region);
273  }) &&
274  typeConverter.isLegal(op->getOperandTypes()) &&
275  typeConverter.isLegal(op->getResultTypes());
276  });
277  target.addDynamicallyLegalOp<omp::PrivateClauseOp>(
278  [&](omp::PrivateClauseOp op) -> bool {
279  return std::all_of(op->getRegions().begin(), op->getRegions().end(),
280  [&](Region &region) {
281  return typeConverter.isLegal(&region);
282  }) &&
283  typeConverter.isLegal(op->getOperandTypes()) &&
284  typeConverter.isLegal(op->getResultTypes()) &&
285  typeConverter.isLegal(op.getType());
286  });
287 }
288 
291  // This type is allowed when converting OpenMP to LLVM Dialect, it carries
292  // bounds information for map clauses and the operation and type are
293  // discarded on lowering to LLVM-IR from the OpenMP dialect.
294  converter.addConversion(
295  [&](omp::MapBoundsType type) -> Type { return type; });
296 
297  patterns.add<
298  AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion,
299  MultiRegionOpConversion<omp::DeclareReductionOp>,
300  MultiRegionOpConversion<omp::PrivateClauseOp>,
301  RegionLessOpConversion<omp::CancellationPointOp>,
302  RegionLessOpConversion<omp::CancelOp>,
303  RegionLessOpConversion<omp::CriticalDeclareOp>,
304  RegionLessOpConversion<omp::DeclareMapperInfoOp>,
305  RegionLessOpConversion<omp::OrderedOp>,
306  RegionLessOpConversion<omp::ScanOp>,
307  RegionLessOpConversion<omp::TargetEnterDataOp>,
308  RegionLessOpConversion<omp::TargetExitDataOp>,
309  RegionLessOpConversion<omp::TargetUpdateOp>,
310  RegionLessOpConversion<omp::YieldOp>,
311  RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
312  RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
313  RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
314  RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
315  RegionOpConversion<omp::AtomicCaptureOp>,
316  RegionOpConversion<omp::CriticalOp>,
317  RegionOpConversion<omp::DistributeOp>,
318  RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
319  RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
320  RegionOpConversion<omp::OrderedRegionOp>,
321  RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
322  RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
323  RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
324  RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
325  RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
326  RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
327  RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
328 }
329 
330 namespace {
331 struct ConvertOpenMPToLLVMPass
332  : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
333  using Base::Base;
334 
335  void runOnOperation() override;
336 };
337 } // namespace
338 
339 void ConvertOpenMPToLLVMPass::runOnOperation() {
340  auto module = getOperation();
341 
342  // Convert to OpenMP operations with LLVM IR dialect
344  LLVMTypeConverter converter(&getContext());
351 
353  target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
354  omp::TaskyieldOp, omp::TerminatorOp>();
355  configureOpenMPToLLVMConversionLegality(target, converter);
356  if (failed(applyPartialConversion(module, target, std::move(patterns))))
357  signalPassFailure();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // ConvertToLLVMPatternInterface implementation
362 //===----------------------------------------------------------------------===//
363 namespace {
364 /// Implement the interface to convert OpenMP to LLVM.
365 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
367  void loadDependentDialects(MLIRContext *context) const final {
368  context->loadDialect<LLVM::LLVMDialect>();
369  }
370 
371  /// Hook for derived dialect interface to provide conversion patterns
372  /// and mark dialect legal for the conversion target.
373  void populateConvertToLLVMConversionPatterns(
374  ConversionTarget &target, LLVMTypeConverter &typeConverter,
375  RewritePatternSet &patterns) const final {
376  configureOpenMPToLLVMConversionLegality(target, typeConverter);
378  }
379 };
380 } // namespace
381 
383  registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
384  dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
385  });
386 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:733
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.