MLIR  22.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/Support/DebugLog.h"
28 #include "llvm/Support/MathExtras.h"
29 
30 #include <optional>
31 
32 #define DEBUG_TYPE "memref-to-llvm"
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 using namespace mlir;
40 
41 static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
42  LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43 
44 namespace {
45 
46 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47  return ShapedType::isStatic(strideOrOffset);
48 }
49 
50 static FailureOr<LLVM::LLVMFuncOp>
51 getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
52  Operation *module, SymbolTableCollection *symbolTables) {
53  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
54 
55  if (useGenericFn)
56  return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
57 
58  return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
59 }
60 
61 static FailureOr<LLVM::LLVMFuncOp>
62 getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
63  Operation *module, Type indexType,
64  SymbolTableCollection *symbolTables) {
65  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
66  if (useGenericFn)
67  return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
68  symbolTables);
69 
70  return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
71 }
72 
73 static FailureOr<LLVM::LLVMFuncOp>
74 getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
75  Operation *module, Type indexType,
76  SymbolTableCollection *symbolTables) {
77  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
78 
79  if (useGenericFn)
80  return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
81  symbolTables);
82 
83  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
84 }
85 
86 /// Computes the aligned value for 'input' as follows:
87 /// bumped = input + alignement - 1
88 /// aligned = bumped - bumped % alignment
89 static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
90  Value input, Value alignment) {
91  Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
92  rewriter.getIndexAttr(1));
93  Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94  Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95  Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96  return LLVM::SubOp::create(rewriter, loc, bumped, mod);
97 }
98 
99 /// Computes the byte size for the MemRef element type.
100 static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
101  MemRefType memRefType, Operation *op,
102  const DataLayout *defaultLayout) {
103  const DataLayout *layout = defaultLayout;
104  if (const DataLayoutAnalysis *analysis =
105  typeConverter->getDataLayoutAnalysis()) {
106  layout = &analysis->getAbove(op);
107  }
108  Type elementType = memRefType.getElementType();
109  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110  return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
111  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112  return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
113  *layout);
114  return layout->getTypeSize(elementType);
115 }
116 
117 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
118  Location loc, Value allocatedPtr,
119  MemRefType memRefType, Type elementPtrType,
120  const LLVMTypeConverter &typeConverter) {
121  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
122  FailureOr<unsigned> maybeMemrefAddrSpace =
123  typeConverter.getMemRefAddressSpace(memRefType);
124  assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
125  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127  allocatedPtr = LLVM::AddrSpaceCastOp::create(
128  rewriter, loc,
129  LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
130  allocatedPtr);
131  return allocatedPtr;
132 }
133 
134 class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
135  SymbolTableCollection *symbolTables = nullptr;
136 
137 public:
138  explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
139  SymbolTableCollection *symbolTables = nullptr,
140  PatternBenefit benefit = 1)
141  : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
142  symbolTables(symbolTables) {}
143 
144  LogicalResult
145  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146  ConversionPatternRewriter &rewriter) const override {
147  auto loc = op.getLoc();
148  MemRefType memRefType = op.getType();
149  if (!isConvertibleAndHasIdentityMaps(memRefType))
150  return rewriter.notifyMatchFailure(op, "incompatible memref type");
151 
152  // Get or insert alloc function into the module.
153  FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154  getNotalignedAllocFn(rewriter, getTypeConverter(),
155  op->getParentWithTrait<OpTrait::SymbolTable>(),
156  getIndexType(), symbolTables);
157  if (failed(allocFuncOp))
158  return failure();
159 
160  // Get actual sizes of the memref as values: static sizes are constant
161  // values and dynamic sizes are passed to 'alloc' as operands. In case of
162  // zero-dimensional memref, assume a scalar (size 1).
163  SmallVector<Value, 4> sizes;
164  SmallVector<Value, 4> strides;
165  Value sizeBytes;
166 
167  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168  rewriter, sizes, strides, sizeBytes, true);
169 
170  Value alignment = getAlignment(rewriter, loc, op);
171  if (alignment) {
172  // Adjust the allocation size to consider alignment.
173  sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
174  }
175 
176  // Allocate the underlying buffer.
177  Type elementPtrType = this->getElementPtrType(memRefType);
178  assert(elementPtrType && "could not compute element ptr type");
179  auto results =
180  LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
181 
182  Value allocatedPtr =
183  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184  elementPtrType, *getTypeConverter());
185  Value alignedPtr = allocatedPtr;
186  if (alignment) {
187  // Compute the aligned pointer.
188  Value allocatedInt =
189  LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
190  Value alignmentInt =
191  createAligned(rewriter, loc, allocatedInt, alignment);
192  alignedPtr =
193  LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
194  }
195 
196  // Create the MemRef descriptor.
197  auto memRefDescriptor = this->createMemRefDescriptor(
198  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
199 
200  // Return the final value of the descriptor.
201  rewriter.replaceOp(op, {memRefDescriptor});
202  return success();
203  }
204 
205  /// Computes the alignment for the given memory allocation op.
206  template <typename OpType>
207  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
208  OpType op) const {
209  MemRefType memRefType = op.getType();
210  Value alignment;
211  if (auto alignmentAttr = op.getAlignment()) {
212  Type indexType = getIndexType();
213  alignment =
214  createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
215  } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
216  // In the case where no alignment is specified, we may want to override
217  // `malloc's` behavior. `malloc` typically aligns at the size of the
218  // biggest scalar on a target HW. For non-scalars, use the natural
219  // alignment of the LLVM type given by the LLVM DataLayout.
220  alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
221  }
222  return alignment;
223  }
224 };
225 
226 class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
227  SymbolTableCollection *symbolTables = nullptr;
228 
229 public:
230  explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
231  SymbolTableCollection *symbolTables = nullptr,
232  PatternBenefit benefit = 1)
233  : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
234  symbolTables(symbolTables) {}
235 
236  LogicalResult
237  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override {
239  auto loc = op.getLoc();
240  MemRefType memRefType = op.getType();
241  if (!isConvertibleAndHasIdentityMaps(memRefType))
242  return rewriter.notifyMatchFailure(op, "incompatible memref type");
243 
244  // Get or insert alloc function into module.
245  FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246  getAlignedAllocFn(rewriter, getTypeConverter(),
247  op->getParentWithTrait<OpTrait::SymbolTable>(),
248  getIndexType(), symbolTables);
249  if (failed(allocFuncOp))
250  return failure();
251 
252  // Get actual sizes of the memref as values: static sizes are constant
253  // values and dynamic sizes are passed to 'alloc' as operands. In case of
254  // zero-dimensional memref, assume a scalar (size 1).
255  SmallVector<Value, 4> sizes;
256  SmallVector<Value, 4> strides;
257  Value sizeBytes;
258 
259  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260  rewriter, sizes, strides, sizeBytes, !false);
261 
262  int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
263 
264  Value allocAlignment =
265  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
266 
267  // Function aligned_alloc requires size to be a multiple of alignment; we
268  // pad the size to the next multiple if necessary.
269  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
271 
272  Type elementPtrType = this->getElementPtrType(memRefType);
273  auto results =
274  LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
275  ValueRange({allocAlignment, sizeBytes}));
276 
277  Value ptr =
278  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279  elementPtrType, *getTypeConverter());
280 
281  // Create the MemRef descriptor.
282  auto memRefDescriptor = this->createMemRefDescriptor(
283  loc, memRefType, ptr, ptr, sizes, strides, rewriter);
284 
285  // Return the final value of the descriptor.
286  rewriter.replaceOp(op, {memRefDescriptor});
287  return success();
288  }
289 
290  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
291  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
292 
293  /// Computes the alignment for aligned_alloc used to allocate the buffer for
294  /// the memory allocation op.
295  ///
296  /// Aligned_alloc requires the allocation size to be a power of two, and the
297  /// allocation size to be a multiple of the alignment.
298  int64_t alignedAllocationGetAlignment(memref::AllocOp op,
299  const DataLayout *defaultLayout) const {
300  if (std::optional<uint64_t> alignment = op.getAlignment())
301  return *alignment;
302 
303  // Whenever we don't have alignment set, we will use an alignment
304  // consistent with the element type; since the allocation size has to be a
305  // power of two, we will bump to the next power of two if it isn't.
306  unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307  getTypeConverter(), op.getType(), op, defaultLayout);
308  return std::max(kMinAlignedAllocAlignment,
309  llvm::PowerOf2Ceil(eltSizeBytes));
310  }
311 
312  /// Returns true if the memref size in bytes is known to be a multiple of
313  /// factor.
314  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
315  const DataLayout *defaultLayout) const {
316  uint64_t sizeDivisor =
317  getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
319  if (type.isDynamicDim(i))
320  continue;
321  sizeDivisor = sizeDivisor * type.getDimSize(i);
322  }
323  return sizeDivisor % factor == 0;
324  }
325 
326 private:
327  /// Default layout to use in absence of the corresponding analysis.
328  DataLayout defaultLayout;
329 };
330 
331 struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
333 
334  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
335  /// is set to null for stack allocations. `accessAlignment` is set if
336  /// alignment is needed post allocation (for eg. in conjunction with malloc).
337  LogicalResult
338  matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339  ConversionPatternRewriter &rewriter) const override {
340  auto loc = op.getLoc();
341  MemRefType memRefType = op.getType();
342  if (!isConvertibleAndHasIdentityMaps(memRefType))
343  return rewriter.notifyMatchFailure(op, "incompatible memref type");
344 
345  // Get actual sizes of the memref as values: static sizes are constant
346  // values and dynamic sizes are passed to 'alloc' as operands. In case of
347  // zero-dimensional memref, assume a scalar (size 1).
348  SmallVector<Value, 4> sizes;
349  SmallVector<Value, 4> strides;
350  Value size;
351 
352  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353  rewriter, sizes, strides, size, !true);
354 
355  // With alloca, one gets a pointer to the element type right away.
356  // For stack allocations.
357  auto elementType =
358  typeConverter->convertType(op.getType().getElementType());
359  FailureOr<unsigned> maybeAddressSpace =
360  getTypeConverter()->getMemRefAddressSpace(op.getType());
361  assert(succeeded(maybeAddressSpace) && "unsupported address space");
362  unsigned addrSpace = *maybeAddressSpace;
363  auto elementPtrType =
364  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
365 
366  auto allocatedElementPtr =
367  LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368  op.getAlignment().value_or(0));
369 
370  // Create the MemRef descriptor.
371  auto memRefDescriptor = this->createMemRefDescriptor(
372  loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
373  strides, rewriter);
374 
375  // Return the final value of the descriptor.
376  rewriter.replaceOp(op, {memRefDescriptor});
377  return success();
378  }
379 };
380 
381 struct AllocaScopeOpLowering
382  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
384 
385  LogicalResult
386  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override {
388  OpBuilder::InsertionGuard guard(rewriter);
389  Location loc = allocaScopeOp.getLoc();
390 
391  // Split the current block before the AllocaScopeOp to create the inlining
392  // point.
393  auto *currentBlock = rewriter.getInsertionBlock();
394  auto *remainingOpsBlock =
395  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396  Block *continueBlock;
397  if (allocaScopeOp.getNumResults() == 0) {
398  continueBlock = remainingOpsBlock;
399  } else {
400  continueBlock = rewriter.createBlock(
401  remainingOpsBlock, allocaScopeOp.getResultTypes(),
402  SmallVector<Location>(allocaScopeOp->getNumResults(),
403  allocaScopeOp.getLoc()));
404  LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
405  }
406 
407  // Inline body region.
408  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
409  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
410  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
411 
412  // Save stack and then branch into the body of the region.
413  rewriter.setInsertionPointToEnd(currentBlock);
414  auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415  LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
416 
417  // Replace the alloca_scope return with a branch that jumps out of the body.
418  // Stack restore before leaving the body region.
419  rewriter.setInsertionPointToEnd(afterBody);
420  auto returnOp =
421  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
422  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423  returnOp, returnOp.getResults(), continueBlock);
424 
425  // Insert stack restore before jumping out the body of the region.
426  rewriter.setInsertionPoint(branchOp);
427  LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
428 
429  // Replace the op with values return from the body region.
430  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
431 
432  return success();
433  }
434 };
435 
436 struct AssumeAlignmentOpLowering
437  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
439  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
440  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
442 
443  LogicalResult
444  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445  ConversionPatternRewriter &rewriter) const override {
446  Value memref = adaptor.getMemref();
447  unsigned alignment = op.getAlignment();
448  auto loc = op.getLoc();
449 
450  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451  Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
452  /*indices=*/{});
453 
454  // Emit llvm.assume(true) ["align"(memref, alignment)].
455  // This is more direct than ptrtoint-based checks, is explicitly supported,
456  // and works with non-integral address spaces.
457  Value trueCond =
458  LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
459  Value alignmentConst =
460  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
461  LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
462  alignmentConst);
463  rewriter.replaceOp(op, memref);
464  return success();
465  }
466 };
467 
468 struct DistinctObjectsOpLowering
469  : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
471  memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472  explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
474 
475  LogicalResult
476  matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477  ConversionPatternRewriter &rewriter) const override {
478  ValueRange operands = adaptor.getOperands();
479  if (operands.size() <= 1) {
480  // Fast path.
481  rewriter.replaceOp(op, operands);
482  return success();
483  }
484 
485  Location loc = op.getLoc();
486  SmallVector<Value> ptrs;
487  for (auto [origOperand, newOperand] :
488  llvm::zip_equal(op.getOperands(), operands)) {
489  auto memrefType = cast<MemRefType>(origOperand.getType());
490  MemRefDescriptor memRefDescriptor(newOperand);
491  Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
492  memrefType);
493  ptrs.push_back(ptr);
494  }
495 
496  auto cond =
497  LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
498  // Generate separate_storage assumptions for each pair of pointers.
499  for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500  for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501  Value ptr1 = ptrs[i];
502  Value ptr2 = ptrs[j];
503  LLVM::AssumeOp::create(rewriter, loc, cond,
504  LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
505  }
506  }
507 
508  rewriter.replaceOp(op, operands);
509  return success();
510  }
511 };
512 
513 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
514 // The memref descriptor being an SSA value, there is no need to clean it up
515 // in any way.
516 class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
517  SymbolTableCollection *symbolTables = nullptr;
518 
519 public:
520  explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
521  SymbolTableCollection *symbolTables = nullptr,
522  PatternBenefit benefit = 1)
523  : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
524  symbolTables(symbolTables) {}
525 
526  LogicalResult
527  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528  ConversionPatternRewriter &rewriter) const override {
529  // Insert the `free` declaration if it is not already present.
530  FailureOr<LLVM::LLVMFuncOp> freeFunc =
531  getFreeFn(rewriter, getTypeConverter(),
532  op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
533  if (failed(freeFunc))
534  return failure();
535  Value allocatedPtr;
536  if (auto unrankedTy =
537  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538  auto elementPtrTy = LLVM::LLVMPointerType::get(
539  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541  rewriter, op.getLoc(),
542  UnrankedMemRefDescriptor(adaptor.getMemref())
543  .memRefDescPtr(rewriter, op.getLoc()),
544  elementPtrTy);
545  } else {
546  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
547  .allocatedPtr(rewriter, op.getLoc());
548  }
549  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
550  allocatedPtr);
551  return success();
552  }
553 };
554 
555 // A `dim` is converted to a constant for static sizes and to an access to the
556 // size stored in the memref descriptor for dynamic sizes.
557 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
559 
560  LogicalResult
561  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562  ConversionPatternRewriter &rewriter) const override {
563  Type operandType = dimOp.getSource().getType();
564  if (isa<UnrankedMemRefType>(operandType)) {
565  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566  operandType, dimOp, adaptor.getOperands(), rewriter);
567  if (failed(extractedSize))
568  return failure();
569  rewriter.replaceOp(dimOp, {*extractedSize});
570  return success();
571  }
572  if (isa<MemRefType>(operandType)) {
573  rewriter.replaceOp(
574  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575  adaptor.getOperands(), rewriter)});
576  return success();
577  }
578  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
579  }
580 
581 private:
582  FailureOr<Value>
583  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
584  OpAdaptor adaptor,
585  ConversionPatternRewriter &rewriter) const {
586  Location loc = dimOp.getLoc();
587 
588  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589  auto scalarMemRefType =
590  MemRefType::get({}, unrankedMemRefType.getElementType());
591  FailureOr<unsigned> maybeAddressSpace =
592  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593  if (failed(maybeAddressSpace)) {
594  dimOp.emitOpError("memref memory space must be convertible to an integer "
595  "address space");
596  return failure();
597  }
598  unsigned addressSpace = *maybeAddressSpace;
599 
600  // Extract pointer to the underlying ranked descriptor and bitcast it to a
601  // memref<element_type> descriptor pointer to minimize the number of GEP
602  // operations.
603  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
604  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
605 
606  Type elementType = typeConverter->convertType(scalarMemRefType);
607 
608  // Get pointer to offset field of memref<element_type> descriptor.
609  auto indexPtrTy =
610  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
611  Value offsetPtr =
612  LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
613  underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
614 
615  // The size value that we have to extract can be obtained using GEPop with
616  // `dimOp.index() + 1` index argument.
617  Value idxPlusOne = LLVM::AddOp::create(
618  rewriter, loc,
619  createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
620  adaptor.getIndex());
621  Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622  getTypeConverter()->getIndexType(),
623  offsetPtr, idxPlusOne);
624  return LLVM::LoadOp::create(rewriter, loc,
625  getTypeConverter()->getIndexType(), sizePtr)
626  .getResult();
627  }
628 
629  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
630  if (auto idx = dimOp.getConstantIndex())
631  return idx;
632 
633  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
635 
636  return std::nullopt;
637  }
638 
639  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
640  OpAdaptor adaptor,
641  ConversionPatternRewriter &rewriter) const {
642  Location loc = dimOp.getLoc();
643 
644  // Take advantage if index is constant.
645  MemRefType memRefType = cast<MemRefType>(operandType);
646  Type indexType = getIndexType();
647  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
648  int64_t i = *index;
649  if (i >= 0 && i < memRefType.getRank()) {
650  if (memRefType.isDynamicDim(i)) {
651  // extract dynamic size from the memref descriptor.
652  MemRefDescriptor descriptor(adaptor.getSource());
653  return descriptor.size(rewriter, loc, i);
654  }
655  // Use constant for static size.
656  int64_t dimSize = memRefType.getDimSize(i);
657  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
658  }
659  }
660  Value index = adaptor.getIndex();
661  int64_t rank = memRefType.getRank();
662  MemRefDescriptor memrefDescriptor(adaptor.getSource());
663  return memrefDescriptor.size(rewriter, loc, index, rank);
664  }
665 };
666 
667 /// Common base for load and store operations on MemRefs. Restricts the match
668 /// to supported MemRef types. Provides functionality to emit code accessing a
669 /// specific element of the underlying data buffer.
670 template <typename Derived>
671 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
674  using Base = LoadStoreOpLowering<Derived>;
675 };
676 
677 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
678 /// retried until it succeeds in atomically storing a new value into memory.
679 ///
680 /// +---------------------------------+
681 /// | <code before the AtomicRMWOp> |
682 /// | <compute initial %loaded> |
683 /// | cf.br loop(%loaded) |
684 /// +---------------------------------+
685 /// |
686 /// -------| |
687 /// | v v
688 /// | +--------------------------------+
689 /// | | loop(%loaded): |
690 /// | | <body contents> |
691 /// | | %pair = cmpxchg |
692 /// | | %ok = %pair[0] |
693 /// | | %new = %pair[1] |
694 /// | | cf.cond_br %ok, end, loop(%new) |
695 /// | +--------------------------------+
696 /// | | |
697 /// |----------- |
698 /// v
699 /// +--------------------------------+
700 /// | end: |
701 /// | <code after the AtomicRMWOp> |
702 /// +--------------------------------+
703 ///
704 struct GenericAtomicRMWOpLowering
705  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
706  using Base::Base;
707 
708  LogicalResult
709  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710  ConversionPatternRewriter &rewriter) const override {
711  auto loc = atomicOp.getLoc();
712  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
713 
714  // Split the block into initial, loop, and ending parts.
715  auto *initBlock = rewriter.getInsertionBlock();
716  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
717  loopBlock->addArgument(valueType, loc);
718 
719  auto *endBlock =
720  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
721 
722  // Compute the loaded value and branch to the loop block.
723  rewriter.setInsertionPointToEnd(initBlock);
724  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
725  auto dataPtr = getStridedElementPtr(
726  rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
727  Value init = LLVM::LoadOp::create(
728  rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
729  dataPtr);
730  LLVM::BrOp::create(rewriter, loc, init, loopBlock);
731 
732  // Prepare the body of the loop block.
733  rewriter.setInsertionPointToStart(loopBlock);
734 
735  // Clone the GenericAtomicRMWOp region and extract the result.
736  auto loopArgument = loopBlock->getArgument(0);
737  IRMapping mapping;
738  mapping.map(atomicOp.getCurrentValue(), loopArgument);
739  Block &entryBlock = atomicOp.body().front();
740  for (auto &nestedOp : entryBlock.without_terminator()) {
741  Operation *clone = rewriter.clone(nestedOp, mapping);
742  mapping.map(nestedOp.getResults(), clone->getResults());
743  }
744  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
745 
746  // Prepare the epilog of the loop block.
747  // Append the cmpxchg op to the end of the loop block.
748  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
749  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
750  auto cmpxchg =
751  LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
752  result, successOrdering, failureOrdering);
753  // Extract the %new_loaded and %ok values from the pair.
754  Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
755  Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
756 
757  // Conditionally branch to the end or back to the loop depending on %ok.
758  LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
759  loopBlock, newLoaded);
760 
761  rewriter.setInsertionPointToEnd(endBlock);
762 
763  // The 'result' of the atomic_rmw op is the newly loaded value.
764  rewriter.replaceOp(atomicOp, {newLoaded});
765 
766  return success();
767  }
768 };
769 
770 /// Returns the LLVM type of the global variable given the memref type `type`.
771 static Type
772 convertGlobalMemrefTypeToLLVM(MemRefType type,
773  const LLVMTypeConverter &typeConverter) {
774  // LLVM type for a global memref will be a multi-dimension array. For
775  // declarations or uninitialized global memrefs, we can potentially flatten
776  // this to a 1D array. However, for memref.global's with an initial value,
777  // we do not intend to flatten the ElementsAttribute when going from std ->
778  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
779  Type elementType = typeConverter.convertType(type.getElementType());
780  Type arrayTy = elementType;
781  // Shape has the outermost dim at index 0, so need to walk it backwards
782  for (int64_t dim : llvm::reverse(type.getShape()))
783  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
784  return arrayTy;
785 }
786 
787 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
788 class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
789  SymbolTableCollection *symbolTables = nullptr;
790 
791 public:
792  explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
793  SymbolTableCollection *symbolTables = nullptr,
794  PatternBenefit benefit = 1)
795  : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
796  symbolTables(symbolTables) {}
797 
798  LogicalResult
799  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
800  ConversionPatternRewriter &rewriter) const override {
801  MemRefType type = global.getType();
802  if (!isConvertibleAndHasIdentityMaps(type))
803  return failure();
804 
805  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
806 
807  LLVM::Linkage linkage =
808  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
809  bool isExternal = global.isExternal();
810  bool isUninitialized = global.isUninitialized();
811 
812  Attribute initialValue = nullptr;
813  if (!isExternal && !isUninitialized) {
814  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
815  initialValue = elementsAttr;
816 
817  // For scalar memrefs, the global variable created is of the element type,
818  // so unpack the elements attribute to extract the value.
819  if (type.getRank() == 0)
820  initialValue = elementsAttr.getSplatValue<Attribute>();
821  }
822 
823  uint64_t alignment = global.getAlignment().value_or(0);
824  FailureOr<unsigned> addressSpace =
825  getTypeConverter()->getMemRefAddressSpace(type);
826  if (failed(addressSpace))
827  return global.emitOpError(
828  "memory space cannot be converted to an integer address space");
829 
830  // Remove old operation from symbol table.
831  SymbolTable *symbolTable = nullptr;
832  if (symbolTables) {
833  Operation *symbolTableOp =
835  symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
836  symbolTable->remove(global);
837  }
838 
839  // Create new operation.
840  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
841  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
842  initialValue, alignment, *addressSpace);
843 
844  // Insert new operation into symbol table.
845  if (symbolTable)
846  symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
847 
848  if (!isExternal && isUninitialized) {
849  rewriter.createBlock(&newGlobal.getInitializerRegion());
850  Value undef[] = {
851  LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
852  LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
853  }
854  return success();
855  }
856 };
857 
858 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
859 /// the first element stashed into the descriptor. This reuses
860 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
861 struct GetGlobalMemrefOpLowering
862  : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
864 
865  /// Buffer "allocation" for memref.get_global op is getting the address of
866  /// the global variable referenced.
867  LogicalResult
868  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
869  ConversionPatternRewriter &rewriter) const override {
870  auto loc = op.getLoc();
871  MemRefType memRefType = op.getType();
872  if (!isConvertibleAndHasIdentityMaps(memRefType))
873  return rewriter.notifyMatchFailure(op, "incompatible memref type");
874 
875  // Get actual sizes of the memref as values: static sizes are constant
876  // values and dynamic sizes are passed to 'alloc' as operands. In case of
877  // zero-dimensional memref, assume a scalar (size 1).
878  SmallVector<Value, 4> sizes;
879  SmallVector<Value, 4> strides;
880  Value sizeBytes;
881 
882  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
883  rewriter, sizes, strides, sizeBytes, !false);
884 
885  MemRefType type = cast<MemRefType>(op.getResult().getType());
886 
887  // This is called after a type conversion, which would have failed if this
888  // call fails.
889  FailureOr<unsigned> maybeAddressSpace =
890  getTypeConverter()->getMemRefAddressSpace(type);
891  assert(succeeded(maybeAddressSpace) && "unsupported address space");
892  unsigned memSpace = *maybeAddressSpace;
893 
894  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
895  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
896  auto addressOf =
897  LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
898 
899  // Get the address of the first element in the array by creating a GEP with
900  // the address of the GV as the base, and (rank + 1) number of 0 indices.
901  auto gep =
902  LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
903  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
904 
905  // We do not expect the memref obtained using `memref.get_global` to be
906  // ever deallocated. Set the allocated pointer to be known bad value to
907  // help debug if that ever happens.
908  auto intPtrType = getIntPtrType(memSpace);
909  Value deadBeefConst =
910  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
911  auto deadBeefPtr =
912  LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
913 
914  // Both allocated and aligned pointers are same. We could potentially stash
915  // a nullptr for the allocated pointer since we do not expect any dealloc.
916  // Create the MemRef descriptor.
917  auto memRefDescriptor = this->createMemRefDescriptor(
918  loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
919 
920  // Return the final value of the descriptor.
921  rewriter.replaceOp(op, {memRefDescriptor});
922  return success();
923  }
924 };
925 
926 // Load operation is lowered to obtaining a pointer to the indexed element
927 // and loading it.
928 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
929  using Base::Base;
930 
931  LogicalResult
932  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
933  ConversionPatternRewriter &rewriter) const override {
934  auto type = loadOp.getMemRefType();
935 
936  // Per memref.load spec, the indices must be in-bounds:
937  // 0 <= idx < dim_size, and additionally all offsets are non-negative,
938  // hence inbounds and nuw are used when lowering to llvm.getelementptr.
939  Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
940  adaptor.getMemref(),
941  adaptor.getIndices(), kNoWrapFlags);
942  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
943  loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
944  loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
945  return success();
946  }
947 };
948 
949 // Store operation is lowered to obtaining a pointer to the indexed element,
950 // and storing the given value to it.
951 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
952  using Base::Base;
953 
954  LogicalResult
955  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
956  ConversionPatternRewriter &rewriter) const override {
957  auto type = op.getMemRefType();
958 
959  // Per memref.store spec, the indices must be in-bounds:
960  // 0 <= idx < dim_size, and additionally all offsets are non-negative,
961  // hence inbounds and nuw are used when lowering to llvm.getelementptr.
962  Value dataPtr =
963  getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
964  adaptor.getIndices(), kNoWrapFlags);
965  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
966  op.getAlignment().value_or(0),
967  false, op.getNontemporal());
968  return success();
969  }
970 };
971 
972 // The prefetch operation is lowered in a way similar to the load operation
973 // except that the llvm.prefetch operation is used for replacement.
974 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
975  using Base::Base;
976 
977  LogicalResult
978  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
979  ConversionPatternRewriter &rewriter) const override {
980  auto type = prefetchOp.getMemRefType();
981  auto loc = prefetchOp.getLoc();
982 
983  Value dataPtr = getStridedElementPtr(
984  rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
985 
986  // Replace with llvm.prefetch.
987  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
988  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
989  IntegerAttr isData =
990  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
991  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
992  localityHint, isData);
993  return success();
994  }
995 };
996 
997 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
999 
1000  LogicalResult
1001  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1002  ConversionPatternRewriter &rewriter) const override {
1003  Location loc = op.getLoc();
1004  Type operandType = op.getMemref().getType();
1005  if (isa<UnrankedMemRefType>(operandType)) {
1006  UnrankedMemRefDescriptor desc(adaptor.getMemref());
1007  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
1008  return success();
1009  }
1010  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1011  Type indexType = getIndexType();
1012  rewriter.replaceOp(op,
1013  {createIndexAttrConstant(rewriter, loc, indexType,
1014  rankedMemRefType.getRank())});
1015  return success();
1016  }
1017  return failure();
1018  }
1019 };
1020 
1021 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
1023 
1024  LogicalResult
1025  matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1026  ConversionPatternRewriter &rewriter) const override {
1027  Type srcType = memRefCastOp.getOperand().getType();
1028  Type dstType = memRefCastOp.getType();
1029 
1030  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
1031  // used for type erasure. For now they must preserve underlying element type
1032  // and require source and result type to have the same rank. Therefore,
1033  // perform a sanity check that the underlying structs are the same. Once op
1034  // semantics are relaxed we can revisit.
1035  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1036  if (typeConverter->convertType(srcType) !=
1037  typeConverter->convertType(dstType))
1038  return failure();
1039 
1040  // Unranked to unranked cast is disallowed
1041  if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1042  return failure();
1043 
1044  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1045  auto loc = memRefCastOp.getLoc();
1046 
1047  // For ranked/ranked case, just keep the original descriptor.
1048  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1049  rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1050  return success();
1051  }
1052 
1053  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1054  // Casting ranked to unranked memref type
1055  // Set the rank in the destination from the memref type
1056  // Allocate space on the stack and copy the src memref descriptor
1057  // Set the ptr in the destination to the stack space
1058  auto srcMemRefType = cast<MemRefType>(srcType);
1059  int64_t rank = srcMemRefType.getRank();
1060  // ptr = AllocaOp sizeof(MemRefDescriptor)
1061  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1062  loc, adaptor.getSource(), rewriter);
1063 
1064  // rank = ConstantOp srcRank
1065  auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1066  rewriter.getIndexAttr(rank));
1067  // poison = PoisonOp
1068  UnrankedMemRefDescriptor memRefDesc =
1069  UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
1070  // d1 = InsertValueOp poison, rank, 0
1071  memRefDesc.setRank(rewriter, loc, rankVal);
1072  // d2 = InsertValueOp d1, ptr, 1
1073  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
1074  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
1075 
1076  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1077  // Casting from unranked type to ranked.
1078  // The operation is assumed to be doing a correct cast. If the destination
1079  // type mismatches the unranked the type, it is undefined behavior.
1080  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1081  // ptr = ExtractValueOp src, 1
1082  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
1083 
1084  // struct = LoadOp ptr
1085  auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1086  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1087  } else {
1088  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1089  }
1090 
1091  return success();
1092  }
1093 };
1094 
1095 /// Pattern to lower a `memref.copy` to llvm.
1096 ///
1097 /// For memrefs with identity layouts, the copy is lowered to the llvm
1098 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1099 /// to the generic `MemrefCopyFn`.
1100 class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1101  SymbolTableCollection *symbolTables = nullptr;
1102 
1103 public:
1104  explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1105  SymbolTableCollection *symbolTables = nullptr,
1106  PatternBenefit benefit = 1)
1107  : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1108  symbolTables(symbolTables) {}
1109 
1110  LogicalResult
1111  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1112  ConversionPatternRewriter &rewriter) const {
1113  auto loc = op.getLoc();
1114  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1115 
1116  MemRefDescriptor srcDesc(adaptor.getSource());
1117 
1118  // Compute number of elements.
1119  Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1120  rewriter.getIndexAttr(1));
1121  for (int pos = 0; pos < srcType.getRank(); ++pos) {
1122  auto size = srcDesc.size(rewriter, loc, pos);
1123  numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1124  }
1125 
1126  // Get element size.
1127  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1128  // Compute total.
1129  Value totalSize =
1130  LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1131 
1132  Type elementType = typeConverter->convertType(srcType.getElementType());
1133 
1134  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1135  Value srcOffset = srcDesc.offset(rewriter, loc);
1136  Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
1137  elementType, srcBasePtr, srcOffset);
1138  MemRefDescriptor targetDesc(adaptor.getTarget());
1139  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1140  Value targetOffset = targetDesc.offset(rewriter, loc);
1141  Value targetPtr =
1142  LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
1143  targetBasePtr, targetOffset);
1144  LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1145  /*isVolatile=*/false);
1146  rewriter.eraseOp(op);
1147 
1148  return success();
1149  }
1150 
1151  LogicalResult
1152  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1153  ConversionPatternRewriter &rewriter) const {
1154  auto loc = op.getLoc();
1155  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1156  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1157 
1158  // First make sure we have an unranked memref descriptor representation.
1159  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1160  auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1161  type.getRank());
1162  auto *typeConverter = getTypeConverter();
1163  auto ptr =
1164  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1165 
1166  auto unrankedType =
1167  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1169  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1170  };
1171 
1172  // Save stack position before promoting descriptors
1173  auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1174 
1175  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1176  Value unrankedSource =
1177  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1178  : adaptor.getSource();
1179  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1180  Value unrankedTarget =
1181  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1182  : adaptor.getTarget();
1183 
1184  // Now promote the unranked descriptors to the stack.
1185  auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1186  rewriter.getIndexAttr(1));
1187  auto promote = [&](Value desc) {
1188  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1189  auto allocated =
1190  LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1191  LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1192  return allocated;
1193  };
1194 
1195  auto sourcePtr = promote(unrankedSource);
1196  auto targetPtr = promote(unrankedTarget);
1197 
1198  // Derive size from llvm.getelementptr which will account for any
1199  // potential alignment
1200  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1201  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1202  rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1203  sourcePtr.getType(), symbolTables);
1204  if (failed(copyFn))
1205  return failure();
1206  LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1207  ValueRange{elemSize, sourcePtr, targetPtr});
1208 
1209  // Restore stack used for descriptors
1210  LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1211 
1212  rewriter.eraseOp(op);
1213 
1214  return success();
1215  }
1216 
1217  LogicalResult
1218  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1219  ConversionPatternRewriter &rewriter) const override {
1220  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1221  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1222 
1223  auto isContiguousMemrefType = [&](BaseMemRefType type) {
1224  auto memrefType = dyn_cast<mlir::MemRefType>(type);
1225  // We can use memcpy for memrefs if they have an identity layout or are
1226  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1227  // special case handled by memrefCopy.
1228  return memrefType &&
1229  (memrefType.getLayout().isIdentity() ||
1230  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1232  };
1233 
1234  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1235  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1236 
1237  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1238  }
1239 };
1240 
1241 struct MemorySpaceCastOpLowering
1242  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1243  using ConvertOpToLLVMPattern<
1244  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1245 
1246  LogicalResult
1247  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1248  ConversionPatternRewriter &rewriter) const override {
1249  Location loc = op.getLoc();
1250 
1251  Type resultType = op.getDest().getType();
1252  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1253  auto resultDescType =
1254  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1255  Type newPtrType = resultDescType.getBody()[0];
1256 
1257  SmallVector<Value> descVals;
1258  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1259  descVals);
1260  descVals[0] =
1261  LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1262  descVals[1] =
1263  LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1264  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1265  resultTypeR, descVals);
1266  rewriter.replaceOp(op, result);
1267  return success();
1268  }
1269  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1270  // Since the type converter won't be doing this for us, get the address
1271  // space.
1272  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1273  FailureOr<unsigned> maybeSourceAddrSpace =
1274  getTypeConverter()->getMemRefAddressSpace(sourceType);
1275  if (failed(maybeSourceAddrSpace))
1276  return rewriter.notifyMatchFailure(loc,
1277  "non-integer source address space");
1278  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1279  FailureOr<unsigned> maybeResultAddrSpace =
1280  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1281  if (failed(maybeResultAddrSpace))
1282  return rewriter.notifyMatchFailure(loc,
1283  "non-integer result address space");
1284  unsigned resultAddrSpace = *maybeResultAddrSpace;
1285 
1286  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1287  Value rank = sourceDesc.rank(rewriter, loc);
1288  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1289 
1290  // Create and allocate storage for new memref descriptor.
1291  auto result = UnrankedMemRefDescriptor::poison(
1292  rewriter, loc, typeConverter->convertType(resultTypeU));
1293  result.setRank(rewriter, loc, rank);
1294  Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1295  rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1296  Value resultUnderlyingDesc =
1297  LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1298  rewriter.getI8Type(), resultUnderlyingSize);
1299  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1300 
1301  // Copy pointers, performing address space casts.
1302  auto sourceElemPtrType =
1303  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1304  auto resultElemPtrType =
1305  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1306 
1307  Value allocatedPtr = sourceDesc.allocatedPtr(
1308  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1309  Value alignedPtr =
1310  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1311  sourceUnderlyingDesc, sourceElemPtrType);
1312  allocatedPtr = LLVM::AddrSpaceCastOp::create(
1313  rewriter, loc, resultElemPtrType, allocatedPtr);
1314  alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1315  resultElemPtrType, alignedPtr);
1316 
1317  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1318  resultElemPtrType, allocatedPtr);
1319  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1320  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1321 
1322  // Copy all the index-valued operands.
1323  Value sourceIndexVals =
1324  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1325  sourceUnderlyingDesc, sourceElemPtrType);
1326  Value resultIndexVals =
1327  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1328  resultUnderlyingDesc, resultElemPtrType);
1329 
1330  int64_t bytesToSkip =
1331  2 * llvm::divideCeil(
1332  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1333  Value bytesToSkipConst = LLVM::ConstantOp::create(
1334  rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1335  Value copySize =
1336  LLVM::SubOp::create(rewriter, loc, getIndexType(),
1337  resultUnderlyingSize, bytesToSkipConst);
1338  LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1339  copySize, /*isVolatile=*/false);
1340 
1341  rewriter.replaceOp(op, ValueRange{result});
1342  return success();
1343  }
1344  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1345  }
1346 };
1347 
1348 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
1349 /// memref type. In unranked case, the fields are extracted from the underlying
1350 /// ranked descriptor.
1351 static void extractPointersAndOffset(Location loc,
1352  ConversionPatternRewriter &rewriter,
1353  const LLVMTypeConverter &typeConverter,
1354  Value originalOperand,
1355  Value convertedOperand,
1356  Value *allocatedPtr, Value *alignedPtr,
1357  Value *offset = nullptr) {
1358  Type operandType = originalOperand.getType();
1359  if (isa<MemRefType>(operandType)) {
1360  MemRefDescriptor desc(convertedOperand);
1361  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1362  *alignedPtr = desc.alignedPtr(rewriter, loc);
1363  if (offset != nullptr)
1364  *offset = desc.offset(rewriter, loc);
1365  return;
1366  }
1367 
1368  // These will all cause assert()s on unconvertible types.
1369  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1370  cast<UnrankedMemRefType>(operandType));
1371  auto elementPtrType =
1372  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1373 
1374  // Extract pointer to the underlying ranked memref descriptor and cast it to
1375  // ElemType**.
1376  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1377  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1378 
1380  rewriter, loc, underlyingDescPtr, elementPtrType);
1382  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1383  if (offset != nullptr) {
1385  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1386  }
1387 }
1388 
1389 struct MemRefReinterpretCastOpLowering
1390  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1391  using ConvertOpToLLVMPattern<
1392  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1393 
1394  LogicalResult
1395  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1396  ConversionPatternRewriter &rewriter) const override {
1397  Type srcType = castOp.getSource().getType();
1398 
1399  Value descriptor;
1400  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1401  adaptor, &descriptor)))
1402  return failure();
1403  rewriter.replaceOp(castOp, {descriptor});
1404  return success();
1405  }
1406 
1407 private:
1408  LogicalResult convertSourceMemRefToDescriptor(
1409  ConversionPatternRewriter &rewriter, Type srcType,
1410  memref::ReinterpretCastOp castOp,
1411  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1412  MemRefType targetMemRefType =
1413  cast<MemRefType>(castOp.getResult().getType());
1414  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1415  typeConverter->convertType(targetMemRefType));
1416  if (!llvmTargetDescriptorTy)
1417  return failure();
1418 
1419  // Create descriptor.
1420  Location loc = castOp.getLoc();
1421  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1422 
1423  // Set allocated and aligned pointers.
1424  Value allocatedPtr, alignedPtr;
1425  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1426  castOp.getSource(), adaptor.getSource(),
1427  &allocatedPtr, &alignedPtr);
1428  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1429  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1430 
1431  // Set offset.
1432  if (castOp.isDynamicOffset(0))
1433  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1434  else
1435  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1436 
1437  // Set sizes and strides.
1438  unsigned dynSizeId = 0;
1439  unsigned dynStrideId = 0;
1440  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1441  if (castOp.isDynamicSize(i))
1442  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1443  else
1444  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1445 
1446  if (castOp.isDynamicStride(i))
1447  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1448  else
1449  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1450  }
1451  *descriptor = desc;
1452  return success();
1453  }
1454 };
1455 
1456 struct MemRefReshapeOpLowering
1457  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1459 
1460  LogicalResult
1461  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1462  ConversionPatternRewriter &rewriter) const override {
1463  Type srcType = reshapeOp.getSource().getType();
1464 
1465  Value descriptor;
1466  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1467  adaptor, &descriptor)))
1468  return failure();
1469  rewriter.replaceOp(reshapeOp, {descriptor});
1470  return success();
1471  }
1472 
1473 private:
1474  LogicalResult
1475  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1476  Type srcType, memref::ReshapeOp reshapeOp,
1477  memref::ReshapeOp::Adaptor adaptor,
1478  Value *descriptor) const {
1479  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1480  if (shapeMemRefType.hasStaticShape()) {
1481  MemRefType targetMemRefType =
1482  cast<MemRefType>(reshapeOp.getResult().getType());
1483  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1484  typeConverter->convertType(targetMemRefType));
1485  if (!llvmTargetDescriptorTy)
1486  return failure();
1487 
1488  // Create descriptor.
1489  Location loc = reshapeOp.getLoc();
1490  auto desc =
1491  MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1492 
1493  // Set allocated and aligned pointers.
1494  Value allocatedPtr, alignedPtr;
1495  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1496  reshapeOp.getSource(), adaptor.getSource(),
1497  &allocatedPtr, &alignedPtr);
1498  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1499  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1500 
1501  // Extract the offset and strides from the type.
1502  int64_t offset;
1503  SmallVector<int64_t> strides;
1504  if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1505  return rewriter.notifyMatchFailure(
1506  reshapeOp, "failed to get stride and offset exprs");
1507 
1508  if (!isStaticStrideOrOffset(offset))
1509  return rewriter.notifyMatchFailure(reshapeOp,
1510  "dynamic offset is unsupported");
1511 
1512  desc.setConstantOffset(rewriter, loc, offset);
1513 
1514  assert(targetMemRefType.getLayout().isIdentity() &&
1515  "Identity layout map is a precondition of a valid reshape op");
1516 
1517  Type indexType = getIndexType();
1518  Value stride = nullptr;
1519  int64_t targetRank = targetMemRefType.getRank();
1520  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1521  if (ShapedType::isStatic(strides[i])) {
1522  // If the stride for this dimension is dynamic, then use the product
1523  // of the sizes of the inner dimensions.
1524  stride =
1525  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1526  } else if (!stride) {
1527  // `stride` is null only in the first iteration of the loop. However,
1528  // since the target memref has an identity layout, we can safely set
1529  // the innermost stride to 1.
1530  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1531  }
1532 
1533  Value dimSize;
1534  // If the size of this dimension is dynamic, then load it at runtime
1535  // from the shape operand.
1536  if (!targetMemRefType.isDynamicDim(i)) {
1537  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1538  targetMemRefType.getDimSize(i));
1539  } else {
1540  Value shapeOp = reshapeOp.getShape();
1541  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1542  dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1543  Type indexType = getIndexType();
1544  if (dimSize.getType() != indexType)
1545  dimSize = typeConverter->materializeTargetConversion(
1546  rewriter, loc, indexType, dimSize);
1547  assert(dimSize && "Invalid memref element type");
1548  }
1549 
1550  desc.setSize(rewriter, loc, i, dimSize);
1551  desc.setStride(rewriter, loc, i, stride);
1552 
1553  // Prepare the stride value for the next dimension.
1554  stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1555  }
1556 
1557  *descriptor = desc;
1558  return success();
1559  }
1560 
1561  // The shape is a rank-1 tensor with unknown length.
1562  Location loc = reshapeOp.getLoc();
1563  MemRefDescriptor shapeDesc(adaptor.getShape());
1564  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1565 
1566  // Extract address space and element type.
1567  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1568  unsigned addressSpace =
1569  *getTypeConverter()->getMemRefAddressSpace(targetType);
1570 
1571  // Create the unranked memref descriptor that holds the ranked one. The
1572  // inner descriptor is allocated on stack.
1573  auto targetDesc = UnrankedMemRefDescriptor::poison(
1574  rewriter, loc, typeConverter->convertType(targetType));
1575  targetDesc.setRank(rewriter, loc, resultRank);
1577  rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1578  Value underlyingDescPtr = LLVM::AllocaOp::create(
1579  rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1580  allocationSize);
1581  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1582 
1583  // Extract pointers and offset from the source memref.
1584  Value allocatedPtr, alignedPtr, offset;
1585  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1586  reshapeOp.getSource(), adaptor.getSource(),
1587  &allocatedPtr, &alignedPtr, &offset);
1588 
1589  // Set pointers and offset.
1590  auto elementPtrType =
1591  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1592 
1593  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1594  elementPtrType, allocatedPtr);
1595  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1596  underlyingDescPtr, elementPtrType,
1597  alignedPtr);
1598  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1599  underlyingDescPtr, elementPtrType,
1600  offset);
1601 
1602  // Use the offset pointer as base for further addressing. Copy over the new
1603  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1604  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1605  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1606  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1607  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1608  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1609  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1610  Value resultRankMinusOne =
1611  LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1612 
1613  Block *initBlock = rewriter.getInsertionBlock();
1614  Type indexType = getTypeConverter()->getIndexType();
1615  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1616 
1617  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1618  {indexType, indexType}, {loc, loc});
1619 
1620  // Move the remaining initBlock ops to condBlock.
1621  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1622  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1623 
1624  rewriter.setInsertionPointToEnd(initBlock);
1625  LLVM::BrOp::create(rewriter, loc,
1626  ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1627  rewriter.setInsertionPointToStart(condBlock);
1628  Value indexArg = condBlock->getArgument(0);
1629  Value strideArg = condBlock->getArgument(1);
1630 
1631  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1632  Value pred = LLVM::ICmpOp::create(
1633  rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1634  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1635 
1636  Block *bodyBlock =
1637  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1638  rewriter.setInsertionPointToStart(bodyBlock);
1639 
1640  // Copy size from shape to descriptor.
1641  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1642  Value sizeLoadGep = LLVM::GEPOp::create(
1643  rewriter, loc, llvmIndexPtrType,
1644  typeConverter->convertType(shapeMemRefType.getElementType()),
1645  shapeOperandPtr, indexArg);
1646  Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1647  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1648  targetSizesBase, indexArg, size);
1649 
1650  // Write stride value and compute next one.
1651  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1652  targetStridesBase, indexArg, strideArg);
1653  Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1654 
1655  // Decrement loop counter and branch back.
1656  Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1657  LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
1658  condBlock);
1659 
1660  Block *remainder =
1661  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1662 
1663  // Hook up the cond exit to the remainder.
1664  rewriter.setInsertionPointToEnd(condBlock);
1665  LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
1666  remainder, ValueRange());
1667 
1668  // Reset position to beginning of new remainder block.
1669  rewriter.setInsertionPointToStart(remainder);
1670 
1671  *descriptor = targetDesc;
1672  return success();
1673  }
1674 };
1675 
1676 /// RessociatingReshapeOp must be expanded before we reach this stage.
1677 /// Report that information.
1678 template <typename ReshapeOp>
1679 class ReassociatingReshapeOpConversion
1680  : public ConvertOpToLLVMPattern<ReshapeOp> {
1681 public:
1683  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1684 
1685  LogicalResult
1686  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1687  ConversionPatternRewriter &rewriter) const override {
1688  return rewriter.notifyMatchFailure(
1689  reshapeOp,
1690  "reassociation operations should have been expanded beforehand");
1691  }
1692 };
1693 
1694 /// Subviews must be expanded before we reach this stage.
1695 /// Report that information.
1696 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1698 
1699  LogicalResult
1700  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1701  ConversionPatternRewriter &rewriter) const override {
1702  return rewriter.notifyMatchFailure(
1703  subViewOp, "subview operations should have been expanded beforehand");
1704  }
1705 };
1706 
1707 /// Conversion pattern that transforms a transpose op into:
1708 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1709 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1710 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1711 /// and stride. Size and stride are permutations of the original values.
1712 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1713 /// The transpose op is replaced by the alloca'ed pointer.
1714 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1715 public:
1717 
1718  LogicalResult
1719  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1720  ConversionPatternRewriter &rewriter) const override {
1721  auto loc = transposeOp.getLoc();
1722  MemRefDescriptor viewMemRef(adaptor.getIn());
1723 
1724  // No permutation, early exit.
1725  if (transposeOp.getPermutation().isIdentity())
1726  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1727 
1728  auto targetMemRef = MemRefDescriptor::poison(
1729  rewriter, loc,
1730  typeConverter->convertType(transposeOp.getIn().getType()));
1731 
1732  // Copy the base and aligned pointers from the old descriptor to the new
1733  // one.
1734  targetMemRef.setAllocatedPtr(rewriter, loc,
1735  viewMemRef.allocatedPtr(rewriter, loc));
1736  targetMemRef.setAlignedPtr(rewriter, loc,
1737  viewMemRef.alignedPtr(rewriter, loc));
1738 
1739  // Copy the offset pointer from the old descriptor to the new one.
1740  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1741 
1742  // Iterate over the dimensions and apply size/stride permutation:
1743  // When enumerating the results of the permutation map, the enumeration
1744  // index is the index into the target dimensions and the DimExpr points to
1745  // the dimension of the source memref.
1746  for (const auto &en :
1747  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1748  int targetPos = en.index();
1749  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1750  targetMemRef.setSize(rewriter, loc, targetPos,
1751  viewMemRef.size(rewriter, loc, sourcePos));
1752  targetMemRef.setStride(rewriter, loc, targetPos,
1753  viewMemRef.stride(rewriter, loc, sourcePos));
1754  }
1755 
1756  rewriter.replaceOp(transposeOp, {targetMemRef});
1757  return success();
1758  }
1759 };
1760 
1761 /// Conversion pattern that transforms an op into:
1762 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1763 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1764 /// and stride.
1765 /// The view op is replaced by the descriptor.
1766 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1768 
1769  // Build and return the value for the idx^th shape dimension, either by
1770  // returning the constant shape dimension or counting the proper dynamic size.
1771  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1772  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1773  Type indexType) const {
1774  assert(idx < shape.size());
1775  if (ShapedType::isStatic(shape[idx]))
1776  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1777  // Count the number of dynamic dims in range [0, idx]
1778  unsigned nDynamic =
1779  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1780  return dynamicSizes[nDynamic];
1781  }
1782 
1783  // Build and return the idx^th stride, either by returning the constant stride
1784  // or by computing the dynamic stride from the current `runningStride` and
1785  // `nextSize`. The caller should keep a running stride and update it with the
1786  // result returned by this function.
1787  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1788  ArrayRef<int64_t> strides, Value nextSize,
1789  Value runningStride, unsigned idx, Type indexType) const {
1790  assert(idx < strides.size());
1791  if (ShapedType::isStatic(strides[idx]))
1792  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1793  if (nextSize)
1794  return runningStride
1795  ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1796  : nextSize;
1797  assert(!runningStride);
1798  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1799  }
1800 
1801  LogicalResult
1802  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1803  ConversionPatternRewriter &rewriter) const override {
1804  auto loc = viewOp.getLoc();
1805 
1806  auto viewMemRefType = viewOp.getType();
1807  auto targetElementTy =
1808  typeConverter->convertType(viewMemRefType.getElementType());
1809  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1810  if (!targetDescTy || !targetElementTy ||
1811  !LLVM::isCompatibleType(targetElementTy) ||
1812  !LLVM::isCompatibleType(targetDescTy))
1813  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1814  failure();
1815 
1816  int64_t offset;
1817  SmallVector<int64_t, 4> strides;
1818  auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1819  if (failed(successStrides))
1820  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1821  assert(offset == 0 && "expected offset to be 0");
1822 
1823  // Target memref must be contiguous in memory (innermost stride is 1), or
1824  // empty (special case when at least one of the memref dimensions is 0).
1825  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1826  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1827  failure();
1828 
1829  // Create the descriptor.
1830  MemRefDescriptor sourceMemRef(adaptor.getSource());
1831  auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1832 
1833  // Field 1: Copy the allocated pointer, used for malloc/free.
1834  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1835  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1836  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1837 
1838  // Field 2: Copy the actual aligned pointer to payload.
1839  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1840  alignedPtr = LLVM::GEPOp::create(
1841  rewriter, loc, alignedPtr.getType(),
1842  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1843  adaptor.getByteShift());
1844 
1845  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1846 
1847  Type indexType = getIndexType();
1848  // Field 3: The offset in the resulting type must be 0. This is
1849  // because of the type change: an offset on srcType* may not be
1850  // expressible as an offset on dstType*.
1851  targetMemRef.setOffset(
1852  rewriter, loc,
1853  createIndexAttrConstant(rewriter, loc, indexType, offset));
1854 
1855  // Early exit for 0-D corner case.
1856  if (viewMemRefType.getRank() == 0)
1857  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1858 
1859  // Fields 4 and 5: Update sizes and strides.
1860  Value stride = nullptr, nextSize = nullptr;
1861  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1862  // Update size.
1863  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1864  adaptor.getSizes(), i, indexType);
1865  targetMemRef.setSize(rewriter, loc, i, size);
1866  // Update stride.
1867  stride =
1868  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1869  targetMemRef.setStride(rewriter, loc, i, stride);
1870  nextSize = size;
1871  }
1872 
1873  rewriter.replaceOp(viewOp, {targetMemRef});
1874  return success();
1875  }
1876 };
1877 
1878 //===----------------------------------------------------------------------===//
1879 // AtomicRMWOpLowering
1880 //===----------------------------------------------------------------------===//
1881 
1882 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1883 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1884 static std::optional<LLVM::AtomicBinOp>
1885 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1886  switch (atomicOp.getKind()) {
1887  case arith::AtomicRMWKind::addf:
1888  return LLVM::AtomicBinOp::fadd;
1889  case arith::AtomicRMWKind::addi:
1890  return LLVM::AtomicBinOp::add;
1891  case arith::AtomicRMWKind::assign:
1892  return LLVM::AtomicBinOp::xchg;
1893  case arith::AtomicRMWKind::maximumf:
1894  // TODO: remove this by end of 2025.
1895  LDBG() << "the lowering of memref.atomicrmw maximumf changed "
1896  "from fmax to fmaximum, expect more NaNs";
1897  return LLVM::AtomicBinOp::fmaximum;
1898  case arith::AtomicRMWKind::maxnumf:
1899  return LLVM::AtomicBinOp::fmax;
1900  case arith::AtomicRMWKind::maxs:
1901  return LLVM::AtomicBinOp::max;
1902  case arith::AtomicRMWKind::maxu:
1903  return LLVM::AtomicBinOp::umax;
1904  case arith::AtomicRMWKind::minimumf:
1905  // TODO: remove this by end of 2025.
1906  LDBG() << "the lowering of memref.atomicrmw minimum changed "
1907  "from fmin to fminimum, expect more NaNs";
1908  return LLVM::AtomicBinOp::fminimum;
1909  case arith::AtomicRMWKind::minnumf:
1910  return LLVM::AtomicBinOp::fmin;
1911  case arith::AtomicRMWKind::mins:
1912  return LLVM::AtomicBinOp::min;
1913  case arith::AtomicRMWKind::minu:
1914  return LLVM::AtomicBinOp::umin;
1915  case arith::AtomicRMWKind::ori:
1916  return LLVM::AtomicBinOp::_or;
1917  case arith::AtomicRMWKind::xori:
1918  return LLVM::AtomicBinOp::_xor;
1919  case arith::AtomicRMWKind::andi:
1920  return LLVM::AtomicBinOp::_and;
1921  default:
1922  return std::nullopt;
1923  }
1924  llvm_unreachable("Invalid AtomicRMWKind");
1925 }
1926 
1927 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1928  using Base::Base;
1929 
1930  LogicalResult
1931  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1932  ConversionPatternRewriter &rewriter) const override {
1933  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1934  if (!maybeKind)
1935  return failure();
1936  auto memRefType = atomicOp.getMemRefType();
1937  SmallVector<int64_t> strides;
1938  int64_t offset;
1939  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1940  return failure();
1941  auto dataPtr =
1942  getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1943  adaptor.getMemref(), adaptor.getIndices());
1944  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1945  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1946  LLVM::AtomicOrdering::acq_rel);
1947  return success();
1948  }
1949 };
1950 
1951 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1952 class ConvertExtractAlignedPointerAsIndex
1953  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1954 public:
1955  using ConvertOpToLLVMPattern<
1956  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1957 
1958  LogicalResult
1959  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1960  OpAdaptor adaptor,
1961  ConversionPatternRewriter &rewriter) const override {
1962  BaseMemRefType sourceTy = extractOp.getSource().getType();
1963 
1964  Value alignedPtr;
1965  if (sourceTy.hasRank()) {
1966  MemRefDescriptor desc(adaptor.getSource());
1967  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1968  } else {
1969  auto elementPtrTy = LLVM::LLVMPointerType::get(
1970  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1971 
1972  UnrankedMemRefDescriptor desc(adaptor.getSource());
1973  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1974 
1976  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1977  elementPtrTy);
1978  }
1979 
1980  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1981  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1982  return success();
1983  }
1984 };
1985 
1986 /// Materialize the MemRef descriptor represented by the results of
1987 /// ExtractStridedMetadataOp.
1988 class ExtractStridedMetadataOpLowering
1989  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1990 public:
1991  using ConvertOpToLLVMPattern<
1992  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1993 
1994  LogicalResult
1995  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1996  OpAdaptor adaptor,
1997  ConversionPatternRewriter &rewriter) const override {
1998 
1999  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
2000  return failure();
2001 
2002  // Create the descriptor.
2003  MemRefDescriptor sourceMemRef(adaptor.getSource());
2004  Location loc = extractStridedMetadataOp.getLoc();
2005  Value source = extractStridedMetadataOp.getSource();
2006 
2007  auto sourceMemRefType = cast<MemRefType>(source.getType());
2008  int64_t rank = sourceMemRefType.getRank();
2009  SmallVector<Value> results;
2010  results.reserve(2 + rank * 2);
2011 
2012  // Base buffer.
2013  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
2014  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
2016  rewriter, loc, *getTypeConverter(),
2017  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2018  baseBuffer, alignedBuffer);
2019  results.push_back((Value)dstMemRef);
2020 
2021  // Offset.
2022  results.push_back(sourceMemRef.offset(rewriter, loc));
2023 
2024  // Sizes.
2025  for (unsigned i = 0; i < rank; ++i)
2026  results.push_back(sourceMemRef.size(rewriter, loc, i));
2027  // Strides.
2028  for (unsigned i = 0; i < rank; ++i)
2029  results.push_back(sourceMemRef.stride(rewriter, loc, i));
2030 
2031  rewriter.replaceOp(extractStridedMetadataOp, results);
2032  return success();
2033  }
2034 };
2035 
2036 } // namespace
2037 
2039  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2040  SymbolTableCollection *symbolTables) {
2041  // clang-format off
2042  patterns.add<
2043  AllocaOpLowering,
2044  AllocaScopeOpLowering,
2045  AssumeAlignmentOpLowering,
2046  AtomicRMWOpLowering,
2047  ConvertExtractAlignedPointerAsIndex,
2048  DimOpLowering,
2049  DistinctObjectsOpLowering,
2050  ExtractStridedMetadataOpLowering,
2051  GenericAtomicRMWOpLowering,
2052  GetGlobalMemrefOpLowering,
2053  LoadOpLowering,
2054  MemRefCastOpLowering,
2055  MemRefReinterpretCastOpLowering,
2056  MemRefReshapeOpLowering,
2057  MemorySpaceCastOpLowering,
2058  PrefetchOpLowering,
2059  RankOpLowering,
2060  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2061  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2062  StoreOpLowering,
2063  SubViewOpLowering,
2065  ViewOpLowering>(converter);
2066  // clang-format on
2067  patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2068  symbolTables);
2069  auto allocLowering = converter.getOptions().allocLowering;
2071  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2072  symbolTables);
2073  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2074  patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2075 }
2076 
2077 namespace {
2078 struct FinalizeMemRefToLLVMConversionPass
2079  : public impl::FinalizeMemRefToLLVMConversionPassBase<
2080  FinalizeMemRefToLLVMConversionPass> {
2081  using FinalizeMemRefToLLVMConversionPassBase::
2082  FinalizeMemRefToLLVMConversionPassBase;
2083 
2084  void runOnOperation() override {
2085  Operation *op = getOperation();
2086  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2088  dataLayoutAnalysis.getAtOrAbove(op));
2089  options.allocLowering =
2092 
2093  options.useGenericFunctions = useGenericFunctions;
2094 
2095  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2096  options.overrideIndexBitwidth(indexBitwidth);
2097 
2098  LLVMTypeConverter typeConverter(&getContext(), options,
2099  &dataLayoutAnalysis);
2101  SymbolTableCollection symbolTables;
2103  &symbolTables);
2104  LLVMConversionTarget target(getContext());
2105  target.addLegalOp<func::FuncOp>();
2106  if (failed(applyPartialConversion(op, target, std::move(patterns))))
2107  signalPassFailure();
2108  }
2109 };
2110 
2111 /// Implement the interface to convert MemRef to LLVM.
2112 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2114  void loadDependentDialects(MLIRContext *context) const final {
2115  context->loadDialect<LLVM::LLVMDialect>();
2116  }
2117 
2118  /// Hook for derived dialect interface to provide conversion patterns
2119  /// and mark dialect legal for the conversion target.
2120  void populateConvertToLLVMConversionPatterns(
2121  ConversionTarget &target, LLVMTypeConverter &typeConverter,
2122  RewritePatternSet &patterns) const final {
2124  }
2125 };
2126 
2127 } // namespace
2128 
2130  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2131  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2132  });
2133 }
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:23
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:567
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:533
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:699
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.