MLIR  22.0.0git
PtrToLLVM.cpp
Go to the documentation of this file.
1 //===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include <type_traits>
21 
22 using namespace mlir;
23 
24 namespace {
25 //===----------------------------------------------------------------------===//
26 // FromPtrOpConversion
27 //===----------------------------------------------------------------------===//
28 struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> {
30  LogicalResult
31  matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
32  ConversionPatternRewriter &rewriter) const override;
33 };
34 
35 //===----------------------------------------------------------------------===//
36 // GetMetadataOpConversion
37 //===----------------------------------------------------------------------===//
38 struct GetMetadataOpConversion
39  : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> {
41  LogicalResult
42  matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
43  ConversionPatternRewriter &rewriter) const override;
44 };
45 
46 //===----------------------------------------------------------------------===//
47 // PtrAddOpConversion
48 //===----------------------------------------------------------------------===//
49 struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> {
51  LogicalResult
52  matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override;
54 };
55 
56 //===----------------------------------------------------------------------===//
57 // ToPtrOpConversion
58 //===----------------------------------------------------------------------===//
59 struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> {
61  LogicalResult
62  matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
63  ConversionPatternRewriter &rewriter) const override;
64 };
65 
66 //===----------------------------------------------------------------------===//
67 // TypeOffsetOpConversion
68 //===----------------------------------------------------------------------===//
69 struct TypeOffsetOpConversion
70  : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> {
72  LogicalResult
73  matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
74  ConversionPatternRewriter &rewriter) const override;
75 };
76 } // namespace
77 
78 //===----------------------------------------------------------------------===//
79 // Internal functions
80 //===----------------------------------------------------------------------===//
81 
82 // Function to create an LLVM struct type representing a memref metadata.
83 static FailureOr<LLVM::LLVMStructType>
84 createMemRefMetadataType(MemRefType type,
85  const LLVMTypeConverter &typeConverter) {
86  MLIRContext *context = type.getContext();
87  // Get the address space.
88  FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type);
89  if (failed(addressSpace))
90  return failure();
91 
92  // Get pointer type (using address space 0 by default)
93  auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
94 
95  // Get the strides offsets and shape.
96  SmallVector<int64_t> strides;
97  int64_t offset;
98  if (failed(type.getStridesAndOffset(strides, offset)))
99  return failure();
100  ArrayRef<int64_t> shape = type.getShape();
101 
102  // Use index type from the type converter for the descriptor elements
103  Type indexType = typeConverter.getIndexType();
104 
105  // For a ranked memref, the descriptor contains:
106  // 1. The pointer to the allocated data
107  // 2. The pointer to the aligned data
108  // 3. The dynamic offset?
109  // 4. The dynamic sizes?
110  // 5. The dynamic strides?
111  SmallVector<Type, 5> elements;
112 
113  // Allocated pointer.
114  elements.push_back(ptrType);
115 
116  // Potentially add the dynamic offset.
117  if (offset == ShapedType::kDynamic)
118  elements.push_back(indexType);
119 
120  // Potentially add the dynamic sizes.
121  for (int64_t dim : shape) {
122  if (dim == ShapedType::kDynamic)
123  elements.push_back(indexType);
124  }
125 
126  // Potentially add the dynamic strides.
127  for (int64_t stride : strides) {
128  if (stride == ShapedType::kDynamic)
129  elements.push_back(indexType);
130  }
131  return LLVM::LLVMStructType::getLiteral(context, elements);
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // FromPtrOpConversion
136 //===----------------------------------------------------------------------===//
137 
138 LogicalResult FromPtrOpConversion::matchAndRewrite(
139  ptr::FromPtrOp op, OpAdaptor adaptor,
140  ConversionPatternRewriter &rewriter) const {
141  // Get the target memref type
142  auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
143  if (!mTy)
144  return rewriter.notifyMatchFailure(op, "Expected memref result type");
145 
146  if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
147  return rewriter.notifyMatchFailure(
148  op, "Can convert only memrefs with metadata");
149  }
150 
151  // Convert the result type
152  Type descriptorTy = getTypeConverter()->convertType(mTy);
153  if (!descriptorTy)
154  return rewriter.notifyMatchFailure(op, "Failed to convert result type");
155 
156  // Get the strides, offsets and shape.
157  SmallVector<int64_t> strides;
158  int64_t offset;
159  if (failed(mTy.getStridesAndOffset(strides, offset))) {
160  return rewriter.notifyMatchFailure(op,
161  "Failed to get the strides and offset");
162  }
163  ArrayRef<int64_t> shape = mTy.getShape();
164 
165  // Create a new memref descriptor
166  Location loc = op.getLoc();
167  auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
168 
169  // Set the allocated and aligned pointers.
170  desc.setAllocatedPtr(
171  rewriter, loc,
172  LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getMetadata(), 0));
173  desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
174 
175  // Extract metadata from the passed struct.
176  unsigned fieldIdx = 1;
177 
178  // Set dynamic offset if needed.
179  if (offset == ShapedType::kDynamic) {
180  Value offsetValue = LLVM::ExtractValueOp::create(
181  rewriter, loc, adaptor.getMetadata(), fieldIdx++);
182  desc.setOffset(rewriter, loc, offsetValue);
183  } else {
184  desc.setConstantOffset(rewriter, loc, offset);
185  }
186 
187  // Set dynamic sizes if needed.
188  for (auto [i, dim] : llvm::enumerate(shape)) {
189  if (dim == ShapedType::kDynamic) {
190  Value sizeValue = LLVM::ExtractValueOp::create(
191  rewriter, loc, adaptor.getMetadata(), fieldIdx++);
192  desc.setSize(rewriter, loc, i, sizeValue);
193  } else {
194  desc.setConstantSize(rewriter, loc, i, dim);
195  }
196  }
197 
198  // Set dynamic strides if needed.
199  for (auto [i, stride] : llvm::enumerate(strides)) {
200  if (stride == ShapedType::kDynamic) {
201  Value strideValue = LLVM::ExtractValueOp::create(
202  rewriter, loc, adaptor.getMetadata(), fieldIdx++);
203  desc.setStride(rewriter, loc, i, strideValue);
204  } else {
205  desc.setConstantStride(rewriter, loc, i, stride);
206  }
207  }
208 
209  rewriter.replaceOp(op, static_cast<Value>(desc));
210  return success();
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // GetMetadataOpConversion
215 //===----------------------------------------------------------------------===//
216 
217 LogicalResult GetMetadataOpConversion::matchAndRewrite(
218  ptr::GetMetadataOp op, OpAdaptor adaptor,
219  ConversionPatternRewriter &rewriter) const {
220  auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
221  if (!mTy)
222  return rewriter.notifyMatchFailure(op, "Only memref metadata is supported");
223 
224  // Get the metadata type.
225  FailureOr<LLVM::LLVMStructType> mdTy =
226  createMemRefMetadataType(mTy, *getTypeConverter());
227  if (failed(mdTy)) {
228  return rewriter.notifyMatchFailure(op,
229  "Failed to create the metadata type");
230  }
231 
232  // Get the memref descriptor.
233  MemRefDescriptor descriptor(adaptor.getPtr());
234 
235  // Get the strides offsets and shape.
236  SmallVector<int64_t> strides;
237  int64_t offset;
238  if (failed(mTy.getStridesAndOffset(strides, offset))) {
239  return rewriter.notifyMatchFailure(op,
240  "Failed to get the strides and offset");
241  }
242  ArrayRef<int64_t> shape = mTy.getShape();
243 
244  // Create a new LLVM struct to hold the metadata
245  Location loc = op.getLoc();
246  Value sV = LLVM::UndefOp::create(rewriter, loc, *mdTy);
247 
248  // First element is the allocated pointer.
249  SmallVector<int64_t> pos{0};
250  sV = LLVM::InsertValueOp::create(rewriter, loc, sV,
251  descriptor.allocatedPtr(rewriter, loc), pos);
252 
253  // Track the current field index.
254  unsigned fieldIdx = 1;
255 
256  // Add dynamic offset if needed.
257  if (offset == ShapedType::kDynamic) {
258  sV = LLVM::InsertValueOp::create(
259  rewriter, loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
260  }
261 
262  // Add dynamic sizes if needed.
263  for (auto [i, dim] : llvm::enumerate(shape)) {
264  if (dim != ShapedType::kDynamic)
265  continue;
266  sV = LLVM::InsertValueOp::create(
267  rewriter, loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
268  }
269 
270  // Add dynamic strides if needed
271  for (auto [i, stride] : llvm::enumerate(strides)) {
272  if (stride != ShapedType::kDynamic)
273  continue;
274  sV = LLVM::InsertValueOp::create(
275  rewriter, loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
276  }
277  rewriter.replaceOp(op, sV);
278  return success();
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // PtrAddOpConversion
283 //===----------------------------------------------------------------------===//
284 
285 LogicalResult
286 PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const {
288  // Get and check the base.
289  Value base = adaptor.getBase();
290  if (!isa<LLVM::LLVMPointerType>(base.getType()))
291  return rewriter.notifyMatchFailure(op, "Incompatible pointer type");
292 
293  // Get the offset.
294  Value offset = adaptor.getOffset();
295 
296  // Ptr assumes the offset is in bytes.
297  Type elementType = IntegerType::get(rewriter.getContext(), 8);
298 
299  // Convert the `ptradd` flags.
300  LLVM::GEPNoWrapFlags flags;
301  switch (op.getFlags()) {
302  case ptr::PtrAddFlags::none:
303  flags = LLVM::GEPNoWrapFlags::none;
304  break;
305  case ptr::PtrAddFlags::nusw:
306  flags = LLVM::GEPNoWrapFlags::nusw;
307  break;
308  case ptr::PtrAddFlags::nuw:
309  flags = LLVM::GEPNoWrapFlags::nuw;
310  break;
311  case ptr::PtrAddFlags::inbounds:
312  flags = LLVM::GEPNoWrapFlags::inbounds;
313  break;
314  }
315 
316  // Create the GEP operation with appropriate arguments
317  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType,
318  base, ValueRange{offset}, flags);
319  return success();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // ToPtrOpConversion
324 //===----------------------------------------------------------------------===//
325 
326 LogicalResult
327 ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
328  ConversionPatternRewriter &rewriter) const {
329  // Bail if it's not a memref.
330  if (!isa<MemRefType>(op.getPtr().getType()))
331  return rewriter.notifyMatchFailure(op, "Expected a memref input");
332 
333  // Extract the aligned pointer from the memref descriptor.
334  rewriter.replaceOp(
335  op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
336  return success();
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // TypeOffsetOpConversion
341 //===----------------------------------------------------------------------===//
342 
343 LogicalResult TypeOffsetOpConversion::matchAndRewrite(
344  ptr::TypeOffsetOp op, OpAdaptor adaptor,
345  ConversionPatternRewriter &rewriter) const {
346  // Convert the type attribute.
347  Type type = getTypeConverter()->convertType(op.getElementType());
348  if (!type)
349  return rewriter.notifyMatchFailure(op, "Couldn't convert the type");
350 
351  // Convert the result type.
352  Type rTy = getTypeConverter()->convertType(op.getResult().getType());
353  if (!rTy)
354  return rewriter.notifyMatchFailure(op, "Couldn't convert the result type");
355 
356  // TODO: Use MLIR's data layout. We don't use it because overall support is
357  // still flaky.
358 
359  // Create an LLVM pointer type for the GEP operation.
360  auto ptrTy = LLVM::LLVMPointerType::get(getContext());
361 
362  // Create a GEP operation to compute the offset of the type.
363  auto offset =
364  LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
365  LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
366  ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
367 
368  // Replace the original op with a PtrToIntOp using the computed offset.
369  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
370  return success();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // ConvertToLLVMPatternInterface implementation
375 //===----------------------------------------------------------------------===//
376 
377 namespace {
378 /// Implement the interface to convert Ptr to LLVM.
379 struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
381  void loadDependentDialects(MLIRContext *context) const final {
382  context->loadDialect<LLVM::LLVMDialect>();
383  }
384 
385  /// Hook for derived dialect interface to provide conversion patterns
386  /// and mark dialect legal for the conversion target.
387  void populateConvertToLLVMConversionPatterns(
388  ConversionTarget &target, LLVMTypeConverter &converter,
389  RewritePatternSet &patterns) const final {
391  }
392 };
393 } // namespace
394 
395 //===----------------------------------------------------------------------===//
396 // API
397 //===----------------------------------------------------------------------===//
398 
401  // Add address space conversions.
402  converter.addTypeAttributeConversion(
403  [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
405  if (type.getMemorySpace() != memorySpace)
407  return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
408  });
409 
410  // Add type conversions.
411  converter.addConversion([&](ptr::PtrType type) -> Type {
412  std::optional<Attribute> maybeAttr =
413  converter.convertTypeAttribute(type, type.getMemorySpace());
414  auto memSpace =
415  maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
416  if (!memSpace)
417  return {};
418  return LLVM::LLVMPointerType::get(type.getContext(),
419  memSpace.getValue().getSExtValue());
420  });
421 
422  // Convert ptr metadata of memref type.
423  converter.addConversion([&](ptr::PtrMetadataType type) -> Type {
424  auto mTy = dyn_cast<MemRefType>(type.getType());
425  if (!mTy)
426  return {};
427  FailureOr<LLVM::LLVMStructType> res =
428  createMemRefMetadataType(mTy, converter);
429  return failed(res) ? Type() : res.value();
430  });
431 
432  // Add conversion patterns.
433  patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
434  ToPtrOpConversion, TypeOffsetOpConversion>(converter);
435 }
436 
438  registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
439  dialect->addInterfaces<PtrToLLVMDialectInterface>();
440  });
441 }
static MLIRContext * getContext(OpFoldResult val)
static FailureOr< LLVM::LLVMStructType > createMemRefMetadataType(MemRefType type, const LLVMTypeConverter &typeConverter)
Definition: PtrToLLVM.cpp:84
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
The general result of a type attribute conversion callback, allowing for early termination.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the convert to LLVM patterns for the ptr dialect.
Definition: PtrToLLVM.cpp:399
void registerConvertPtrToLLVMInterface(DialectRegistry &registry)
Register the convert to LLVM interface for the ptr dialect.
Definition: PtrToLLVM.cpp:437
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...