MLIR  22.0.0git
PtrToLLVM.cpp
Go to the documentation of this file.
1 //===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include <type_traits>
21 
22 using namespace mlir;
23 
24 namespace {
25 //===----------------------------------------------------------------------===//
26 // FromPtrOpConversion
27 //===----------------------------------------------------------------------===//
28 struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> {
30  LogicalResult
31  matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
32  ConversionPatternRewriter &rewriter) const override;
33 };
34 
35 //===----------------------------------------------------------------------===//
36 // GetMetadataOpConversion
37 //===----------------------------------------------------------------------===//
38 struct GetMetadataOpConversion
39  : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> {
41  LogicalResult
42  matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
43  ConversionPatternRewriter &rewriter) const override;
44 };
45 
46 //===----------------------------------------------------------------------===//
47 // PtrAddOpConversion
48 //===----------------------------------------------------------------------===//
49 struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> {
51  LogicalResult
52  matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override;
54 };
55 
56 //===----------------------------------------------------------------------===//
57 // ToPtrOpConversion
58 //===----------------------------------------------------------------------===//
59 struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> {
61  LogicalResult
62  matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
63  ConversionPatternRewriter &rewriter) const override;
64 };
65 
66 //===----------------------------------------------------------------------===//
67 // TypeOffsetOpConversion
68 //===----------------------------------------------------------------------===//
69 struct TypeOffsetOpConversion
70  : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> {
72  LogicalResult
73  matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
74  ConversionPatternRewriter &rewriter) const override;
75 };
76 } // namespace
77 
78 //===----------------------------------------------------------------------===//
79 // Internal functions
80 //===----------------------------------------------------------------------===//
81 
82 // Function to create an LLVM struct type representing a memref metadata.
83 static FailureOr<LLVM::LLVMStructType>
84 createMemRefMetadataType(MemRefType type,
85  const LLVMTypeConverter &typeConverter) {
86  MLIRContext *context = type.getContext();
87  // Get the address space.
88  FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type);
89  if (failed(addressSpace))
90  return failure();
91 
92  // Get pointer type (using address space 0 by default)
93  auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
94 
95  // Get the strides offsets and shape.
96  SmallVector<int64_t> strides;
97  int64_t offset;
98  if (failed(type.getStridesAndOffset(strides, offset)))
99  return failure();
100  ArrayRef<int64_t> shape = type.getShape();
101 
102  // Use index type from the type converter for the descriptor elements
103  Type indexType = typeConverter.getIndexType();
104 
105  // For a ranked memref, the descriptor contains:
106  // 1. The pointer to the allocated data
107  // 2. The pointer to the aligned data
108  // 3. The dynamic offset?
109  // 4. The dynamic sizes?
110  // 5. The dynamic strides?
111  SmallVector<Type, 5> elements;
112 
113  // Allocated pointer.
114  elements.push_back(ptrType);
115 
116  // Potentially add the dynamic offset.
117  if (offset == ShapedType::kDynamic)
118  elements.push_back(indexType);
119 
120  // Potentially add the dynamic sizes.
121  for (int64_t dim : shape) {
122  if (dim == ShapedType::kDynamic)
123  elements.push_back(indexType);
124  }
125 
126  // Potentially add the dynamic strides.
127  for (int64_t stride : strides) {
128  if (stride == ShapedType::kDynamic)
129  elements.push_back(indexType);
130  }
131  return LLVM::LLVMStructType::getLiteral(context, elements);
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // FromPtrOpConversion
136 //===----------------------------------------------------------------------===//
137 
138 LogicalResult FromPtrOpConversion::matchAndRewrite(
139  ptr::FromPtrOp op, OpAdaptor adaptor,
140  ConversionPatternRewriter &rewriter) const {
141  // Get the target memref type
142  auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
143  if (!mTy)
144  return rewriter.notifyMatchFailure(op, "Expected memref result type");
145 
146  if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
147  return rewriter.notifyMatchFailure(
148  op, "Can convert only memrefs with metadata");
149  }
150 
151  // Convert the result type
152  Type descriptorTy = getTypeConverter()->convertType(mTy);
153  if (!descriptorTy)
154  return rewriter.notifyMatchFailure(op, "Failed to convert result type");
155 
156  // Get the strides, offsets and shape.
157  SmallVector<int64_t> strides;
158  int64_t offset;
159  if (failed(mTy.getStridesAndOffset(strides, offset))) {
160  return rewriter.notifyMatchFailure(op,
161  "Failed to get the strides and offset");
162  }
163  ArrayRef<int64_t> shape = mTy.getShape();
164 
165  // Create a new memref descriptor
166  Location loc = op.getLoc();
167  auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
168 
169  // Set the allocated and aligned pointers.
170  desc.setAllocatedPtr(
171  rewriter, loc,
172  rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getMetadata(), 0));
173  desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
174 
175  // Extract metadata from the passed struct.
176  unsigned fieldIdx = 1;
177 
178  // Set dynamic offset if needed.
179  if (offset == ShapedType::kDynamic) {
180  Value offsetValue = rewriter.create<LLVM::ExtractValueOp>(
181  loc, adaptor.getMetadata(), fieldIdx++);
182  desc.setOffset(rewriter, loc, offsetValue);
183  } else {
184  desc.setConstantOffset(rewriter, loc, offset);
185  }
186 
187  // Set dynamic sizes if needed.
188  for (auto [i, dim] : llvm::enumerate(shape)) {
189  if (dim == ShapedType::kDynamic) {
190  Value sizeValue = rewriter.create<LLVM::ExtractValueOp>(
191  loc, adaptor.getMetadata(), fieldIdx++);
192  desc.setSize(rewriter, loc, i, sizeValue);
193  } else {
194  desc.setConstantSize(rewriter, loc, i, dim);
195  }
196  }
197 
198  // Set dynamic strides if needed.
199  for (auto [i, stride] : llvm::enumerate(strides)) {
200  if (stride == ShapedType::kDynamic) {
201  Value strideValue = rewriter.create<LLVM::ExtractValueOp>(
202  loc, adaptor.getMetadata(), fieldIdx++);
203  desc.setStride(rewriter, loc, i, strideValue);
204  } else {
205  desc.setConstantStride(rewriter, loc, i, stride);
206  }
207  }
208 
209  rewriter.replaceOp(op, static_cast<Value>(desc));
210  return success();
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // GetMetadataOpConversion
215 //===----------------------------------------------------------------------===//
216 
217 LogicalResult GetMetadataOpConversion::matchAndRewrite(
218  ptr::GetMetadataOp op, OpAdaptor adaptor,
219  ConversionPatternRewriter &rewriter) const {
220  auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
221  if (!mTy)
222  return rewriter.notifyMatchFailure(op, "Only memref metadata is supported");
223 
224  // Get the metadata type.
225  FailureOr<LLVM::LLVMStructType> mdTy =
226  createMemRefMetadataType(mTy, *getTypeConverter());
227  if (failed(mdTy)) {
228  return rewriter.notifyMatchFailure(op,
229  "Failed to create the metadata type");
230  }
231 
232  // Get the memref descriptor.
233  MemRefDescriptor descriptor(adaptor.getPtr());
234 
235  // Get the strides offsets and shape.
236  SmallVector<int64_t> strides;
237  int64_t offset;
238  if (failed(mTy.getStridesAndOffset(strides, offset))) {
239  return rewriter.notifyMatchFailure(op,
240  "Failed to get the strides and offset");
241  }
242  ArrayRef<int64_t> shape = mTy.getShape();
243 
244  // Create a new LLVM struct to hold the metadata
245  Location loc = op.getLoc();
246  Value sV = rewriter.create<LLVM::UndefOp>(loc, *mdTy);
247 
248  // First element is the allocated pointer.
249  sV = rewriter.create<LLVM::InsertValueOp>(
250  loc, sV, descriptor.allocatedPtr(rewriter, loc), 0);
251 
252  // Track the current field index.
253  unsigned fieldIdx = 1;
254 
255  // Add dynamic offset if needed.
256  if (offset == ShapedType::kDynamic) {
257  sV = rewriter.create<LLVM::InsertValueOp>(
258  loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
259  }
260 
261  // Add dynamic sizes if needed.
262  for (auto [i, dim] : llvm::enumerate(shape)) {
263  if (dim != ShapedType::kDynamic)
264  continue;
265  sV = rewriter.create<LLVM::InsertValueOp>(
266  loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
267  }
268 
269  // Add dynamic strides if needed
270  for (auto [i, stride] : llvm::enumerate(strides)) {
271  if (stride != ShapedType::kDynamic)
272  continue;
273  sV = rewriter.create<LLVM::InsertValueOp>(
274  loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
275  }
276  rewriter.replaceOp(op, sV);
277  return success();
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // PtrAddOpConversion
282 //===----------------------------------------------------------------------===//
283 
284 LogicalResult
285 PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const {
287  // Get and check the base.
288  Value base = adaptor.getBase();
289  if (!isa<LLVM::LLVMPointerType>(base.getType()))
290  return rewriter.notifyMatchFailure(op, "Incompatible pointer type");
291 
292  // Get the offset.
293  Value offset = adaptor.getOffset();
294 
295  // Ptr assumes the offset is in bytes.
296  Type elementType = IntegerType::get(rewriter.getContext(), 8);
297 
298  // Convert the `ptradd` flags.
299  LLVM::GEPNoWrapFlags flags;
300  switch (op.getFlags()) {
301  case ptr::PtrAddFlags::none:
302  flags = LLVM::GEPNoWrapFlags::none;
303  break;
304  case ptr::PtrAddFlags::nusw:
305  flags = LLVM::GEPNoWrapFlags::nusw;
306  break;
307  case ptr::PtrAddFlags::nuw:
308  flags = LLVM::GEPNoWrapFlags::nuw;
309  break;
310  case ptr::PtrAddFlags::inbounds:
311  flags = LLVM::GEPNoWrapFlags::inbounds;
312  break;
313  }
314 
315  // Create the GEP operation with appropriate arguments
316  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType,
317  base, ValueRange{offset}, flags);
318  return success();
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // ToPtrOpConversion
323 //===----------------------------------------------------------------------===//
324 
325 LogicalResult
326 ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
327  ConversionPatternRewriter &rewriter) const {
328  // Bail if it's not a memref.
329  if (!isa<MemRefType>(op.getPtr().getType()))
330  return rewriter.notifyMatchFailure(op, "Expected a memref input");
331 
332  // Extract the aligned pointer from the memref descriptor.
333  rewriter.replaceOp(
334  op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
335  return success();
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // TypeOffsetOpConversion
340 //===----------------------------------------------------------------------===//
341 
342 LogicalResult TypeOffsetOpConversion::matchAndRewrite(
343  ptr::TypeOffsetOp op, OpAdaptor adaptor,
344  ConversionPatternRewriter &rewriter) const {
345  // Convert the type attribute.
346  Type type = getTypeConverter()->convertType(op.getElementType());
347  if (!type)
348  return rewriter.notifyMatchFailure(op, "Couldn't convert the type");
349 
350  // Convert the result type.
351  Type rTy = getTypeConverter()->convertType(op.getResult().getType());
352  if (!rTy)
353  return rewriter.notifyMatchFailure(op, "Couldn't convert the result type");
354 
355  // TODO: Use MLIR's data layout. We don't use it because overall support is
356  // still flaky.
357 
358  // Create an LLVM pointer type for the GEP operation.
359  auto ptrTy = LLVM::LLVMPointerType::get(getContext());
360 
361  // Create a GEP operation to compute the offset of the type.
362  auto offset =
363  LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
364  LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
365  ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
366 
367  // Replace the original op with a PtrToIntOp using the computed offset.
368  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
369  return success();
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // ConvertToLLVMPatternInterface implementation
374 //===----------------------------------------------------------------------===//
375 
376 namespace {
377 /// Implement the interface to convert Ptr to LLVM.
378 struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
380  void loadDependentDialects(MLIRContext *context) const final {
381  context->loadDialect<LLVM::LLVMDialect>();
382  }
383 
384  /// Hook for derived dialect interface to provide conversion patterns
385  /// and mark dialect legal for the conversion target.
386  void populateConvertToLLVMConversionPatterns(
387  ConversionTarget &target, LLVMTypeConverter &converter,
388  RewritePatternSet &patterns) const final {
390  }
391 };
392 } // namespace
393 
394 //===----------------------------------------------------------------------===//
395 // API
396 //===----------------------------------------------------------------------===//
397 
400  // Add address space conversions.
401  converter.addTypeAttributeConversion(
402  [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
404  if (type.getMemorySpace() != memorySpace)
406  return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
407  });
408 
409  // Add type conversions.
410  converter.addConversion([&](ptr::PtrType type) -> Type {
411  std::optional<Attribute> maybeAttr =
412  converter.convertTypeAttribute(type, type.getMemorySpace());
413  auto memSpace =
414  maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
415  if (!memSpace)
416  return {};
417  return LLVM::LLVMPointerType::get(type.getContext(),
418  memSpace.getValue().getSExtValue());
419  });
420 
421  // Convert ptr metadata of memref type.
422  converter.addConversion([&](ptr::PtrMetadataType type) -> Type {
423  auto mTy = dyn_cast<MemRefType>(type.getType());
424  if (!mTy)
425  return {};
426  FailureOr<LLVM::LLVMStructType> res =
427  createMemRefMetadataType(mTy, converter);
428  return failed(res) ? Type() : res.value();
429  });
430 
431  // Add conversion patterns.
432  patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
433  ToPtrOpConversion, TypeOffsetOpConversion>(converter);
434 }
435 
437  registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
438  dialect->addInterfaces<PtrToLLVMDialectInterface>();
439  });
440 }
static MLIRContext * getContext(OpFoldResult val)
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
Definition: PtrToLLVM.cpp:84
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
The general result of a type attribute conversion callback, allowing for early termination.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
Definition: PtrToLLVM.cpp:398
void registerConvertPtrToLLVMInterface(DialectRegistry &registry)
Register the convert to LLVM interface for the ptr dialect.
Definition: PtrToLLVM.cpp:436
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...