MLIR  19.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 /// A pattern that converts the region arguments in a single-region OpenMP
31 /// operation to the LLVM dialect. The body of the region is not modified and is
32 /// expected to either be processed by the conversion infrastructure or already
33 /// contain ops compatible with LLVM dialect types.
34 template <typename OpType>
35 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
37 
39  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
40  ConversionPatternRewriter &rewriter) const override {
41  auto newOp = rewriter.create<OpType>(
42  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
43  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
44  newOp.getRegion().end());
45  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
46  *this->getTypeConverter())))
47  return failure();
48 
49  rewriter.eraseOp(curOp);
50  return success();
51  }
52 };
53 
54 template <typename T>
55 struct RegionLessOpWithVarOperandsConversion
56  : public ConvertOpToLLVMPattern<T> {
59  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
60  ConversionPatternRewriter &rewriter) const override {
62  SmallVector<Type> resTypes;
63  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
64  return failure();
65  SmallVector<Value> convertedOperands;
66  assert(curOp.getNumVariableOperands() ==
67  curOp.getOperation()->getNumOperands() &&
68  "unexpected non-variable operands");
69  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
70  Value originalVariableOperand = curOp.getVariableOperand(idx);
71  if (!originalVariableOperand)
72  return failure();
73  if (isa<MemRefType>(originalVariableOperand.getType())) {
74  // TODO: Support memref type in variable operands
75  return rewriter.notifyMatchFailure(curOp,
76  "memref is not supported yet");
77  }
78  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
79  }
80 
81  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
82  curOp->getAttrs());
83  return success();
84  }
85 };
86 
87 template <typename T>
88 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
91  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
92  ConversionPatternRewriter &rewriter) const override {
94  SmallVector<Type> resTypes;
95  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
96  return failure();
97  SmallVector<Value> convertedOperands;
98  assert(curOp.getNumVariableOperands() ==
99  curOp.getOperation()->getNumOperands() &&
100  "unexpected non-variable operands");
101  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
102  Value originalVariableOperand = curOp.getVariableOperand(idx);
103  if (!originalVariableOperand)
104  return failure();
105  if (isa<MemRefType>(originalVariableOperand.getType())) {
106  // TODO: Support memref type in variable operands
107  return rewriter.notifyMatchFailure(curOp,
108  "memref is not supported yet");
109  }
110  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
111  }
112  auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
113  curOp->getAttrs());
114  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
115  newOp.getRegion().end());
116  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
117  *this->getTypeConverter())))
118  return failure();
119 
120  rewriter.eraseOp(curOp);
121  return success();
122  }
123 };
124 
125 template <typename T>
126 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
129  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override {
132  SmallVector<Type> resTypes;
133  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
134  return failure();
135 
136  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
137  curOp->getAttrs());
138  return success();
139  }
140 };
141 
142 struct AtomicReadOpConversion
143  : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
146  matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
149  Type curElementType = curOp.getElementType();
150  auto newOp = rewriter.create<omp::AtomicReadOp>(
151  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
152  TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
153  newOp.setElementTypeAttr(typeAttr);
154  rewriter.eraseOp(curOp);
155  return success();
156  }
157 };
158 
159 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
162  matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
163  ConversionPatternRewriter &rewriter) const override {
165 
166  SmallVector<Type> resTypes;
167  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
168  return failure();
169 
170  // Copy attributes of the curOp except for the typeAttr which should
171  // be converted
173  for (NamedAttribute attr : curOp->getAttrs()) {
174  if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
175  Type newAttr = converter->convertType(typeAttr.getValue());
176  newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
177  } else {
178  newAttrs.push_back(attr);
179  }
180  }
181 
182  rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
183  curOp, resTypes, adaptor.getOperands(), newAttrs);
184  return success();
185  }
186 };
187 
188 struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
191  matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const override {
193  if (isa<MemRefType>(curOp.getAccumulator().getType())) {
194  // TODO: Support memref type in variable operands
195  return rewriter.notifyMatchFailure(curOp, "memref is not supported yet");
196  }
197  rewriter.replaceOpWithNewOp<omp::ReductionOp>(
198  curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs());
199  return success();
200  }
201 };
202 
203 template <typename OpType>
204 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
206 
207  void forwardOpAttrs(OpType curOp, OpType newOp) const {}
208 
210  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
211  ConversionPatternRewriter &rewriter) const override {
212  auto newOp = rewriter.create<OpType>(
213  curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
214  TypeAttr::get(this->getTypeConverter()->convertType(
215  curOp.getTypeAttr().getValue())));
216  forwardOpAttrs(curOp, newOp);
217 
218  for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
219  rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
220  newOp.getRegion(idx).end());
221  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
222  *this->getTypeConverter())))
223  return failure();
224  }
225 
226  rewriter.eraseOp(curOp);
227  return success();
228  }
229 };
230 
231 template <>
232 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
233  omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
234  newOp.setDataSharingType(curOp.getDataSharingType());
235 }
236 } // namespace
237 
239  ConversionTarget &target, LLVMTypeConverter &typeConverter) {
240  target.addDynamicallyLegalOp<
241  mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp,
242  mlir::omp::ThreadprivateOp, mlir::omp::YieldOp,
243  mlir::omp::TargetEnterDataOp, mlir::omp::TargetExitDataOp,
244  mlir::omp::TargetUpdateOp, mlir::omp::MapBoundsOp, mlir::omp::MapInfoOp>(
245  [&](Operation *op) {
246  return typeConverter.isLegal(op->getOperandTypes()) &&
247  typeConverter.isLegal(op->getResultTypes());
248  });
249  target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) {
250  return typeConverter.isLegal(op->getOperandTypes());
251  });
252  target.addDynamicallyLegalOp<
253  mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::TargetOp,
254  mlir::omp::TargetDataOp, mlir::omp::LoopNestOp,
255  mlir::omp::OrderedRegionOp, mlir::omp::ParallelOp, mlir::omp::WsloopOp,
256  mlir::omp::SimdOp, mlir::omp::MasterOp, mlir::omp::SectionOp,
257  mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskgroupOp,
258  mlir::omp::TaskOp, mlir::omp::DeclareReductionOp,
259  mlir::omp::PrivateClauseOp>([&](Operation *op) {
260  return std::all_of(op->getRegions().begin(), op->getRegions().end(),
261  [&](Region &region) {
262  return typeConverter.isLegal(&region);
263  }) &&
264  typeConverter.isLegal(op->getOperandTypes()) &&
265  typeConverter.isLegal(op->getResultTypes());
266  });
267 }
268 
270  RewritePatternSet &patterns) {
271  // This type is allowed when converting OpenMP to LLVM Dialect, it carries
272  // bounds information for map clauses and the operation and type are
273  // discarded on lowering to LLVM-IR from the OpenMP dialect.
274  converter.addConversion(
275  [&](omp::MapBoundsType type) -> Type { return type; });
276 
277  patterns.add<
278  AtomicReadOpConversion, MapInfoOpConversion, ReductionOpConversion,
279  MultiRegionOpConversion<omp::DeclareReductionOp>,
280  MultiRegionOpConversion<omp::PrivateClauseOp>,
281  RegionOpConversion<omp::CriticalOp>, RegionOpConversion<omp::LoopNestOp>,
282  RegionOpConversion<omp::MasterOp>, ReductionOpConversion,
283  RegionOpConversion<omp::OrderedRegionOp>,
284  RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::WsloopOp>,
285  RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SectionOp>,
286  RegionOpConversion<omp::SimdOp>, RegionOpConversion<omp::SingleOp>,
287  RegionOpConversion<omp::TaskgroupOp>, RegionOpConversion<omp::TaskOp>,
288  RegionOpConversion<omp::TargetDataOp>, RegionOpConversion<omp::TargetOp>,
289  RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
290  RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>,
291  RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
292  RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
293  RegionLessOpConversion<omp::YieldOp>,
294  RegionLessOpConversion<omp::TargetEnterDataOp>,
295  RegionLessOpConversion<omp::TargetExitDataOp>,
296  RegionLessOpConversion<omp::TargetUpdateOp>,
297  RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>>(converter);
298 }
299 
300 namespace {
301 struct ConvertOpenMPToLLVMPass
302  : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
303  using Base::Base;
304 
305  void runOnOperation() override;
306 };
307 } // namespace
308 
309 void ConvertOpenMPToLLVMPass::runOnOperation() {
310  auto module = getOperation();
311 
312  // Convert to OpenMP operations with LLVM IR dialect
313  RewritePatternSet patterns(&getContext());
314  LLVMTypeConverter converter(&getContext());
318  populateFuncToLLVMConversionPatterns(converter, patterns);
319  populateOpenMPToLLVMConversionPatterns(converter, patterns);
320 
322  target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
323  omp::BarrierOp, omp::TaskwaitOp>();
324  configureOpenMPToLLVMConversionLegality(target, converter);
325  if (failed(applyPartialConversion(module, target, std::move(patterns))))
326  signalPassFailure();
327 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:752
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26