MLIR  21.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/Support/MathExtras.h"
29 #include <optional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 namespace {
39 
40 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
41  return !ShapedType::isDynamic(strideOrOffset);
42 }
43 
44 static FailureOr<LLVM::LLVMFuncOp>
45 getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
46  ModuleOp module) {
47  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48 
49  if (useGenericFn)
50  return LLVM::lookupOrCreateGenericFreeFn(b, module);
51 
52  return LLVM::lookupOrCreateFreeFn(b, module);
53 }
54 
55 static FailureOr<LLVM::LLVMFuncOp>
56 getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
57  Operation *module, Type indexType) {
58  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
59  if (useGenericFn)
60  return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
61 
62  return LLVM::lookupOrCreateMallocFn(b, module, indexType);
63 }
64 
65 static FailureOr<LLVM::LLVMFuncOp>
66 getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
67  Operation *module, Type indexType) {
68  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
69 
70  if (useGenericFn)
71  return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
72 
73  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
74 }
75 
76 /// Computes the aligned value for 'input' as follows:
77 /// bumped = input + alignement - 1
78 /// aligned = bumped - bumped % alignment
79 static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
80  Value input, Value alignment) {
81  Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
82  rewriter.getIndexAttr(1));
83  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
84  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
85  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
86  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
87 }
88 
89 /// Computes the byte size for the MemRef element type.
90 static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
91  MemRefType memRefType, Operation *op,
92  const DataLayout *defaultLayout) {
93  const DataLayout *layout = defaultLayout;
94  if (const DataLayoutAnalysis *analysis =
95  typeConverter->getDataLayoutAnalysis()) {
96  layout = &analysis->getAbove(op);
97  }
98  Type elementType = memRefType.getElementType();
99  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
100  return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
101  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
102  return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
103  *layout);
104  return layout->getTypeSize(elementType);
105 }
106 
107 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
108  Location loc, Value allocatedPtr,
109  MemRefType memRefType, Type elementPtrType,
110  const LLVMTypeConverter &typeConverter) {
111  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
112  FailureOr<unsigned> maybeMemrefAddrSpace =
113  typeConverter.getMemRefAddressSpace(memRefType);
114  assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
115  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
116  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
117  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
118  loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
119  allocatedPtr);
120  return allocatedPtr;
121 }
122 
123 struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
125 
126  LogicalResult
127  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
128  ConversionPatternRewriter &rewriter) const override {
129  auto loc = op.getLoc();
130  MemRefType memRefType = op.getType();
131  if (!isConvertibleAndHasIdentityMaps(memRefType))
132  return rewriter.notifyMatchFailure(op, "incompatible memref type");
133 
134  // Get or insert alloc function into the module.
135  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
136  rewriter, getTypeConverter(),
137  op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
138  if (failed(allocFuncOp))
139  return failure();
140 
141  // Get actual sizes of the memref as values: static sizes are constant
142  // values and dynamic sizes are passed to 'alloc' as operands. In case of
143  // zero-dimensional memref, assume a scalar (size 1).
144  SmallVector<Value, 4> sizes;
145  SmallVector<Value, 4> strides;
146  Value sizeBytes;
147 
148  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
149  rewriter, sizes, strides, sizeBytes, true);
150 
151  Value alignment = getAlignment(rewriter, loc, op);
152  if (alignment) {
153  // Adjust the allocation size to consider alignment.
154  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
155  }
156 
157  // Allocate the underlying buffer.
158  Type elementPtrType = this->getElementPtrType(memRefType);
159  assert(elementPtrType && "could not compute element ptr type");
160  auto results =
161  rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
162 
163  Value allocatedPtr =
164  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
165  elementPtrType, *getTypeConverter());
166  Value alignedPtr = allocatedPtr;
167  if (alignment) {
168  // Compute the aligned pointer.
169  Value allocatedInt =
170  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
171  Value alignmentInt =
172  createAligned(rewriter, loc, allocatedInt, alignment);
173  alignedPtr =
174  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
175  }
176 
177  // Create the MemRef descriptor.
178  auto memRefDescriptor = this->createMemRefDescriptor(
179  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
180 
181  // Return the final value of the descriptor.
182  rewriter.replaceOp(op, {memRefDescriptor});
183  return success();
184  }
185 
186  /// Computes the alignment for the given memory allocation op.
187  template <typename OpType>
188  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
189  OpType op) const {
190  MemRefType memRefType = op.getType();
191  Value alignment;
192  if (auto alignmentAttr = op.getAlignment()) {
193  Type indexType = getIndexType();
194  alignment =
195  createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
196  } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
197  // In the case where no alignment is specified, we may want to override
198  // `malloc's` behavior. `malloc` typically aligns at the size of the
199  // biggest scalar on a target HW. For non-scalars, use the natural
200  // alignment of the LLVM type given by the LLVM DataLayout.
201  alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
202  }
203  return alignment;
204  }
205 };
206 
207 struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
209 
210  LogicalResult
211  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
212  ConversionPatternRewriter &rewriter) const override {
213  auto loc = op.getLoc();
214  MemRefType memRefType = op.getType();
215  if (!isConvertibleAndHasIdentityMaps(memRefType))
216  return rewriter.notifyMatchFailure(op, "incompatible memref type");
217 
218  // Get or insert alloc function into module.
219  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
220  rewriter, getTypeConverter(),
221  op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
222  if (failed(allocFuncOp))
223  return failure();
224 
225  // Get actual sizes of the memref as values: static sizes are constant
226  // values and dynamic sizes are passed to 'alloc' as operands. In case of
227  // zero-dimensional memref, assume a scalar (size 1).
228  SmallVector<Value, 4> sizes;
229  SmallVector<Value, 4> strides;
230  Value sizeBytes;
231 
232  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
233  rewriter, sizes, strides, sizeBytes, !false);
234 
235  int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
236 
237  Value allocAlignment =
238  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
239 
240  // Function aligned_alloc requires size to be a multiple of alignment; we
241  // pad the size to the next multiple if necessary.
242  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
243  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
244 
245  Type elementPtrType = this->getElementPtrType(memRefType);
246  auto results = rewriter.create<LLVM::CallOp>(
247  loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
248 
249  Value ptr =
250  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
251  elementPtrType, *getTypeConverter());
252 
253  // Create the MemRef descriptor.
254  auto memRefDescriptor = this->createMemRefDescriptor(
255  loc, memRefType, ptr, ptr, sizes, strides, rewriter);
256 
257  // Return the final value of the descriptor.
258  rewriter.replaceOp(op, {memRefDescriptor});
259  return success();
260  }
261 
262  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
263  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
264 
265  /// Computes the alignment for aligned_alloc used to allocate the buffer for
266  /// the memory allocation op.
267  ///
268  /// Aligned_alloc requires the allocation size to be a power of two, and the
269  /// allocation size to be a multiple of the alignment.
270  int64_t alignedAllocationGetAlignment(memref::AllocOp op,
271  const DataLayout *defaultLayout) const {
272  if (std::optional<uint64_t> alignment = op.getAlignment())
273  return *alignment;
274 
275  // Whenever we don't have alignment set, we will use an alignment
276  // consistent with the element type; since the allocation size has to be a
277  // power of two, we will bump to the next power of two if it isn't.
278  unsigned eltSizeBytes = getMemRefEltSizeInBytes(
279  getTypeConverter(), op.getType(), op, defaultLayout);
280  return std::max(kMinAlignedAllocAlignment,
281  llvm::PowerOf2Ceil(eltSizeBytes));
282  }
283 
284  /// Returns true if the memref size in bytes is known to be a multiple of
285  /// factor.
286  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
287  const DataLayout *defaultLayout) const {
288  uint64_t sizeDivisor =
289  getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
290  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
291  if (type.isDynamicDim(i))
292  continue;
293  sizeDivisor = sizeDivisor * type.getDimSize(i);
294  }
295  return sizeDivisor % factor == 0;
296  }
297 
298 private:
299  /// Default layout to use in absence of the corresponding analysis.
300  DataLayout defaultLayout;
301 };
302 
303 struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
305 
306  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
307  /// is set to null for stack allocations. `accessAlignment` is set if
308  /// alignment is needed post allocation (for eg. in conjunction with malloc).
309  LogicalResult
310  matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const override {
312  auto loc = op.getLoc();
313  MemRefType memRefType = op.getType();
314  if (!isConvertibleAndHasIdentityMaps(memRefType))
315  return rewriter.notifyMatchFailure(op, "incompatible memref type");
316 
317  // Get actual sizes of the memref as values: static sizes are constant
318  // values and dynamic sizes are passed to 'alloc' as operands. In case of
319  // zero-dimensional memref, assume a scalar (size 1).
320  SmallVector<Value, 4> sizes;
321  SmallVector<Value, 4> strides;
322  Value size;
323 
324  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
325  rewriter, sizes, strides, size, !true);
326 
327  // With alloca, one gets a pointer to the element type right away.
328  // For stack allocations.
329  auto elementType =
330  typeConverter->convertType(op.getType().getElementType());
331  FailureOr<unsigned> maybeAddressSpace =
332  getTypeConverter()->getMemRefAddressSpace(op.getType());
333  assert(succeeded(maybeAddressSpace) && "unsupported address space");
334  unsigned addrSpace = *maybeAddressSpace;
335  auto elementPtrType =
336  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
337 
338  auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
339  loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
340 
341  // Create the MemRef descriptor.
342  auto memRefDescriptor = this->createMemRefDescriptor(
343  loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
344  strides, rewriter);
345 
346  // Return the final value of the descriptor.
347  rewriter.replaceOp(op, {memRefDescriptor});
348  return success();
349  }
350 };
351 
352 struct AllocaScopeOpLowering
353  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
355 
356  LogicalResult
357  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
358  ConversionPatternRewriter &rewriter) const override {
359  OpBuilder::InsertionGuard guard(rewriter);
360  Location loc = allocaScopeOp.getLoc();
361 
362  // Split the current block before the AllocaScopeOp to create the inlining
363  // point.
364  auto *currentBlock = rewriter.getInsertionBlock();
365  auto *remainingOpsBlock =
366  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
367  Block *continueBlock;
368  if (allocaScopeOp.getNumResults() == 0) {
369  continueBlock = remainingOpsBlock;
370  } else {
371  continueBlock = rewriter.createBlock(
372  remainingOpsBlock, allocaScopeOp.getResultTypes(),
373  SmallVector<Location>(allocaScopeOp->getNumResults(),
374  allocaScopeOp.getLoc()));
375  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
376  }
377 
378  // Inline body region.
379  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
380  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
381  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
382 
383  // Save stack and then branch into the body of the region.
384  rewriter.setInsertionPointToEnd(currentBlock);
385  auto stackSaveOp =
386  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
387  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
388 
389  // Replace the alloca_scope return with a branch that jumps out of the body.
390  // Stack restore before leaving the body region.
391  rewriter.setInsertionPointToEnd(afterBody);
392  auto returnOp =
393  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
394  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
395  returnOp, returnOp.getResults(), continueBlock);
396 
397  // Insert stack restore before jumping out the body of the region.
398  rewriter.setInsertionPoint(branchOp);
399  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
400 
401  // Replace the op with values return from the body region.
402  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
403 
404  return success();
405  }
406 };
407 
408 struct AssumeAlignmentOpLowering
409  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
411  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
412  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
413  : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
414 
415  LogicalResult
416  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
417  ConversionPatternRewriter &rewriter) const override {
418  Value memref = adaptor.getMemref();
419  unsigned alignment = op.getAlignment();
420  auto loc = op.getLoc();
421 
422  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
423  Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
424  rewriter);
425 
426  // Emit llvm.assume(true) ["align"(memref, alignment)].
427  // This is more direct than ptrtoint-based checks, is explicitly supported,
428  // and works with non-integral address spaces.
429  Value trueCond =
430  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
431  Value alignmentConst =
432  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
433  rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
434  alignmentConst);
435 
436  rewriter.eraseOp(op);
437  return success();
438  }
439 };
440 
441 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
442 // The memref descriptor being an SSA value, there is no need to clean it up
443 // in any way.
444 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
446 
447  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
448  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
449 
450  LogicalResult
451  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
452  ConversionPatternRewriter &rewriter) const override {
453  // Insert the `free` declaration if it is not already present.
454  FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
455  rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
456  if (failed(freeFunc))
457  return failure();
458  Value allocatedPtr;
459  if (auto unrankedTy =
460  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
461  auto elementPtrTy = LLVM::LLVMPointerType::get(
462  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
464  rewriter, op.getLoc(),
465  UnrankedMemRefDescriptor(adaptor.getMemref())
466  .memRefDescPtr(rewriter, op.getLoc()),
467  elementPtrTy);
468  } else {
469  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
470  .allocatedPtr(rewriter, op.getLoc());
471  }
472  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
473  allocatedPtr);
474  return success();
475  }
476 };
477 
478 // A `dim` is converted to a constant for static sizes and to an access to the
479 // size stored in the memref descriptor for dynamic sizes.
480 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
482 
483  LogicalResult
484  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486  Type operandType = dimOp.getSource().getType();
487  if (isa<UnrankedMemRefType>(operandType)) {
488  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
489  operandType, dimOp, adaptor.getOperands(), rewriter);
490  if (failed(extractedSize))
491  return failure();
492  rewriter.replaceOp(dimOp, {*extractedSize});
493  return success();
494  }
495  if (isa<MemRefType>(operandType)) {
496  rewriter.replaceOp(
497  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
498  adaptor.getOperands(), rewriter)});
499  return success();
500  }
501  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
502  }
503 
504 private:
505  FailureOr<Value>
506  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
507  OpAdaptor adaptor,
508  ConversionPatternRewriter &rewriter) const {
509  Location loc = dimOp.getLoc();
510 
511  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
512  auto scalarMemRefType =
513  MemRefType::get({}, unrankedMemRefType.getElementType());
514  FailureOr<unsigned> maybeAddressSpace =
515  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
516  if (failed(maybeAddressSpace)) {
517  dimOp.emitOpError("memref memory space must be convertible to an integer "
518  "address space");
519  return failure();
520  }
521  unsigned addressSpace = *maybeAddressSpace;
522 
523  // Extract pointer to the underlying ranked descriptor and bitcast it to a
524  // memref<element_type> descriptor pointer to minimize the number of GEP
525  // operations.
526  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
527  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
528 
529  Type elementType = typeConverter->convertType(scalarMemRefType);
530 
531  // Get pointer to offset field of memref<element_type> descriptor.
532  auto indexPtrTy =
533  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
534  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
535  loc, indexPtrTy, elementType, underlyingRankedDesc,
536  ArrayRef<LLVM::GEPArg>{0, 2});
537 
538  // The size value that we have to extract can be obtained using GEPop with
539  // `dimOp.index() + 1` index argument.
540  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
541  loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
542  adaptor.getIndex());
543  Value sizePtr = rewriter.create<LLVM::GEPOp>(
544  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
545  idxPlusOne);
546  return rewriter
547  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
548  .getResult();
549  }
550 
551  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
552  if (auto idx = dimOp.getConstantIndex())
553  return idx;
554 
555  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
556  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
557 
558  return std::nullopt;
559  }
560 
561  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
562  OpAdaptor adaptor,
563  ConversionPatternRewriter &rewriter) const {
564  Location loc = dimOp.getLoc();
565 
566  // Take advantage if index is constant.
567  MemRefType memRefType = cast<MemRefType>(operandType);
568  Type indexType = getIndexType();
569  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
570  int64_t i = *index;
571  if (i >= 0 && i < memRefType.getRank()) {
572  if (memRefType.isDynamicDim(i)) {
573  // extract dynamic size from the memref descriptor.
574  MemRefDescriptor descriptor(adaptor.getSource());
575  return descriptor.size(rewriter, loc, i);
576  }
577  // Use constant for static size.
578  int64_t dimSize = memRefType.getDimSize(i);
579  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
580  }
581  }
582  Value index = adaptor.getIndex();
583  int64_t rank = memRefType.getRank();
584  MemRefDescriptor memrefDescriptor(adaptor.getSource());
585  return memrefDescriptor.size(rewriter, loc, index, rank);
586  }
587 };
588 
589 /// Common base for load and store operations on MemRefs. Restricts the match
590 /// to supported MemRef types. Provides functionality to emit code accessing a
591 /// specific element of the underlying data buffer.
592 template <typename Derived>
593 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
596  using Base = LoadStoreOpLowering<Derived>;
597 };
598 
599 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
600 /// retried until it succeeds in atomically storing a new value into memory.
601 ///
602 /// +---------------------------------+
603 /// | <code before the AtomicRMWOp> |
604 /// | <compute initial %loaded> |
605 /// | cf.br loop(%loaded) |
606 /// +---------------------------------+
607 /// |
608 /// -------| |
609 /// | v v
610 /// | +--------------------------------+
611 /// | | loop(%loaded): |
612 /// | | <body contents> |
613 /// | | %pair = cmpxchg |
614 /// | | %ok = %pair[0] |
615 /// | | %new = %pair[1] |
616 /// | | cf.cond_br %ok, end, loop(%new) |
617 /// | +--------------------------------+
618 /// | | |
619 /// |----------- |
620 /// v
621 /// +--------------------------------+
622 /// | end: |
623 /// | <code after the AtomicRMWOp> |
624 /// +--------------------------------+
625 ///
626 struct GenericAtomicRMWOpLowering
627  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
628  using Base::Base;
629 
630  LogicalResult
631  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  auto loc = atomicOp.getLoc();
634  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
635 
636  // Split the block into initial, loop, and ending parts.
637  auto *initBlock = rewriter.getInsertionBlock();
638  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
639  loopBlock->addArgument(valueType, loc);
640 
641  auto *endBlock =
642  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
643 
644  // Compute the loaded value and branch to the loop block.
645  rewriter.setInsertionPointToEnd(initBlock);
646  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
647  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
648  adaptor.getIndices(), rewriter);
649  Value init = rewriter.create<LLVM::LoadOp>(
650  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
651  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
652 
653  // Prepare the body of the loop block.
654  rewriter.setInsertionPointToStart(loopBlock);
655 
656  // Clone the GenericAtomicRMWOp region and extract the result.
657  auto loopArgument = loopBlock->getArgument(0);
658  IRMapping mapping;
659  mapping.map(atomicOp.getCurrentValue(), loopArgument);
660  Block &entryBlock = atomicOp.body().front();
661  for (auto &nestedOp : entryBlock.without_terminator()) {
662  Operation *clone = rewriter.clone(nestedOp, mapping);
663  mapping.map(nestedOp.getResults(), clone->getResults());
664  }
665  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
666 
667  // Prepare the epilog of the loop block.
668  // Append the cmpxchg op to the end of the loop block.
669  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
670  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
671  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
672  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
673  // Extract the %new_loaded and %ok values from the pair.
674  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
675  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
676 
677  // Conditionally branch to the end or back to the loop depending on %ok.
678  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
679  loopBlock, newLoaded);
680 
681  rewriter.setInsertionPointToEnd(endBlock);
682 
683  // The 'result' of the atomic_rmw op is the newly loaded value.
684  rewriter.replaceOp(atomicOp, {newLoaded});
685 
686  return success();
687  }
688 };
689 
690 /// Returns the LLVM type of the global variable given the memref type `type`.
691 static Type
692 convertGlobalMemrefTypeToLLVM(MemRefType type,
693  const LLVMTypeConverter &typeConverter) {
694  // LLVM type for a global memref will be a multi-dimension array. For
695  // declarations or uninitialized global memrefs, we can potentially flatten
696  // this to a 1D array. However, for memref.global's with an initial value,
697  // we do not intend to flatten the ElementsAttribute when going from std ->
698  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
699  Type elementType = typeConverter.convertType(type.getElementType());
700  Type arrayTy = elementType;
701  // Shape has the outermost dim at index 0, so need to walk it backwards
702  for (int64_t dim : llvm::reverse(type.getShape()))
703  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
704  return arrayTy;
705 }
706 
707 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
708 struct GlobalMemrefOpLowering
709  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
711 
712  LogicalResult
713  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
714  ConversionPatternRewriter &rewriter) const override {
715  MemRefType type = global.getType();
716  if (!isConvertibleAndHasIdentityMaps(type))
717  return failure();
718 
719  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
720 
721  LLVM::Linkage linkage =
722  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
723 
724  Attribute initialValue = nullptr;
725  if (!global.isExternal() && !global.isUninitialized()) {
726  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
727  initialValue = elementsAttr;
728 
729  // For scalar memrefs, the global variable created is of the element type,
730  // so unpack the elements attribute to extract the value.
731  if (type.getRank() == 0)
732  initialValue = elementsAttr.getSplatValue<Attribute>();
733  }
734 
735  uint64_t alignment = global.getAlignment().value_or(0);
736  FailureOr<unsigned> addressSpace =
737  getTypeConverter()->getMemRefAddressSpace(type);
738  if (failed(addressSpace))
739  return global.emitOpError(
740  "memory space cannot be converted to an integer address space");
741  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
742  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
743  initialValue, alignment, *addressSpace);
744  if (!global.isExternal() && global.isUninitialized()) {
745  rewriter.createBlock(&newGlobal.getInitializerRegion());
746  Value undef[] = {
747  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
748  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
749  }
750  return success();
751  }
752 };
753 
754 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
755 /// the first element stashed into the descriptor. This reuses
756 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
757 struct GetGlobalMemrefOpLowering
758  : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
760 
761  /// Buffer "allocation" for memref.get_global op is getting the address of
762  /// the global variable referenced.
763  LogicalResult
764  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
765  ConversionPatternRewriter &rewriter) const override {
766  auto loc = op.getLoc();
767  MemRefType memRefType = op.getType();
768  if (!isConvertibleAndHasIdentityMaps(memRefType))
769  return rewriter.notifyMatchFailure(op, "incompatible memref type");
770 
771  // Get actual sizes of the memref as values: static sizes are constant
772  // values and dynamic sizes are passed to 'alloc' as operands. In case of
773  // zero-dimensional memref, assume a scalar (size 1).
774  SmallVector<Value, 4> sizes;
775  SmallVector<Value, 4> strides;
776  Value sizeBytes;
777 
778  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
779  rewriter, sizes, strides, sizeBytes, !false);
780 
781  MemRefType type = cast<MemRefType>(op.getResult().getType());
782 
783  // This is called after a type conversion, which would have failed if this
784  // call fails.
785  FailureOr<unsigned> maybeAddressSpace =
786  getTypeConverter()->getMemRefAddressSpace(type);
787  assert(succeeded(maybeAddressSpace) && "unsupported address space");
788  unsigned memSpace = *maybeAddressSpace;
789 
790  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
791  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
792  auto addressOf =
793  rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
794 
795  // Get the address of the first element in the array by creating a GEP with
796  // the address of the GV as the base, and (rank + 1) number of 0 indices.
797  auto gep = rewriter.create<LLVM::GEPOp>(
798  loc, ptrTy, arrayTy, addressOf,
799  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
800 
801  // We do not expect the memref obtained using `memref.get_global` to be
802  // ever deallocated. Set the allocated pointer to be known bad value to
803  // help debug if that ever happens.
804  auto intPtrType = getIntPtrType(memSpace);
805  Value deadBeefConst =
806  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
807  auto deadBeefPtr =
808  rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
809 
810  // Both allocated and aligned pointers are same. We could potentially stash
811  // a nullptr for the allocated pointer since we do not expect any dealloc.
812  // Create the MemRef descriptor.
813  auto memRefDescriptor = this->createMemRefDescriptor(
814  loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
815 
816  // Return the final value of the descriptor.
817  rewriter.replaceOp(op, {memRefDescriptor});
818  return success();
819  }
820 };
821 
822 // Load operation is lowered to obtaining a pointer to the indexed element
823 // and loading it.
824 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
825  using Base::Base;
826 
827  LogicalResult
828  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
829  ConversionPatternRewriter &rewriter) const override {
830  auto type = loadOp.getMemRefType();
831 
832  Value dataPtr =
833  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
834  adaptor.getIndices(), rewriter);
835  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
836  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
837  false, loadOp.getNontemporal());
838  return success();
839  }
840 };
841 
842 // Store operation is lowered to obtaining a pointer to the indexed element,
843 // and storing the given value to it.
844 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
845  using Base::Base;
846 
847  LogicalResult
848  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
849  ConversionPatternRewriter &rewriter) const override {
850  auto type = op.getMemRefType();
851 
852  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
853  adaptor.getIndices(), rewriter);
854  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
855  0, false, op.getNontemporal());
856  return success();
857  }
858 };
859 
860 // The prefetch operation is lowered in a way similar to the load operation
861 // except that the llvm.prefetch operation is used for replacement.
862 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
863  using Base::Base;
864 
865  LogicalResult
866  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
867  ConversionPatternRewriter &rewriter) const override {
868  auto type = prefetchOp.getMemRefType();
869  auto loc = prefetchOp.getLoc();
870 
871  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
872  adaptor.getIndices(), rewriter);
873 
874  // Replace with llvm.prefetch.
875  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
876  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
877  IntegerAttr isData =
878  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
879  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
880  localityHint, isData);
881  return success();
882  }
883 };
884 
885 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
887 
888  LogicalResult
889  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
890  ConversionPatternRewriter &rewriter) const override {
891  Location loc = op.getLoc();
892  Type operandType = op.getMemref().getType();
893  if (dyn_cast<UnrankedMemRefType>(operandType)) {
894  UnrankedMemRefDescriptor desc(adaptor.getMemref());
895  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
896  return success();
897  }
898  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
899  Type indexType = getIndexType();
900  rewriter.replaceOp(op,
901  {createIndexAttrConstant(rewriter, loc, indexType,
902  rankedMemRefType.getRank())});
903  return success();
904  }
905  return failure();
906  }
907 };
908 
909 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
911 
912  LogicalResult
913  matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
914  ConversionPatternRewriter &rewriter) const override {
915  Type srcType = memRefCastOp.getOperand().getType();
916  Type dstType = memRefCastOp.getType();
917 
918  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
919  // used for type erasure. For now they must preserve underlying element type
920  // and require source and result type to have the same rank. Therefore,
921  // perform a sanity check that the underlying structs are the same. Once op
922  // semantics are relaxed we can revisit.
923  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
924  if (typeConverter->convertType(srcType) !=
925  typeConverter->convertType(dstType))
926  return failure();
927 
928  // Unranked to unranked cast is disallowed
929  if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
930  return failure();
931 
932  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
933  auto loc = memRefCastOp.getLoc();
934 
935  // For ranked/ranked case, just keep the original descriptor.
936  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
937  rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
938  return success();
939  }
940 
941  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
942  // Casting ranked to unranked memref type
943  // Set the rank in the destination from the memref type
944  // Allocate space on the stack and copy the src memref descriptor
945  // Set the ptr in the destination to the stack space
946  auto srcMemRefType = cast<MemRefType>(srcType);
947  int64_t rank = srcMemRefType.getRank();
948  // ptr = AllocaOp sizeof(MemRefDescriptor)
949  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
950  loc, adaptor.getSource(), rewriter);
951 
952  // rank = ConstantOp srcRank
953  auto rankVal = rewriter.create<LLVM::ConstantOp>(
954  loc, getIndexType(), rewriter.getIndexAttr(rank));
955  // poison = PoisonOp
956  UnrankedMemRefDescriptor memRefDesc =
957  UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
958  // d1 = InsertValueOp poison, rank, 0
959  memRefDesc.setRank(rewriter, loc, rankVal);
960  // d2 = InsertValueOp d1, ptr, 1
961  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
962  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
963 
964  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
965  // Casting from unranked type to ranked.
966  // The operation is assumed to be doing a correct cast. If the destination
967  // type mismatches the unranked the type, it is undefined behavior.
968  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
969  // ptr = ExtractValueOp src, 1
970  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
971 
972  // struct = LoadOp ptr
973  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
974  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
975  } else {
976  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
977  }
978 
979  return success();
980  }
981 };
982 
983 /// Pattern to lower a `memref.copy` to llvm.
984 ///
985 /// For memrefs with identity layouts, the copy is lowered to the llvm
986 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
987 /// to the generic `MemrefCopyFn`.
988 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
990 
991  LogicalResult
992  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
993  ConversionPatternRewriter &rewriter) const {
994  auto loc = op.getLoc();
995  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
996 
997  MemRefDescriptor srcDesc(adaptor.getSource());
998 
999  // Compute number of elements.
1000  Value numElements = rewriter.create<LLVM::ConstantOp>(
1001  loc, getIndexType(), rewriter.getIndexAttr(1));
1002  for (int pos = 0; pos < srcType.getRank(); ++pos) {
1003  auto size = srcDesc.size(rewriter, loc, pos);
1004  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
1005  }
1006 
1007  // Get element size.
1008  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1009  // Compute total.
1010  Value totalSize =
1011  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
1012 
1013  Type elementType = typeConverter->convertType(srcType.getElementType());
1014 
1015  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1016  Value srcOffset = srcDesc.offset(rewriter, loc);
1017  Value srcPtr = rewriter.create<LLVM::GEPOp>(
1018  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
1019  MemRefDescriptor targetDesc(adaptor.getTarget());
1020  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1021  Value targetOffset = targetDesc.offset(rewriter, loc);
1022  Value targetPtr = rewriter.create<LLVM::GEPOp>(
1023  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
1024  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
1025  /*isVolatile=*/false);
1026  rewriter.eraseOp(op);
1027 
1028  return success();
1029  }
1030 
1031  LogicalResult
1032  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1033  ConversionPatternRewriter &rewriter) const {
1034  auto loc = op.getLoc();
1035  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1036  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1037 
1038  // First make sure we have an unranked memref descriptor representation.
1039  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1040  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1041  type.getRank());
1042  auto *typeConverter = getTypeConverter();
1043  auto ptr =
1044  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1045 
1046  auto unrankedType =
1047  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1049  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1050  };
1051 
1052  // Save stack position before promoting descriptors
1053  auto stackSaveOp =
1054  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1055 
1056  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1057  Value unrankedSource =
1058  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1059  : adaptor.getSource();
1060  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1061  Value unrankedTarget =
1062  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1063  : adaptor.getTarget();
1064 
1065  // Now promote the unranked descriptors to the stack.
1066  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1067  rewriter.getIndexAttr(1));
1068  auto promote = [&](Value desc) {
1069  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1070  auto allocated =
1071  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1072  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
1073  return allocated;
1074  };
1075 
1076  auto sourcePtr = promote(unrankedSource);
1077  auto targetPtr = promote(unrankedTarget);
1078 
1079  // Derive size from llvm.getelementptr which will account for any
1080  // potential alignment
1081  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1082  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1083  rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1084  sourcePtr.getType());
1085  if (failed(copyFn))
1086  return failure();
1087  rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
1088  ValueRange{elemSize, sourcePtr, targetPtr});
1089 
1090  // Restore stack used for descriptors
1091  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1092 
1093  rewriter.eraseOp(op);
1094 
1095  return success();
1096  }
1097 
1098  LogicalResult
1099  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1100  ConversionPatternRewriter &rewriter) const override {
1101  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1102  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1103 
1104  auto isContiguousMemrefType = [&](BaseMemRefType type) {
1105  auto memrefType = dyn_cast<mlir::MemRefType>(type);
1106  // We can use memcpy for memrefs if they have an identity layout or are
1107  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1108  // special case handled by memrefCopy.
1109  return memrefType &&
1110  (memrefType.getLayout().isIdentity() ||
1111  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1113  };
1114 
1115  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1116  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1117 
1118  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1119  }
1120 };
1121 
1122 struct MemorySpaceCastOpLowering
1123  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1124  using ConvertOpToLLVMPattern<
1125  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1126 
1127  LogicalResult
1128  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1129  ConversionPatternRewriter &rewriter) const override {
1130  Location loc = op.getLoc();
1131 
1132  Type resultType = op.getDest().getType();
1133  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1134  auto resultDescType =
1135  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1136  Type newPtrType = resultDescType.getBody()[0];
1137 
1138  SmallVector<Value> descVals;
1139  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1140  descVals);
1141  descVals[0] =
1142  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1143  descVals[1] =
1144  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1145  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1146  resultTypeR, descVals);
1147  rewriter.replaceOp(op, result);
1148  return success();
1149  }
1150  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1151  // Since the type converter won't be doing this for us, get the address
1152  // space.
1153  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1154  FailureOr<unsigned> maybeSourceAddrSpace =
1155  getTypeConverter()->getMemRefAddressSpace(sourceType);
1156  if (failed(maybeSourceAddrSpace))
1157  return rewriter.notifyMatchFailure(loc,
1158  "non-integer source address space");
1159  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1160  FailureOr<unsigned> maybeResultAddrSpace =
1161  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1162  if (failed(maybeResultAddrSpace))
1163  return rewriter.notifyMatchFailure(loc,
1164  "non-integer result address space");
1165  unsigned resultAddrSpace = *maybeResultAddrSpace;
1166 
1167  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1168  Value rank = sourceDesc.rank(rewriter, loc);
1169  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1170 
1171  // Create and allocate storage for new memref descriptor.
1172  auto result = UnrankedMemRefDescriptor::poison(
1173  rewriter, loc, typeConverter->convertType(resultTypeU));
1174  result.setRank(rewriter, loc, rank);
1175  SmallVector<Value, 1> sizes;
1176  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1177  result, resultAddrSpace, sizes);
1178  Value resultUnderlyingSize = sizes.front();
1179  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
1180  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
1181  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1182 
1183  // Copy pointers, performing address space casts.
1184  auto sourceElemPtrType =
1185  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1186  auto resultElemPtrType =
1187  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1188 
1189  Value allocatedPtr = sourceDesc.allocatedPtr(
1190  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1191  Value alignedPtr =
1192  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1193  sourceUnderlyingDesc, sourceElemPtrType);
1194  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1195  loc, resultElemPtrType, allocatedPtr);
1196  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1197  loc, resultElemPtrType, alignedPtr);
1198 
1199  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1200  resultElemPtrType, allocatedPtr);
1201  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1202  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1203 
1204  // Copy all the index-valued operands.
1205  Value sourceIndexVals =
1206  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1207  sourceUnderlyingDesc, sourceElemPtrType);
1208  Value resultIndexVals =
1209  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1210  resultUnderlyingDesc, resultElemPtrType);
1211 
1212  int64_t bytesToSkip =
1213  2 * llvm::divideCeil(
1214  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1215  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
1216  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1217  Value copySize = rewriter.create<LLVM::SubOp>(
1218  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1219  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1220  copySize, /*isVolatile=*/false);
1221 
1222  rewriter.replaceOp(op, ValueRange{result});
1223  return success();
1224  }
1225  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1226  }
1227 };
1228 
1229 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
1230 /// memref type. In unranked case, the fields are extracted from the underlying
1231 /// ranked descriptor.
1232 static void extractPointersAndOffset(Location loc,
1233  ConversionPatternRewriter &rewriter,
1234  const LLVMTypeConverter &typeConverter,
1235  Value originalOperand,
1236  Value convertedOperand,
1237  Value *allocatedPtr, Value *alignedPtr,
1238  Value *offset = nullptr) {
1239  Type operandType = originalOperand.getType();
1240  if (isa<MemRefType>(operandType)) {
1241  MemRefDescriptor desc(convertedOperand);
1242  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1243  *alignedPtr = desc.alignedPtr(rewriter, loc);
1244  if (offset != nullptr)
1245  *offset = desc.offset(rewriter, loc);
1246  return;
1247  }
1248 
1249  // These will all cause assert()s on unconvertible types.
1250  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1251  cast<UnrankedMemRefType>(operandType));
1252  auto elementPtrType =
1253  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1254 
1255  // Extract pointer to the underlying ranked memref descriptor and cast it to
1256  // ElemType**.
1257  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1258  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1259 
1261  rewriter, loc, underlyingDescPtr, elementPtrType);
1263  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1264  if (offset != nullptr) {
1266  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1267  }
1268 }
1269 
1270 struct MemRefReinterpretCastOpLowering
1271  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1272  using ConvertOpToLLVMPattern<
1273  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1274 
1275  LogicalResult
1276  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1277  ConversionPatternRewriter &rewriter) const override {
1278  Type srcType = castOp.getSource().getType();
1279 
1280  Value descriptor;
1281  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1282  adaptor, &descriptor)))
1283  return failure();
1284  rewriter.replaceOp(castOp, {descriptor});
1285  return success();
1286  }
1287 
1288 private:
1289  LogicalResult convertSourceMemRefToDescriptor(
1290  ConversionPatternRewriter &rewriter, Type srcType,
1291  memref::ReinterpretCastOp castOp,
1292  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1293  MemRefType targetMemRefType =
1294  cast<MemRefType>(castOp.getResult().getType());
1295  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1296  typeConverter->convertType(targetMemRefType));
1297  if (!llvmTargetDescriptorTy)
1298  return failure();
1299 
1300  // Create descriptor.
1301  Location loc = castOp.getLoc();
1302  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1303 
1304  // Set allocated and aligned pointers.
1305  Value allocatedPtr, alignedPtr;
1306  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1307  castOp.getSource(), adaptor.getSource(),
1308  &allocatedPtr, &alignedPtr);
1309  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1310  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1311 
1312  // Set offset.
1313  if (castOp.isDynamicOffset(0))
1314  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1315  else
1316  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1317 
1318  // Set sizes and strides.
1319  unsigned dynSizeId = 0;
1320  unsigned dynStrideId = 0;
1321  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1322  if (castOp.isDynamicSize(i))
1323  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1324  else
1325  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1326 
1327  if (castOp.isDynamicStride(i))
1328  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1329  else
1330  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1331  }
1332  *descriptor = desc;
1333  return success();
1334  }
1335 };
1336 
1337 struct MemRefReshapeOpLowering
1338  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1340 
1341  LogicalResult
1342  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1343  ConversionPatternRewriter &rewriter) const override {
1344  Type srcType = reshapeOp.getSource().getType();
1345 
1346  Value descriptor;
1347  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1348  adaptor, &descriptor)))
1349  return failure();
1350  rewriter.replaceOp(reshapeOp, {descriptor});
1351  return success();
1352  }
1353 
1354 private:
1355  LogicalResult
1356  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1357  Type srcType, memref::ReshapeOp reshapeOp,
1358  memref::ReshapeOp::Adaptor adaptor,
1359  Value *descriptor) const {
1360  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1361  if (shapeMemRefType.hasStaticShape()) {
1362  MemRefType targetMemRefType =
1363  cast<MemRefType>(reshapeOp.getResult().getType());
1364  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1365  typeConverter->convertType(targetMemRefType));
1366  if (!llvmTargetDescriptorTy)
1367  return failure();
1368 
1369  // Create descriptor.
1370  Location loc = reshapeOp.getLoc();
1371  auto desc =
1372  MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1373 
1374  // Set allocated and aligned pointers.
1375  Value allocatedPtr, alignedPtr;
1376  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1377  reshapeOp.getSource(), adaptor.getSource(),
1378  &allocatedPtr, &alignedPtr);
1379  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1380  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1381 
1382  // Extract the offset and strides from the type.
1383  int64_t offset;
1384  SmallVector<int64_t> strides;
1385  if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1386  return rewriter.notifyMatchFailure(
1387  reshapeOp, "failed to get stride and offset exprs");
1388 
1389  if (!isStaticStrideOrOffset(offset))
1390  return rewriter.notifyMatchFailure(reshapeOp,
1391  "dynamic offset is unsupported");
1392 
1393  desc.setConstantOffset(rewriter, loc, offset);
1394 
1395  assert(targetMemRefType.getLayout().isIdentity() &&
1396  "Identity layout map is a precondition of a valid reshape op");
1397 
1398  Type indexType = getIndexType();
1399  Value stride = nullptr;
1400  int64_t targetRank = targetMemRefType.getRank();
1401  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1402  if (!ShapedType::isDynamic(strides[i])) {
1403  // If the stride for this dimension is dynamic, then use the product
1404  // of the sizes of the inner dimensions.
1405  stride =
1406  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1407  } else if (!stride) {
1408  // `stride` is null only in the first iteration of the loop. However,
1409  // since the target memref has an identity layout, we can safely set
1410  // the innermost stride to 1.
1411  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1412  }
1413 
1414  Value dimSize;
1415  // If the size of this dimension is dynamic, then load it at runtime
1416  // from the shape operand.
1417  if (!targetMemRefType.isDynamicDim(i)) {
1418  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1419  targetMemRefType.getDimSize(i));
1420  } else {
1421  Value shapeOp = reshapeOp.getShape();
1422  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1423  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1424  Type indexType = getIndexType();
1425  if (dimSize.getType() != indexType)
1426  dimSize = typeConverter->materializeTargetConversion(
1427  rewriter, loc, indexType, dimSize);
1428  assert(dimSize && "Invalid memref element type");
1429  }
1430 
1431  desc.setSize(rewriter, loc, i, dimSize);
1432  desc.setStride(rewriter, loc, i, stride);
1433 
1434  // Prepare the stride value for the next dimension.
1435  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1436  }
1437 
1438  *descriptor = desc;
1439  return success();
1440  }
1441 
1442  // The shape is a rank-1 tensor with unknown length.
1443  Location loc = reshapeOp.getLoc();
1444  MemRefDescriptor shapeDesc(adaptor.getShape());
1445  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1446 
1447  // Extract address space and element type.
1448  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1449  unsigned addressSpace =
1450  *getTypeConverter()->getMemRefAddressSpace(targetType);
1451 
1452  // Create the unranked memref descriptor that holds the ranked one. The
1453  // inner descriptor is allocated on stack.
1454  auto targetDesc = UnrankedMemRefDescriptor::poison(
1455  rewriter, loc, typeConverter->convertType(targetType));
1456  targetDesc.setRank(rewriter, loc, resultRank);
1457  SmallVector<Value, 4> sizes;
1458  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1459  targetDesc, addressSpace, sizes);
1460  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1461  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1462  sizes.front());
1463  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1464 
1465  // Extract pointers and offset from the source memref.
1466  Value allocatedPtr, alignedPtr, offset;
1467  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1468  reshapeOp.getSource(), adaptor.getSource(),
1469  &allocatedPtr, &alignedPtr, &offset);
1470 
1471  // Set pointers and offset.
1472  auto elementPtrType =
1473  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1474 
1475  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1476  elementPtrType, allocatedPtr);
1477  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1478  underlyingDescPtr, elementPtrType,
1479  alignedPtr);
1480  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1481  underlyingDescPtr, elementPtrType,
1482  offset);
1483 
1484  // Use the offset pointer as base for further addressing. Copy over the new
1485  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1486  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1487  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1488  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1489  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1490  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1491  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1492  Value resultRankMinusOne =
1493  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1494 
1495  Block *initBlock = rewriter.getInsertionBlock();
1496  Type indexType = getTypeConverter()->getIndexType();
1497  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1498 
1499  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1500  {indexType, indexType}, {loc, loc});
1501 
1502  // Move the remaining initBlock ops to condBlock.
1503  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1504  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1505 
1506  rewriter.setInsertionPointToEnd(initBlock);
1507  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1508  condBlock);
1509  rewriter.setInsertionPointToStart(condBlock);
1510  Value indexArg = condBlock->getArgument(0);
1511  Value strideArg = condBlock->getArgument(1);
1512 
1513  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1514  Value pred = rewriter.create<LLVM::ICmpOp>(
1515  loc, IntegerType::get(rewriter.getContext(), 1),
1516  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1517 
1518  Block *bodyBlock =
1519  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1520  rewriter.setInsertionPointToStart(bodyBlock);
1521 
1522  // Copy size from shape to descriptor.
1523  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1524  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1525  loc, llvmIndexPtrType,
1526  typeConverter->convertType(shapeMemRefType.getElementType()),
1527  shapeOperandPtr, indexArg);
1528  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1529  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1530  targetSizesBase, indexArg, size);
1531 
1532  // Write stride value and compute next one.
1533  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1534  targetStridesBase, indexArg, strideArg);
1535  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1536 
1537  // Decrement loop counter and branch back.
1538  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1539  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1540  condBlock);
1541 
1542  Block *remainder =
1543  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1544 
1545  // Hook up the cond exit to the remainder.
1546  rewriter.setInsertionPointToEnd(condBlock);
1547  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1548  remainder, std::nullopt);
1549 
1550  // Reset position to beginning of new remainder block.
1551  rewriter.setInsertionPointToStart(remainder);
1552 
1553  *descriptor = targetDesc;
1554  return success();
1555  }
1556 };
1557 
1558 /// RessociatingReshapeOp must be expanded before we reach this stage.
1559 /// Report that information.
1560 template <typename ReshapeOp>
1561 class ReassociatingReshapeOpConversion
1562  : public ConvertOpToLLVMPattern<ReshapeOp> {
1563 public:
1565  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1566 
1567  LogicalResult
1568  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1569  ConversionPatternRewriter &rewriter) const override {
1570  return rewriter.notifyMatchFailure(
1571  reshapeOp,
1572  "reassociation operations should have been expanded beforehand");
1573  }
1574 };
1575 
1576 /// Subviews must be expanded before we reach this stage.
1577 /// Report that information.
1578 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1580 
1581  LogicalResult
1582  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1583  ConversionPatternRewriter &rewriter) const override {
1584  return rewriter.notifyMatchFailure(
1585  subViewOp, "subview operations should have been expanded beforehand");
1586  }
1587 };
1588 
1589 /// Conversion pattern that transforms a transpose op into:
1590 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1591 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1592 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1593 /// and stride. Size and stride are permutations of the original values.
1594 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1595 /// The transpose op is replaced by the alloca'ed pointer.
1596 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1597 public:
1599 
1600  LogicalResult
1601  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1602  ConversionPatternRewriter &rewriter) const override {
1603  auto loc = transposeOp.getLoc();
1604  MemRefDescriptor viewMemRef(adaptor.getIn());
1605 
1606  // No permutation, early exit.
1607  if (transposeOp.getPermutation().isIdentity())
1608  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1609 
1610  auto targetMemRef = MemRefDescriptor::poison(
1611  rewriter, loc,
1612  typeConverter->convertType(transposeOp.getIn().getType()));
1613 
1614  // Copy the base and aligned pointers from the old descriptor to the new
1615  // one.
1616  targetMemRef.setAllocatedPtr(rewriter, loc,
1617  viewMemRef.allocatedPtr(rewriter, loc));
1618  targetMemRef.setAlignedPtr(rewriter, loc,
1619  viewMemRef.alignedPtr(rewriter, loc));
1620 
1621  // Copy the offset pointer from the old descriptor to the new one.
1622  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1623 
1624  // Iterate over the dimensions and apply size/stride permutation:
1625  // When enumerating the results of the permutation map, the enumeration
1626  // index is the index into the target dimensions and the DimExpr points to
1627  // the dimension of the source memref.
1628  for (const auto &en :
1629  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1630  int targetPos = en.index();
1631  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1632  targetMemRef.setSize(rewriter, loc, targetPos,
1633  viewMemRef.size(rewriter, loc, sourcePos));
1634  targetMemRef.setStride(rewriter, loc, targetPos,
1635  viewMemRef.stride(rewriter, loc, sourcePos));
1636  }
1637 
1638  rewriter.replaceOp(transposeOp, {targetMemRef});
1639  return success();
1640  }
1641 };
1642 
1643 /// Conversion pattern that transforms an op into:
1644 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1645 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1646 /// and stride.
1647 /// The view op is replaced by the descriptor.
1648 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1650 
1651  // Build and return the value for the idx^th shape dimension, either by
1652  // returning the constant shape dimension or counting the proper dynamic size.
1653  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1654  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1655  Type indexType) const {
1656  assert(idx < shape.size());
1657  if (!ShapedType::isDynamic(shape[idx]))
1658  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1659  // Count the number of dynamic dims in range [0, idx]
1660  unsigned nDynamic =
1661  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1662  return dynamicSizes[nDynamic];
1663  }
1664 
1665  // Build and return the idx^th stride, either by returning the constant stride
1666  // or by computing the dynamic stride from the current `runningStride` and
1667  // `nextSize`. The caller should keep a running stride and update it with the
1668  // result returned by this function.
1669  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1670  ArrayRef<int64_t> strides, Value nextSize,
1671  Value runningStride, unsigned idx, Type indexType) const {
1672  assert(idx < strides.size());
1673  if (!ShapedType::isDynamic(strides[idx]))
1674  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1675  if (nextSize)
1676  return runningStride
1677  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1678  : nextSize;
1679  assert(!runningStride);
1680  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1681  }
1682 
1683  LogicalResult
1684  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1685  ConversionPatternRewriter &rewriter) const override {
1686  auto loc = viewOp.getLoc();
1687 
1688  auto viewMemRefType = viewOp.getType();
1689  auto targetElementTy =
1690  typeConverter->convertType(viewMemRefType.getElementType());
1691  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1692  if (!targetDescTy || !targetElementTy ||
1693  !LLVM::isCompatibleType(targetElementTy) ||
1694  !LLVM::isCompatibleType(targetDescTy))
1695  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1696  failure();
1697 
1698  int64_t offset;
1699  SmallVector<int64_t, 4> strides;
1700  auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1701  if (failed(successStrides))
1702  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1703  assert(offset == 0 && "expected offset to be 0");
1704 
1705  // Target memref must be contiguous in memory (innermost stride is 1), or
1706  // empty (special case when at least one of the memref dimensions is 0).
1707  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1708  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1709  failure();
1710 
1711  // Create the descriptor.
1712  MemRefDescriptor sourceMemRef(adaptor.getSource());
1713  auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1714 
1715  // Field 1: Copy the allocated pointer, used for malloc/free.
1716  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1717  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1718  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1719 
1720  // Field 2: Copy the actual aligned pointer to payload.
1721  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1722  alignedPtr = rewriter.create<LLVM::GEPOp>(
1723  loc, alignedPtr.getType(),
1724  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1725  adaptor.getByteShift());
1726 
1727  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1728 
1729  Type indexType = getIndexType();
1730  // Field 3: The offset in the resulting type must be 0. This is
1731  // because of the type change: an offset on srcType* may not be
1732  // expressible as an offset on dstType*.
1733  targetMemRef.setOffset(
1734  rewriter, loc,
1735  createIndexAttrConstant(rewriter, loc, indexType, offset));
1736 
1737  // Early exit for 0-D corner case.
1738  if (viewMemRefType.getRank() == 0)
1739  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1740 
1741  // Fields 4 and 5: Update sizes and strides.
1742  Value stride = nullptr, nextSize = nullptr;
1743  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1744  // Update size.
1745  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1746  adaptor.getSizes(), i, indexType);
1747  targetMemRef.setSize(rewriter, loc, i, size);
1748  // Update stride.
1749  stride =
1750  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1751  targetMemRef.setStride(rewriter, loc, i, stride);
1752  nextSize = size;
1753  }
1754 
1755  rewriter.replaceOp(viewOp, {targetMemRef});
1756  return success();
1757  }
1758 };
1759 
1760 //===----------------------------------------------------------------------===//
1761 // AtomicRMWOpLowering
1762 //===----------------------------------------------------------------------===//
1763 
1764 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1765 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1766 static std::optional<LLVM::AtomicBinOp>
1767 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1768  switch (atomicOp.getKind()) {
1769  case arith::AtomicRMWKind::addf:
1770  return LLVM::AtomicBinOp::fadd;
1771  case arith::AtomicRMWKind::addi:
1772  return LLVM::AtomicBinOp::add;
1773  case arith::AtomicRMWKind::assign:
1774  return LLVM::AtomicBinOp::xchg;
1775  case arith::AtomicRMWKind::maximumf:
1776  return LLVM::AtomicBinOp::fmax;
1777  case arith::AtomicRMWKind::maxs:
1778  return LLVM::AtomicBinOp::max;
1779  case arith::AtomicRMWKind::maxu:
1780  return LLVM::AtomicBinOp::umax;
1781  case arith::AtomicRMWKind::minimumf:
1782  return LLVM::AtomicBinOp::fmin;
1783  case arith::AtomicRMWKind::mins:
1784  return LLVM::AtomicBinOp::min;
1785  case arith::AtomicRMWKind::minu:
1786  return LLVM::AtomicBinOp::umin;
1787  case arith::AtomicRMWKind::ori:
1788  return LLVM::AtomicBinOp::_or;
1789  case arith::AtomicRMWKind::andi:
1790  return LLVM::AtomicBinOp::_and;
1791  default:
1792  return std::nullopt;
1793  }
1794  llvm_unreachable("Invalid AtomicRMWKind");
1795 }
1796 
1797 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1798  using Base::Base;
1799 
1800  LogicalResult
1801  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1802  ConversionPatternRewriter &rewriter) const override {
1803  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1804  if (!maybeKind)
1805  return failure();
1806  auto memRefType = atomicOp.getMemRefType();
1807  SmallVector<int64_t> strides;
1808  int64_t offset;
1809  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1810  return failure();
1811  auto dataPtr =
1812  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1813  adaptor.getIndices(), rewriter);
1814  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1815  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1816  LLVM::AtomicOrdering::acq_rel);
1817  return success();
1818  }
1819 };
1820 
1821 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1822 class ConvertExtractAlignedPointerAsIndex
1823  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1824 public:
1825  using ConvertOpToLLVMPattern<
1826  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1827 
1828  LogicalResult
1829  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1830  OpAdaptor adaptor,
1831  ConversionPatternRewriter &rewriter) const override {
1832  BaseMemRefType sourceTy = extractOp.getSource().getType();
1833 
1834  Value alignedPtr;
1835  if (sourceTy.hasRank()) {
1836  MemRefDescriptor desc(adaptor.getSource());
1837  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1838  } else {
1839  auto elementPtrTy = LLVM::LLVMPointerType::get(
1840  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1841 
1842  UnrankedMemRefDescriptor desc(adaptor.getSource());
1843  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1844 
1846  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1847  elementPtrTy);
1848  }
1849 
1850  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1851  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1852  return success();
1853  }
1854 };
1855 
1856 /// Materialize the MemRef descriptor represented by the results of
1857 /// ExtractStridedMetadataOp.
1858 class ExtractStridedMetadataOpLowering
1859  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1860 public:
1861  using ConvertOpToLLVMPattern<
1862  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1863 
1864  LogicalResult
1865  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1866  OpAdaptor adaptor,
1867  ConversionPatternRewriter &rewriter) const override {
1868 
1869  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1870  return failure();
1871 
1872  // Create the descriptor.
1873  MemRefDescriptor sourceMemRef(adaptor.getSource());
1874  Location loc = extractStridedMetadataOp.getLoc();
1875  Value source = extractStridedMetadataOp.getSource();
1876 
1877  auto sourceMemRefType = cast<MemRefType>(source.getType());
1878  int64_t rank = sourceMemRefType.getRank();
1879  SmallVector<Value> results;
1880  results.reserve(2 + rank * 2);
1881 
1882  // Base buffer.
1883  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1884  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1886  rewriter, loc, *getTypeConverter(),
1887  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1888  baseBuffer, alignedBuffer);
1889  results.push_back((Value)dstMemRef);
1890 
1891  // Offset.
1892  results.push_back(sourceMemRef.offset(rewriter, loc));
1893 
1894  // Sizes.
1895  for (unsigned i = 0; i < rank; ++i)
1896  results.push_back(sourceMemRef.size(rewriter, loc, i));
1897  // Strides.
1898  for (unsigned i = 0; i < rank; ++i)
1899  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1900 
1901  rewriter.replaceOp(extractStridedMetadataOp, results);
1902  return success();
1903  }
1904 };
1905 
1906 } // namespace
1907 
1909  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1910  // clang-format off
1911  patterns.add<
1912  AllocaOpLowering,
1913  AllocaScopeOpLowering,
1914  AtomicRMWOpLowering,
1915  AssumeAlignmentOpLowering,
1916  ConvertExtractAlignedPointerAsIndex,
1917  DimOpLowering,
1918  ExtractStridedMetadataOpLowering,
1919  GenericAtomicRMWOpLowering,
1920  GlobalMemrefOpLowering,
1921  GetGlobalMemrefOpLowering,
1922  LoadOpLowering,
1923  MemRefCastOpLowering,
1924  MemRefCopyOpLowering,
1925  MemorySpaceCastOpLowering,
1926  MemRefReinterpretCastOpLowering,
1927  MemRefReshapeOpLowering,
1928  PrefetchOpLowering,
1929  RankOpLowering,
1930  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1931  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1932  StoreOpLowering,
1933  SubViewOpLowering,
1935  ViewOpLowering>(converter);
1936  // clang-format on
1937  auto allocLowering = converter.getOptions().allocLowering;
1939  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1940  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1941  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1942 }
1943 
1944 namespace {
1945 struct FinalizeMemRefToLLVMConversionPass
1946  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1947  FinalizeMemRefToLLVMConversionPass> {
1948  using FinalizeMemRefToLLVMConversionPassBase::
1949  FinalizeMemRefToLLVMConversionPassBase;
1950 
1951  void runOnOperation() override {
1952  Operation *op = getOperation();
1953  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1955  dataLayoutAnalysis.getAtOrAbove(op));
1956  options.allocLowering =
1959 
1960  options.useGenericFunctions = useGenericFunctions;
1961 
1962  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1963  options.overrideIndexBitwidth(indexBitwidth);
1964 
1965  LLVMTypeConverter typeConverter(&getContext(), options,
1966  &dataLayoutAnalysis);
1969  LLVMConversionTarget target(getContext());
1970  target.addLegalOp<func::FuncOp>();
1971  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1972  signalPassFailure();
1973  }
1974 };
1975 
1976 /// Implement the interface to convert MemRef to LLVM.
1977 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1979  void loadDependentDialects(MLIRContext *context) const final {
1980  context->loadDialect<LLVM::LLVMDialect>();
1981  }
1982 
1983  /// Hook for derived dialect interface to provide conversion patterns
1984  /// and mark dialect legal for the conversion target.
1985  void populateConvertToLLVMConversionPatterns(
1986  ConversionTarget &target, LLVMTypeConverter &typeConverter,
1987  RewritePatternSet &patterns) const final {
1989  }
1990 };
1991 
1992 } // namespace
1993 
1995  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1996  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1997  });
1998 }
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:161
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
result_range getResults()
Definition: Operation.h:415
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:651
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.