MLIR  21.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// A pattern that converts the region arguments in a single-region OpenMP
32 /// operation to the LLVM dialect. The body of the region is not modified and is
33 /// expected to either be processed by the conversion infrastructure or already
34 /// contain ops compatible with LLVM dialect types.
35 template <typename OpType>
36 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
38 
39  LogicalResult
40  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
41  ConversionPatternRewriter &rewriter) const override {
42  auto newOp = rewriter.create<OpType>(
43  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
44  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
45  newOp.getRegion().end());
46  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
47  *this->getTypeConverter())))
48  return failure();
49 
50  rewriter.eraseOp(curOp);
51  return success();
52  }
53 };
54 
55 template <typename T>
56 struct RegionLessOpWithVarOperandsConversion
57  : public ConvertOpToLLVMPattern<T> {
59  LogicalResult
60  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
61  ConversionPatternRewriter &rewriter) const override {
63  SmallVector<Type> resTypes;
64  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
65  return failure();
66  SmallVector<Value> convertedOperands;
67  assert(curOp.getNumVariableOperands() ==
68  curOp.getOperation()->getNumOperands() &&
69  "unexpected non-variable operands");
70  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
71  Value originalVariableOperand = curOp.getVariableOperand(idx);
72  if (!originalVariableOperand)
73  return failure();
74  if (isa<MemRefType>(originalVariableOperand.getType())) {
75  // TODO: Support memref type in variable operands
76  return rewriter.notifyMatchFailure(curOp,
77  "memref is not supported yet");
78  }
79  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
80  }
81 
82  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
83  curOp->getAttrs());
84  return success();
85  }
86 };
87 
88 template <typename T>
89 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
91  LogicalResult
92  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
93  ConversionPatternRewriter &rewriter) const override {
95  SmallVector<Type> resTypes;
96  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
97  return failure();
98  SmallVector<Value> convertedOperands;
99  assert(curOp.getNumVariableOperands() ==
100  curOp.getOperation()->getNumOperands() &&
101  "unexpected non-variable operands");
102  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
103  Value originalVariableOperand = curOp.getVariableOperand(idx);
104  if (!originalVariableOperand)
105  return failure();
106  if (isa<MemRefType>(originalVariableOperand.getType())) {
107  // TODO: Support memref type in variable operands
108  return rewriter.notifyMatchFailure(curOp,
109  "memref is not supported yet");
110  }
111  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
112  }
113  auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
114  curOp->getAttrs());
115  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
116  newOp.getRegion().end());
117  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
118  *this->getTypeConverter())))
119  return failure();
120 
121  rewriter.eraseOp(curOp);
122  return success();
123  }
124 };
125 
126 template <typename T>
127 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
129  LogicalResult
130  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
133  SmallVector<Type> resTypes;
134  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
135  return failure();
136 
137  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
138  curOp->getAttrs());
139  return success();
140  }
141 };
142 
143 struct AtomicReadOpConversion
144  : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
146  LogicalResult
147  matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
148  ConversionPatternRewriter &rewriter) const override {
150  Type curElementType = curOp.getElementType();
151  auto newOp = rewriter.create<omp::AtomicReadOp>(
152  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
153  TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
154  newOp.setElementTypeAttr(typeAttr);
155  rewriter.eraseOp(curOp);
156  return success();
157  }
158 };
159 
160 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
162  LogicalResult
163  matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
164  ConversionPatternRewriter &rewriter) const override {
166 
167  SmallVector<Type> resTypes;
168  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
169  return failure();
170 
171  // Copy attributes of the curOp except for the typeAttr which should
172  // be converted
174  for (NamedAttribute attr : curOp->getAttrs()) {
175  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
176  Type newAttr = converter->convertType(typeAttr.getValue());
177  newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
178  } else {
179  newAttrs.push_back(attr);
180  }
181  }
182 
183  rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
184  curOp, resTypes, adaptor.getOperands(), newAttrs);
185  return success();
186  }
187 };
188 
189 template <typename OpType>
190 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
192 
193  void forwardOpAttrs(OpType curOp, OpType newOp) const {}
194 
195  LogicalResult
196  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override {
198  auto newOp = rewriter.create<OpType>(
199  curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
200  TypeAttr::get(this->getTypeConverter()->convertType(
201  curOp.getTypeAttr().getValue())));
202  forwardOpAttrs(curOp, newOp);
203 
204  for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
205  rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
206  newOp.getRegion(idx).end());
207  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
208  *this->getTypeConverter())))
209  return failure();
210  }
211 
212  rewriter.eraseOp(curOp);
213  return success();
214  }
215 };
216 
217 template <>
218 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
219  omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
220  newOp.setDataSharingType(curOp.getDataSharingType());
221 }
222 } // namespace
223 
225  ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
226  target.addDynamicallyLegalOp<
227  omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
228  omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
229  omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
230  omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
231  omp::YieldOp>([&](Operation *op) {
232  return typeConverter.isLegal(op->getOperandTypes()) &&
233  typeConverter.isLegal(op->getResultTypes());
234  });
235  target.addDynamicallyLegalOp<
236  omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp,
237  omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp,
238  omp::OrderedRegionOp, omp::ParallelOp, omp::SectionOp, omp::SectionsOp,
239  omp::SimdOp, omp::SingleOp, omp::TargetDataOp, omp::TargetOp,
240  omp::TaskgroupOp, omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
241  omp::WsloopOp>([&](Operation *op) {
242  return std::all_of(op->getRegions().begin(), op->getRegions().end(),
243  [&](Region &region) {
244  return typeConverter.isLegal(&region);
245  }) &&
246  typeConverter.isLegal(op->getOperandTypes()) &&
247  typeConverter.isLegal(op->getResultTypes());
248  });
249  target.addDynamicallyLegalOp<omp::PrivateClauseOp>(
250  [&](omp::PrivateClauseOp op) -> bool {
251  return std::all_of(op->getRegions().begin(), op->getRegions().end(),
252  [&](Region &region) {
253  return typeConverter.isLegal(&region);
254  }) &&
255  typeConverter.isLegal(op->getOperandTypes()) &&
256  typeConverter.isLegal(op->getResultTypes()) &&
257  typeConverter.isLegal(op.getType());
258  });
259 }
260 
263  // This type is allowed when converting OpenMP to LLVM Dialect, it carries
264  // bounds information for map clauses and the operation and type are
265  // discarded on lowering to LLVM-IR from the OpenMP dialect.
266  converter.addConversion(
267  [&](omp::MapBoundsType type) -> Type { return type; });
268 
269  patterns.add<
270  AtomicReadOpConversion, MapInfoOpConversion,
271  MultiRegionOpConversion<omp::DeclareReductionOp>,
272  MultiRegionOpConversion<omp::PrivateClauseOp>,
273  RegionLessOpConversion<omp::CancellationPointOp>,
274  RegionLessOpConversion<omp::CancelOp>,
275  RegionLessOpConversion<omp::CriticalDeclareOp>,
276  RegionLessOpConversion<omp::OrderedOp>,
277  RegionLessOpConversion<omp::ScanOp>,
278  RegionLessOpConversion<omp::TargetEnterDataOp>,
279  RegionLessOpConversion<omp::TargetExitDataOp>,
280  RegionLessOpConversion<omp::TargetUpdateOp>,
281  RegionLessOpConversion<omp::YieldOp>,
282  RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
283  RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
284  RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
285  RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
286  RegionOpConversion<omp::AtomicCaptureOp>,
287  RegionOpConversion<omp::CriticalOp>,
288  RegionOpConversion<omp::DistributeOp>,
289  RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
290  RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
291  RegionOpConversion<omp::OrderedRegionOp>,
292  RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
293  RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
294  RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
295  RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
296  RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
297  RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
298  RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
299 }
300 
301 namespace {
302 struct ConvertOpenMPToLLVMPass
303  : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
304  using Base::Base;
305 
306  void runOnOperation() override;
307 };
308 } // namespace
309 
310 void ConvertOpenMPToLLVMPass::runOnOperation() {
311  auto module = getOperation();
312 
313  // Convert to OpenMP operations with LLVM IR dialect
315  LLVMTypeConverter converter(&getContext());
322 
324  target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
325  omp::TaskyieldOp, omp::TerminatorOp>();
326  configureOpenMPToLLVMConversionLegality(target, converter);
327  if (failed(applyPartialConversion(module, target, std::move(patterns))))
328  signalPassFailure();
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // ConvertToLLVMPatternInterface implementation
333 //===----------------------------------------------------------------------===//
334 namespace {
335 /// Implement the interface to convert OpenMP to LLVM.
336 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
338  void loadDependentDialects(MLIRContext *context) const final {
339  context->loadDialect<LLVM::LLVMDialect>();
340  }
341 
342  /// Hook for derived dialect interface to provide conversion patterns
343  /// and mark dialect legal for the conversion target.
344  void populateConvertToLLVMConversionPatterns(
345  ConversionTarget &target, LLVMTypeConverter &typeConverter,
346  RewritePatternSet &patterns) const final {
347  configureOpenMPToLLVMConversionLegality(target, typeConverter);
349  }
350 };
351 } // namespace
352 
354  registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
355  dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
356  });
357 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure=true)
Populate the cf.assert to LLVM conversion pattern.
void populateControlFlowToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
void registerConvertOpenMPToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMPatternInterface interface in the OpenMP dialect.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, const LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:733
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.