MLIR  16.0.0git
OpenMPToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVM
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 /// A pattern that converts the region arguments in a single-region OpenMP
31 /// operation to the LLVM dialect. The body of the region is not modified and is
32 /// expected to either be processed by the conversion infrastructure or already
33 /// contain ops compatible with LLVM dialect types.
34 template <typename OpType>
35 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
37 
39  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
40  ConversionPatternRewriter &rewriter) const override {
41  auto newOp = rewriter.create<OpType>(
42  curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
43  rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
44  newOp.getRegion().end());
45  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
46  *this->getTypeConverter())))
47  return failure();
48 
49  rewriter.eraseOp(curOp);
50  return success();
51  }
52 };
53 
54 template <typename T>
55 struct RegionLessOpWithVarOperandsConversion
56  : public ConvertOpToLLVMPattern<T> {
59  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
60  ConversionPatternRewriter &rewriter) const override {
62  SmallVector<Type> resTypes;
63  if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
64  return failure();
65  SmallVector<Value> convertedOperands;
66  assert(curOp.getNumVariableOperands() ==
67  curOp.getOperation()->getNumOperands() &&
68  "unexpected non-variable operands");
69  for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
70  Value originalVariableOperand = curOp.getVariableOperand(idx);
71  if (!originalVariableOperand)
72  return failure();
73  if (originalVariableOperand.getType().isa<MemRefType>()) {
74  // TODO: Support memref type in variable operands
75  return rewriter.notifyMatchFailure(curOp,
76  "memref is not supported yet");
77  }
78  convertedOperands.emplace_back(adaptor.getOperands()[idx]);
79  }
80  rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
81  curOp->getAttrs());
82  return success();
83  }
84 };
85 
86 struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
89  matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter) const override {
91  if (curOp.getAccumulator().getType().isa<MemRefType>()) {
92  // TODO: Support memref type in variable operands
93  return rewriter.notifyMatchFailure(curOp, "memref is not supported yet");
94  }
95  rewriter.replaceOpWithNewOp<omp::ReductionOp>(
96  curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs());
97  return success();
98  }
99 };
100 } // namespace
101 
103  ConversionTarget &target, LLVMTypeConverter &typeConverter) {
104  target.addDynamicallyLegalOp<mlir::omp::CriticalOp, mlir::omp::ParallelOp,
105  mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp,
106  mlir::omp::MasterOp, mlir::omp::SectionsOp,
107  mlir::omp::SingleOp>([&](Operation *op) {
108  return typeConverter.isLegal(&op->getRegion(0)) &&
109  typeConverter.isLegal(op->getOperandTypes()) &&
110  typeConverter.isLegal(op->getResultTypes());
111  });
112  target
113  .addDynamicallyLegalOp<mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp,
114  mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>(
115  [&](Operation *op) {
116  return typeConverter.isLegal(op->getOperandTypes()) &&
117  typeConverter.isLegal(op->getResultTypes());
118  });
119  target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) {
120  return typeConverter.isLegal(op->getOperandTypes());
121  });
122 }
123 
125  RewritePatternSet &patterns) {
126  patterns.add<
127  ReductionOpConversion, RegionOpConversion<omp::CriticalOp>,
128  RegionOpConversion<omp::MasterOp>, ReductionOpConversion,
129  RegionOpConversion<omp::MasterOp>, RegionOpConversion<omp::ParallelOp>,
130  RegionOpConversion<omp::WsLoopOp>, RegionOpConversion<omp::SectionsOp>,
131  RegionOpConversion<omp::SimdLoopOp>, RegionOpConversion<omp::SingleOp>,
132  RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
133  RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
134  RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
135  RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>>(converter);
136 }
137 
138 namespace {
139 struct ConvertOpenMPToLLVMPass
140  : public impl::ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
141  void runOnOperation() override;
142 };
143 } // namespace
144 
145 void ConvertOpenMPToLLVMPass::runOnOperation() {
146  auto module = getOperation();
147 
148  // Convert to OpenMP operations with LLVM IR dialect
149  RewritePatternSet patterns(&getContext());
150  LLVMTypeConverter converter(&getContext());
153  populateMemRefToLLVMConversionPatterns(converter, patterns);
154  populateFuncToLLVMConversionPatterns(converter, patterns);
155  populateOpenMPToLLVMConversionPatterns(converter, patterns);
156 
157  LLVMConversionTarget target(getContext());
158  target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
159  omp::BarrierOp, omp::TaskwaitOp>();
160  configureOpenMPToLLVMConversionLegality(target, converter);
161  if (failed(applyPartialConversion(module, target, std::move(patterns))))
162  signalPassFailure();
163 }
164 
165 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
166  return std::make_unique<ConvertOpenMPToLLVMPass>();
167 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:133
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling &#39;results&#39; as necessary.
void addLegalOp(OperationName op)
Register the given operations as legal.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
Derived class that automatically populates legalization information for different LLVM ops...
void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from OpenMP to LLVM.
void configureOpenMPToLLVMConversionLegality(ConversionTarget &target, LLVMTypeConverter &typeConverter)
Configure dynamic conversion legality of regionless operations from OpenMP to LLVM.
void populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:414
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Type conversion class.
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:28
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
This class implements a pattern rewriter for use with ConversionPatterns.
std::unique_ptr< OperationPass< ModuleOp > > createConvertOpenMPToLLVMPass()
Create a pass to convert OpenMP operations to the LLVMIR dialect.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:692
bool isa() const
Definition: Types.h:258
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.