MLIR  14.0.0git
OpenACCToLLVM.cpp
Go to the documentation of this file.
1 //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../PassDetail.h"
14 #include "mlir/IR/Builders.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // DataDescriptor implementation
20 //===----------------------------------------------------------------------===//
21 
22 constexpr StringRef getStructName() { return "openacc_data"; }
23 
24 /// Construct a helper for the given descriptor value.
26  assert(value != nullptr && "value cannot be null");
27 }
28 
29 /// Builds IR creating an `undef` value of the data descriptor.
31  Type basePtrTy, Type ptrTy) {
33  builder.getContext(), getStructName(),
34  {basePtrTy, ptrTy, builder.getI64Type()});
35  Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36  return DataDescriptor(descriptor);
37 }
38 
39 /// Check whether the type is a valid data descriptor.
40 bool DataDescriptor::isValid(Value descriptor) {
41  if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
42  if (type.isIdentified() && type.getName().startswith(getStructName()) &&
43  type.getBody().size() == 3 &&
44  (type.getBody()[kPtrBasePosInDataDescriptor]
45  .isa<LLVM::LLVMPointerType>() ||
46  type.getBody()[kPtrBasePosInDataDescriptor]
47  .isa<LLVM::LLVMStructType>()) &&
48  type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
49  type.getBody()[kSizePosInDataDescriptor].isInteger(64))
50  return true;
51  }
52  return false;
53 }
54 
55 /// Builds IR inserting the base pointer value into the descriptor.
57  Value basePtr) {
58  setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
59 }
60 
61 /// Builds IR inserting the pointer value into the descriptor.
63  setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
64 }
65 
66 /// Builds IR inserting the size value into the descriptor.
67 void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
68  setPtr(builder, loc, kSizePosInDataDescriptor, size);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Conversion patterns
73 //===----------------------------------------------------------------------===//
74 
75 namespace {
76 
77 template <typename Op>
78 class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
80 
82  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
83  ConversionPatternRewriter &builder) const override {
84  Location loc = op.getLoc();
86 
87  unsigned numDataOperand = op.getNumDataOperands();
88 
89  // Keep the non data operands without modification.
90  auto nonDataOperands = adaptor.getOperands().take_front(
91  adaptor.getOperands().size() - numDataOperand);
92  SmallVector<Value> convertedOperands;
93  convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
94 
95  // Go over the data operand and legalize them for translation.
96  for (unsigned idx = 0; idx < numDataOperand; ++idx) {
97  Value originalDataOperand = op.getDataOperand(idx);
98 
99  // Traverse operands that were converted to MemRefDescriptors.
100  if (auto memRefType =
101  originalDataOperand.getType().dyn_cast<MemRefType>()) {
102  Type structType = converter->convertType(memRefType);
103  Value memRefDescriptor = builder
104  .create<UnrealizedConversionCastOp>(
105  loc, structType, originalDataOperand)
106  .getResult(0);
107 
108  // Calculate the size of the memref and get the pointer to the allocated
109  // buffer.
110  SmallVector<Value> sizes;
111  SmallVector<Value> strides;
112  Value sizeBytes;
114  loc, memRefType, {}, builder, sizes, strides, sizeBytes);
115  MemRefDescriptor descriptor(memRefDescriptor);
116  Value dataPtr = descriptor.alignedPtr(builder, loc);
117  auto ptrType = descriptor.getElementPtrType();
118 
119  auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
120  descr.setBasePointer(builder, loc, memRefDescriptor);
121  descr.setPointer(builder, loc, dataPtr);
122  descr.setSize(builder, loc, sizeBytes);
123  convertedOperands.push_back(descr);
124  } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
125  convertedOperands.push_back(originalDataOperand);
126  } else {
127  // Type not supported.
128  return builder.notifyMatchFailure(op, "unsupported type");
129  }
130  }
131 
132  builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
133  op.getOperation()->getAttrs());
134 
135  return success();
136  }
137 };
138 } // namespace
139 
141  LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
142  patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
143  patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
144  patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
145  patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
146  patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
147 }
148 
149 namespace {
150 struct ConvertOpenACCToLLVMPass
151  : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
152  void runOnOperation() override;
153 };
154 } // namespace
155 
156 void ConvertOpenACCToLLVMPass::runOnOperation() {
157  auto op = getOperation();
158  auto *context = op.getContext();
159 
160  // Convert to OpenACC operations with LLVM IR dialect
161  RewritePatternSet patterns(context);
162  LLVMTypeConverter converter(context);
163  populateOpenACCToLLVMConversionPatterns(converter, patterns);
164 
165  ConversionTarget target(*context);
166  target.addLegalDialect<LLVM::LLVMDialect>();
167  target.addLegalOp<UnrealizedConversionCastOp>();
168 
169  auto allDataOperandsAreConverted = [](ValueRange operands) {
170  for (Value operand : operands) {
171  if (!DataDescriptor::isValid(operand) &&
172  !operand.getType().isa<LLVM::LLVMPointerType>())
173  return false;
174  }
175  return true;
176  };
177 
178  target.addDynamicallyLegalOp<acc::DataOp>(
179  [allDataOperandsAreConverted](acc::DataOp op) {
180  return allDataOperandsAreConverted(op.copyOperands()) &&
181  allDataOperandsAreConverted(op.copyinOperands()) &&
182  allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
183  allDataOperandsAreConverted(op.copyoutOperands()) &&
184  allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
185  allDataOperandsAreConverted(op.createOperands()) &&
186  allDataOperandsAreConverted(op.createZeroOperands()) &&
187  allDataOperandsAreConverted(op.noCreateOperands()) &&
188  allDataOperandsAreConverted(op.presentOperands()) &&
189  allDataOperandsAreConverted(op.deviceptrOperands()) &&
190  allDataOperandsAreConverted(op.attachOperands());
191  });
192 
193  target.addDynamicallyLegalOp<acc::EnterDataOp>(
194  [allDataOperandsAreConverted](acc::EnterDataOp op) {
195  return allDataOperandsAreConverted(op.copyinOperands()) &&
196  allDataOperandsAreConverted(op.createOperands()) &&
197  allDataOperandsAreConverted(op.createZeroOperands()) &&
198  allDataOperandsAreConverted(op.attachOperands());
199  });
200 
201  target.addDynamicallyLegalOp<acc::ExitDataOp>(
202  [allDataOperandsAreConverted](acc::ExitDataOp op) {
203  return allDataOperandsAreConverted(op.copyoutOperands()) &&
204  allDataOperandsAreConverted(op.deleteOperands()) &&
205  allDataOperandsAreConverted(op.detachOperands());
206  });
207 
208  target.addDynamicallyLegalOp<acc::ParallelOp>(
209  [allDataOperandsAreConverted](acc::ParallelOp op) {
210  return allDataOperandsAreConverted(op.reductionOperands()) &&
211  allDataOperandsAreConverted(op.copyOperands()) &&
212  allDataOperandsAreConverted(op.copyinOperands()) &&
213  allDataOperandsAreConverted(op.copyinReadonlyOperands()) &&
214  allDataOperandsAreConverted(op.copyoutOperands()) &&
215  allDataOperandsAreConverted(op.copyoutZeroOperands()) &&
216  allDataOperandsAreConverted(op.createOperands()) &&
217  allDataOperandsAreConverted(op.createZeroOperands()) &&
218  allDataOperandsAreConverted(op.noCreateOperands()) &&
219  allDataOperandsAreConverted(op.presentOperands()) &&
220  allDataOperandsAreConverted(op.devicePtrOperands()) &&
221  allDataOperandsAreConverted(op.attachOperands()) &&
222  allDataOperandsAreConverted(op.gangPrivateOperands()) &&
223  allDataOperandsAreConverted(op.gangFirstPrivateOperands());
224  });
225 
226  target.addDynamicallyLegalOp<acc::UpdateOp>(
227  [allDataOperandsAreConverted](acc::UpdateOp op) {
228  return allDataOperandsAreConverted(op.hostOperands()) &&
229  allDataOperandsAreConverted(op.deviceOperands());
230  });
231 
232  if (failed(applyPartialConversion(op, target, std::move(patterns))))
233  signalPassFailure();
234 }
235 
236 std::unique_ptr<OperationPass<ModuleOp>>
238  return std::make_unique<ConvertOpenACCToLLVMPass>();
239 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
void setBasePointer(OpBuilder &builder, Location loc, Value basePtr)
Builds IR inserting the base pointer value into the descriptor.
MLIRContext * getContext() const
Definition: Builders.h:54
static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, ArrayRef< Type > elements, bool isPacked=false)
Gets a new identified struct with the given body.
Definition: LLVMTypes.cpp:354
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static constexpr unsigned kPtrBasePosInDataDescriptor
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Operation * getOperation()
Return the operation that this refers to.
Definition: OpDefinition.h:106
LogicalResult notifyMatchFailure(Operation *op, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static constexpr unsigned kPtrPosInDataDescriptor
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
U dyn_cast() const
Definition: Types.h:244
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Helper class to produce LLVM dialect operations extracting or inserting values to a struct...
Definition: StructBuilder.h:26
void populateOpenACCToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the OpenACC dialect LLVMIR dialect.
constexpr StringRef getStructName()
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
IntegerType getI64Type()
Definition: Builders.cpp:56
Helper class to produce LLVM dialect operations inserting elements to a Data descriptor.
void setSize(OpBuilder &builder, Location loc, Value size)
Builds IR inserting the size value into the descriptor.
static DataDescriptor undef(OpBuilder &builder, Location loc, Type basePtrTy, Type ptrTy)
Builds IR creating an undef value of the descriptor type.
static constexpr unsigned kSizePosInDataDescriptor
std::unique_ptr< OperationPass< ModuleOp > > createConvertOpenACCToLLVMPass()
Create a pass to convert the OpenACC dialect into the LLVMIR dialect.
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
Type conversion class.
static bool isValid(Value descriptor)
Check whether the type is a valid data descriptor.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:252
LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: TypeConverter.h:30
void addLegalOp()
Register the given operations as legal.
This class implements a pattern rewriter for use with ConversionPatterns.
LLVM dialect pointer type.
Definition: LLVMTypes.h:181
This provides public APIs that all operations should have.
DataDescriptor(Value descriptor)
Construct a helper for the given descriptor value.
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &sizeBytes) const
Computes sizes, strides and buffer size in bytes of memRefType with identity layout.
Definition: Pattern.cpp:119
This class describes a specific conversion target.
bool isa() const
Definition: Types.h:234
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
void setPointer(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the pointer value into the descriptor.