MLIR  21.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/Support/MathExtras.h"
29 #include <optional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
39  LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40 
41 namespace {
42 
43 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
44  return !ShapedType::isDynamic(strideOrOffset);
45 }
46 
47 static FailureOr<LLVM::LLVMFuncOp>
48 getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
49  ModuleOp module) {
50  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
51 
52  if (useGenericFn)
53  return LLVM::lookupOrCreateGenericFreeFn(b, module);
54 
55  return LLVM::lookupOrCreateFreeFn(b, module);
56 }
57 
58 static FailureOr<LLVM::LLVMFuncOp>
59 getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
60  Operation *module, Type indexType) {
61  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
62  if (useGenericFn)
63  return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
64 
65  return LLVM::lookupOrCreateMallocFn(b, module, indexType);
66 }
67 
68 static FailureOr<LLVM::LLVMFuncOp>
69 getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
70  Operation *module, Type indexType) {
71  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
72 
73  if (useGenericFn)
74  return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
75 
76  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
77 }
78 
79 /// Computes the aligned value for 'input' as follows:
80 /// bumped = input + alignement - 1
81 /// aligned = bumped - bumped % alignment
82 static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
83  Value input, Value alignment) {
84  Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
85  rewriter.getIndexAttr(1));
86  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
87  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
88  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
89  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
90 }
91 
92 /// Computes the byte size for the MemRef element type.
93 static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
94  MemRefType memRefType, Operation *op,
95  const DataLayout *defaultLayout) {
96  const DataLayout *layout = defaultLayout;
97  if (const DataLayoutAnalysis *analysis =
98  typeConverter->getDataLayoutAnalysis()) {
99  layout = &analysis->getAbove(op);
100  }
101  Type elementType = memRefType.getElementType();
102  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
103  return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
104  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
105  return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
106  *layout);
107  return layout->getTypeSize(elementType);
108 }
109 
110 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
111  Location loc, Value allocatedPtr,
112  MemRefType memRefType, Type elementPtrType,
113  const LLVMTypeConverter &typeConverter) {
114  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
115  FailureOr<unsigned> maybeMemrefAddrSpace =
116  typeConverter.getMemRefAddressSpace(memRefType);
117  assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
118  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
119  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
120  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
121  loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
122  allocatedPtr);
123  return allocatedPtr;
124 }
125 
126 struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
128 
129  LogicalResult
130  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
132  auto loc = op.getLoc();
133  MemRefType memRefType = op.getType();
134  if (!isConvertibleAndHasIdentityMaps(memRefType))
135  return rewriter.notifyMatchFailure(op, "incompatible memref type");
136 
137  // Get or insert alloc function into the module.
138  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
139  rewriter, getTypeConverter(),
140  op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
141  if (failed(allocFuncOp))
142  return failure();
143 
144  // Get actual sizes of the memref as values: static sizes are constant
145  // values and dynamic sizes are passed to 'alloc' as operands. In case of
146  // zero-dimensional memref, assume a scalar (size 1).
147  SmallVector<Value, 4> sizes;
148  SmallVector<Value, 4> strides;
149  Value sizeBytes;
150 
151  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
152  rewriter, sizes, strides, sizeBytes, true);
153 
154  Value alignment = getAlignment(rewriter, loc, op);
155  if (alignment) {
156  // Adjust the allocation size to consider alignment.
157  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
158  }
159 
160  // Allocate the underlying buffer.
161  Type elementPtrType = this->getElementPtrType(memRefType);
162  assert(elementPtrType && "could not compute element ptr type");
163  auto results =
164  rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
165 
166  Value allocatedPtr =
167  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
168  elementPtrType, *getTypeConverter());
169  Value alignedPtr = allocatedPtr;
170  if (alignment) {
171  // Compute the aligned pointer.
172  Value allocatedInt =
173  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
174  Value alignmentInt =
175  createAligned(rewriter, loc, allocatedInt, alignment);
176  alignedPtr =
177  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
178  }
179 
180  // Create the MemRef descriptor.
181  auto memRefDescriptor = this->createMemRefDescriptor(
182  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
183 
184  // Return the final value of the descriptor.
185  rewriter.replaceOp(op, {memRefDescriptor});
186  return success();
187  }
188 
189  /// Computes the alignment for the given memory allocation op.
190  template <typename OpType>
191  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
192  OpType op) const {
193  MemRefType memRefType = op.getType();
194  Value alignment;
195  if (auto alignmentAttr = op.getAlignment()) {
196  Type indexType = getIndexType();
197  alignment =
198  createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
199  } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
200  // In the case where no alignment is specified, we may want to override
201  // `malloc's` behavior. `malloc` typically aligns at the size of the
202  // biggest scalar on a target HW. For non-scalars, use the natural
203  // alignment of the LLVM type given by the LLVM DataLayout.
204  alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
205  }
206  return alignment;
207  }
208 };
209 
210 struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
212 
213  LogicalResult
214  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
215  ConversionPatternRewriter &rewriter) const override {
216  auto loc = op.getLoc();
217  MemRefType memRefType = op.getType();
218  if (!isConvertibleAndHasIdentityMaps(memRefType))
219  return rewriter.notifyMatchFailure(op, "incompatible memref type");
220 
221  // Get or insert alloc function into module.
222  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
223  rewriter, getTypeConverter(),
224  op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
225  if (failed(allocFuncOp))
226  return failure();
227 
228  // Get actual sizes of the memref as values: static sizes are constant
229  // values and dynamic sizes are passed to 'alloc' as operands. In case of
230  // zero-dimensional memref, assume a scalar (size 1).
231  SmallVector<Value, 4> sizes;
232  SmallVector<Value, 4> strides;
233  Value sizeBytes;
234 
235  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
236  rewriter, sizes, strides, sizeBytes, !false);
237 
238  int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
239 
240  Value allocAlignment =
241  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
242 
243  // Function aligned_alloc requires size to be a multiple of alignment; we
244  // pad the size to the next multiple if necessary.
245  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
246  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
247 
248  Type elementPtrType = this->getElementPtrType(memRefType);
249  auto results = rewriter.create<LLVM::CallOp>(
250  loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
251 
252  Value ptr =
253  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
254  elementPtrType, *getTypeConverter());
255 
256  // Create the MemRef descriptor.
257  auto memRefDescriptor = this->createMemRefDescriptor(
258  loc, memRefType, ptr, ptr, sizes, strides, rewriter);
259 
260  // Return the final value of the descriptor.
261  rewriter.replaceOp(op, {memRefDescriptor});
262  return success();
263  }
264 
265  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
266  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
267 
268  /// Computes the alignment for aligned_alloc used to allocate the buffer for
269  /// the memory allocation op.
270  ///
271  /// Aligned_alloc requires the allocation size to be a power of two, and the
272  /// allocation size to be a multiple of the alignment.
273  int64_t alignedAllocationGetAlignment(memref::AllocOp op,
274  const DataLayout *defaultLayout) const {
275  if (std::optional<uint64_t> alignment = op.getAlignment())
276  return *alignment;
277 
278  // Whenever we don't have alignment set, we will use an alignment
279  // consistent with the element type; since the allocation size has to be a
280  // power of two, we will bump to the next power of two if it isn't.
281  unsigned eltSizeBytes = getMemRefEltSizeInBytes(
282  getTypeConverter(), op.getType(), op, defaultLayout);
283  return std::max(kMinAlignedAllocAlignment,
284  llvm::PowerOf2Ceil(eltSizeBytes));
285  }
286 
287  /// Returns true if the memref size in bytes is known to be a multiple of
288  /// factor.
289  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
290  const DataLayout *defaultLayout) const {
291  uint64_t sizeDivisor =
292  getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
293  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
294  if (type.isDynamicDim(i))
295  continue;
296  sizeDivisor = sizeDivisor * type.getDimSize(i);
297  }
298  return sizeDivisor % factor == 0;
299  }
300 
301 private:
302  /// Default layout to use in absence of the corresponding analysis.
303  DataLayout defaultLayout;
304 };
305 
306 struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
308 
309  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
310  /// is set to null for stack allocations. `accessAlignment` is set if
311  /// alignment is needed post allocation (for eg. in conjunction with malloc).
312  LogicalResult
313  matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
314  ConversionPatternRewriter &rewriter) const override {
315  auto loc = op.getLoc();
316  MemRefType memRefType = op.getType();
317  if (!isConvertibleAndHasIdentityMaps(memRefType))
318  return rewriter.notifyMatchFailure(op, "incompatible memref type");
319 
320  // Get actual sizes of the memref as values: static sizes are constant
321  // values and dynamic sizes are passed to 'alloc' as operands. In case of
322  // zero-dimensional memref, assume a scalar (size 1).
323  SmallVector<Value, 4> sizes;
324  SmallVector<Value, 4> strides;
325  Value size;
326 
327  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
328  rewriter, sizes, strides, size, !true);
329 
330  // With alloca, one gets a pointer to the element type right away.
331  // For stack allocations.
332  auto elementType =
333  typeConverter->convertType(op.getType().getElementType());
334  FailureOr<unsigned> maybeAddressSpace =
335  getTypeConverter()->getMemRefAddressSpace(op.getType());
336  assert(succeeded(maybeAddressSpace) && "unsupported address space");
337  unsigned addrSpace = *maybeAddressSpace;
338  auto elementPtrType =
339  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
340 
341  auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
342  loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
343 
344  // Create the MemRef descriptor.
345  auto memRefDescriptor = this->createMemRefDescriptor(
346  loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
347  strides, rewriter);
348 
349  // Return the final value of the descriptor.
350  rewriter.replaceOp(op, {memRefDescriptor});
351  return success();
352  }
353 };
354 
355 struct AllocaScopeOpLowering
356  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
358 
359  LogicalResult
360  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
361  ConversionPatternRewriter &rewriter) const override {
362  OpBuilder::InsertionGuard guard(rewriter);
363  Location loc = allocaScopeOp.getLoc();
364 
365  // Split the current block before the AllocaScopeOp to create the inlining
366  // point.
367  auto *currentBlock = rewriter.getInsertionBlock();
368  auto *remainingOpsBlock =
369  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
370  Block *continueBlock;
371  if (allocaScopeOp.getNumResults() == 0) {
372  continueBlock = remainingOpsBlock;
373  } else {
374  continueBlock = rewriter.createBlock(
375  remainingOpsBlock, allocaScopeOp.getResultTypes(),
376  SmallVector<Location>(allocaScopeOp->getNumResults(),
377  allocaScopeOp.getLoc()));
378  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
379  }
380 
381  // Inline body region.
382  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
383  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
384  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
385 
386  // Save stack and then branch into the body of the region.
387  rewriter.setInsertionPointToEnd(currentBlock);
388  auto stackSaveOp =
389  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
390  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
391 
392  // Replace the alloca_scope return with a branch that jumps out of the body.
393  // Stack restore before leaving the body region.
394  rewriter.setInsertionPointToEnd(afterBody);
395  auto returnOp =
396  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
397  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
398  returnOp, returnOp.getResults(), continueBlock);
399 
400  // Insert stack restore before jumping out the body of the region.
401  rewriter.setInsertionPoint(branchOp);
402  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
403 
404  // Replace the op with values return from the body region.
405  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
406 
407  return success();
408  }
409 };
410 
411 struct AssumeAlignmentOpLowering
412  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
414  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
415  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
417 
418  LogicalResult
419  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
420  ConversionPatternRewriter &rewriter) const override {
421  Value memref = adaptor.getMemref();
422  unsigned alignment = op.getAlignment();
423  auto loc = op.getLoc();
424 
425  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
426  Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
427  /*indices=*/{});
428 
429  // Emit llvm.assume(true) ["align"(memref, alignment)].
430  // This is more direct than ptrtoint-based checks, is explicitly supported,
431  // and works with non-integral address spaces.
432  Value trueCond =
433  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
434  Value alignmentConst =
435  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
436  rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
437  alignmentConst);
438  rewriter.replaceOp(op, memref);
439  return success();
440  }
441 };
442 
443 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
444 // The memref descriptor being an SSA value, there is no need to clean it up
445 // in any way.
446 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
448 
449  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
451 
452  LogicalResult
453  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
454  ConversionPatternRewriter &rewriter) const override {
455  // Insert the `free` declaration if it is not already present.
456  FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
457  rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
458  if (failed(freeFunc))
459  return failure();
460  Value allocatedPtr;
461  if (auto unrankedTy =
462  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
463  auto elementPtrTy = LLVM::LLVMPointerType::get(
464  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
466  rewriter, op.getLoc(),
467  UnrankedMemRefDescriptor(adaptor.getMemref())
468  .memRefDescPtr(rewriter, op.getLoc()),
469  elementPtrTy);
470  } else {
471  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
472  .allocatedPtr(rewriter, op.getLoc());
473  }
474  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
475  allocatedPtr);
476  return success();
477  }
478 };
479 
480 // A `dim` is converted to a constant for static sizes and to an access to the
481 // size stored in the memref descriptor for dynamic sizes.
482 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
484 
485  LogicalResult
486  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
487  ConversionPatternRewriter &rewriter) const override {
488  Type operandType = dimOp.getSource().getType();
489  if (isa<UnrankedMemRefType>(operandType)) {
490  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
491  operandType, dimOp, adaptor.getOperands(), rewriter);
492  if (failed(extractedSize))
493  return failure();
494  rewriter.replaceOp(dimOp, {*extractedSize});
495  return success();
496  }
497  if (isa<MemRefType>(operandType)) {
498  rewriter.replaceOp(
499  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
500  adaptor.getOperands(), rewriter)});
501  return success();
502  }
503  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
504  }
505 
506 private:
507  FailureOr<Value>
508  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
509  OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const {
511  Location loc = dimOp.getLoc();
512 
513  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
514  auto scalarMemRefType =
515  MemRefType::get({}, unrankedMemRefType.getElementType());
516  FailureOr<unsigned> maybeAddressSpace =
517  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
518  if (failed(maybeAddressSpace)) {
519  dimOp.emitOpError("memref memory space must be convertible to an integer "
520  "address space");
521  return failure();
522  }
523  unsigned addressSpace = *maybeAddressSpace;
524 
525  // Extract pointer to the underlying ranked descriptor and bitcast it to a
526  // memref<element_type> descriptor pointer to minimize the number of GEP
527  // operations.
528  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
529  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
530 
531  Type elementType = typeConverter->convertType(scalarMemRefType);
532 
533  // Get pointer to offset field of memref<element_type> descriptor.
534  auto indexPtrTy =
535  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
536  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
537  loc, indexPtrTy, elementType, underlyingRankedDesc,
538  ArrayRef<LLVM::GEPArg>{0, 2});
539 
540  // The size value that we have to extract can be obtained using GEPop with
541  // `dimOp.index() + 1` index argument.
542  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
543  loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
544  adaptor.getIndex());
545  Value sizePtr = rewriter.create<LLVM::GEPOp>(
546  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
547  idxPlusOne);
548  return rewriter
549  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
550  .getResult();
551  }
552 
553  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
554  if (auto idx = dimOp.getConstantIndex())
555  return idx;
556 
557  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
558  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
559 
560  return std::nullopt;
561  }
562 
563  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
564  OpAdaptor adaptor,
565  ConversionPatternRewriter &rewriter) const {
566  Location loc = dimOp.getLoc();
567 
568  // Take advantage if index is constant.
569  MemRefType memRefType = cast<MemRefType>(operandType);
570  Type indexType = getIndexType();
571  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
572  int64_t i = *index;
573  if (i >= 0 && i < memRefType.getRank()) {
574  if (memRefType.isDynamicDim(i)) {
575  // extract dynamic size from the memref descriptor.
576  MemRefDescriptor descriptor(adaptor.getSource());
577  return descriptor.size(rewriter, loc, i);
578  }
579  // Use constant for static size.
580  int64_t dimSize = memRefType.getDimSize(i);
581  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
582  }
583  }
584  Value index = adaptor.getIndex();
585  int64_t rank = memRefType.getRank();
586  MemRefDescriptor memrefDescriptor(adaptor.getSource());
587  return memrefDescriptor.size(rewriter, loc, index, rank);
588  }
589 };
590 
591 /// Common base for load and store operations on MemRefs. Restricts the match
592 /// to supported MemRef types. Provides functionality to emit code accessing a
593 /// specific element of the underlying data buffer.
594 template <typename Derived>
595 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
598  using Base = LoadStoreOpLowering<Derived>;
599 };
600 
601 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
602 /// retried until it succeeds in atomically storing a new value into memory.
603 ///
604 /// +---------------------------------+
605 /// | <code before the AtomicRMWOp> |
606 /// | <compute initial %loaded> |
607 /// | cf.br loop(%loaded) |
608 /// +---------------------------------+
609 /// |
610 /// -------| |
611 /// | v v
612 /// | +--------------------------------+
613 /// | | loop(%loaded): |
614 /// | | <body contents> |
615 /// | | %pair = cmpxchg |
616 /// | | %ok = %pair[0] |
617 /// | | %new = %pair[1] |
618 /// | | cf.cond_br %ok, end, loop(%new) |
619 /// | +--------------------------------+
620 /// | | |
621 /// |----------- |
622 /// v
623 /// +--------------------------------+
624 /// | end: |
625 /// | <code after the AtomicRMWOp> |
626 /// +--------------------------------+
627 ///
628 struct GenericAtomicRMWOpLowering
629  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
630  using Base::Base;
631 
632  LogicalResult
633  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
634  ConversionPatternRewriter &rewriter) const override {
635  auto loc = atomicOp.getLoc();
636  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
637 
638  // Split the block into initial, loop, and ending parts.
639  auto *initBlock = rewriter.getInsertionBlock();
640  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
641  loopBlock->addArgument(valueType, loc);
642 
643  auto *endBlock =
644  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
645 
646  // Compute the loaded value and branch to the loop block.
647  rewriter.setInsertionPointToEnd(initBlock);
648  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
649  auto dataPtr = getStridedElementPtr(
650  rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
651  Value init = rewriter.create<LLVM::LoadOp>(
652  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
653  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
654 
655  // Prepare the body of the loop block.
656  rewriter.setInsertionPointToStart(loopBlock);
657 
658  // Clone the GenericAtomicRMWOp region and extract the result.
659  auto loopArgument = loopBlock->getArgument(0);
660  IRMapping mapping;
661  mapping.map(atomicOp.getCurrentValue(), loopArgument);
662  Block &entryBlock = atomicOp.body().front();
663  for (auto &nestedOp : entryBlock.without_terminator()) {
664  Operation *clone = rewriter.clone(nestedOp, mapping);
665  mapping.map(nestedOp.getResults(), clone->getResults());
666  }
667  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
668 
669  // Prepare the epilog of the loop block.
670  // Append the cmpxchg op to the end of the loop block.
671  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
672  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
673  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
674  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
675  // Extract the %new_loaded and %ok values from the pair.
676  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
677  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
678 
679  // Conditionally branch to the end or back to the loop depending on %ok.
680  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
681  loopBlock, newLoaded);
682 
683  rewriter.setInsertionPointToEnd(endBlock);
684 
685  // The 'result' of the atomic_rmw op is the newly loaded value.
686  rewriter.replaceOp(atomicOp, {newLoaded});
687 
688  return success();
689  }
690 };
691 
692 /// Returns the LLVM type of the global variable given the memref type `type`.
693 static Type
694 convertGlobalMemrefTypeToLLVM(MemRefType type,
695  const LLVMTypeConverter &typeConverter) {
696  // LLVM type for a global memref will be a multi-dimension array. For
697  // declarations or uninitialized global memrefs, we can potentially flatten
698  // this to a 1D array. However, for memref.global's with an initial value,
699  // we do not intend to flatten the ElementsAttribute when going from std ->
700  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
701  Type elementType = typeConverter.convertType(type.getElementType());
702  Type arrayTy = elementType;
703  // Shape has the outermost dim at index 0, so need to walk it backwards
704  for (int64_t dim : llvm::reverse(type.getShape()))
705  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
706  return arrayTy;
707 }
708 
709 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
710 struct GlobalMemrefOpLowering
711  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
713 
714  LogicalResult
715  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
716  ConversionPatternRewriter &rewriter) const override {
717  MemRefType type = global.getType();
718  if (!isConvertibleAndHasIdentityMaps(type))
719  return failure();
720 
721  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
722 
723  LLVM::Linkage linkage =
724  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
725 
726  Attribute initialValue = nullptr;
727  if (!global.isExternal() && !global.isUninitialized()) {
728  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
729  initialValue = elementsAttr;
730 
731  // For scalar memrefs, the global variable created is of the element type,
732  // so unpack the elements attribute to extract the value.
733  if (type.getRank() == 0)
734  initialValue = elementsAttr.getSplatValue<Attribute>();
735  }
736 
737  uint64_t alignment = global.getAlignment().value_or(0);
738  FailureOr<unsigned> addressSpace =
739  getTypeConverter()->getMemRefAddressSpace(type);
740  if (failed(addressSpace))
741  return global.emitOpError(
742  "memory space cannot be converted to an integer address space");
743  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
744  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
745  initialValue, alignment, *addressSpace);
746  if (!global.isExternal() && global.isUninitialized()) {
747  rewriter.createBlock(&newGlobal.getInitializerRegion());
748  Value undef[] = {
749  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
750  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
751  }
752  return success();
753  }
754 };
755 
756 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
757 /// the first element stashed into the descriptor. This reuses
758 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
759 struct GetGlobalMemrefOpLowering
760  : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
762 
763  /// Buffer "allocation" for memref.get_global op is getting the address of
764  /// the global variable referenced.
765  LogicalResult
766  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
767  ConversionPatternRewriter &rewriter) const override {
768  auto loc = op.getLoc();
769  MemRefType memRefType = op.getType();
770  if (!isConvertibleAndHasIdentityMaps(memRefType))
771  return rewriter.notifyMatchFailure(op, "incompatible memref type");
772 
773  // Get actual sizes of the memref as values: static sizes are constant
774  // values and dynamic sizes are passed to 'alloc' as operands. In case of
775  // zero-dimensional memref, assume a scalar (size 1).
776  SmallVector<Value, 4> sizes;
777  SmallVector<Value, 4> strides;
778  Value sizeBytes;
779 
780  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
781  rewriter, sizes, strides, sizeBytes, !false);
782 
783  MemRefType type = cast<MemRefType>(op.getResult().getType());
784 
785  // This is called after a type conversion, which would have failed if this
786  // call fails.
787  FailureOr<unsigned> maybeAddressSpace =
788  getTypeConverter()->getMemRefAddressSpace(type);
789  assert(succeeded(maybeAddressSpace) && "unsupported address space");
790  unsigned memSpace = *maybeAddressSpace;
791 
792  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
793  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
794  auto addressOf =
795  rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
796 
797  // Get the address of the first element in the array by creating a GEP with
798  // the address of the GV as the base, and (rank + 1) number of 0 indices.
799  auto gep = rewriter.create<LLVM::GEPOp>(
800  loc, ptrTy, arrayTy, addressOf,
801  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
802 
803  // We do not expect the memref obtained using `memref.get_global` to be
804  // ever deallocated. Set the allocated pointer to be known bad value to
805  // help debug if that ever happens.
806  auto intPtrType = getIntPtrType(memSpace);
807  Value deadBeefConst =
808  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
809  auto deadBeefPtr =
810  rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
811 
812  // Both allocated and aligned pointers are same. We could potentially stash
813  // a nullptr for the allocated pointer since we do not expect any dealloc.
814  // Create the MemRef descriptor.
815  auto memRefDescriptor = this->createMemRefDescriptor(
816  loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
817 
818  // Return the final value of the descriptor.
819  rewriter.replaceOp(op, {memRefDescriptor});
820  return success();
821  }
822 };
823 
824 // Load operation is lowered to obtaining a pointer to the indexed element
825 // and loading it.
826 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
827  using Base::Base;
828 
829  LogicalResult
830  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
831  ConversionPatternRewriter &rewriter) const override {
832  auto type = loadOp.getMemRefType();
833 
834  // Per memref.load spec, the indices must be in-bounds:
835  // 0 <= idx < dim_size, and additionally all offsets are non-negative,
836  // hence inbounds and nuw are used when lowering to llvm.getelementptr.
837  Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
838  adaptor.getMemref(),
839  adaptor.getIndices(), kNoWrapFlags);
840  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
841  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
842  false, loadOp.getNontemporal());
843  return success();
844  }
845 };
846 
847 // Store operation is lowered to obtaining a pointer to the indexed element,
848 // and storing the given value to it.
849 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
850  using Base::Base;
851 
852  LogicalResult
853  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
854  ConversionPatternRewriter &rewriter) const override {
855  auto type = op.getMemRefType();
856 
857  // Per memref.store spec, the indices must be in-bounds:
858  // 0 <= idx < dim_size, and additionally all offsets are non-negative,
859  // hence inbounds and nuw are used when lowering to llvm.getelementptr.
860  Value dataPtr =
861  getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
862  adaptor.getIndices(), kNoWrapFlags);
863  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
864  0, false, op.getNontemporal());
865  return success();
866  }
867 };
868 
869 // The prefetch operation is lowered in a way similar to the load operation
870 // except that the llvm.prefetch operation is used for replacement.
871 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
872  using Base::Base;
873 
874  LogicalResult
875  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
876  ConversionPatternRewriter &rewriter) const override {
877  auto type = prefetchOp.getMemRefType();
878  auto loc = prefetchOp.getLoc();
879 
880  Value dataPtr = getStridedElementPtr(
881  rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
882 
883  // Replace with llvm.prefetch.
884  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
885  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
886  IntegerAttr isData =
887  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
888  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
889  localityHint, isData);
890  return success();
891  }
892 };
893 
894 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
896 
897  LogicalResult
898  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
899  ConversionPatternRewriter &rewriter) const override {
900  Location loc = op.getLoc();
901  Type operandType = op.getMemref().getType();
902  if (isa<UnrankedMemRefType>(operandType)) {
903  UnrankedMemRefDescriptor desc(adaptor.getMemref());
904  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
905  return success();
906  }
907  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
908  Type indexType = getIndexType();
909  rewriter.replaceOp(op,
910  {createIndexAttrConstant(rewriter, loc, indexType,
911  rankedMemRefType.getRank())});
912  return success();
913  }
914  return failure();
915  }
916 };
917 
918 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
920 
921  LogicalResult
922  matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
923  ConversionPatternRewriter &rewriter) const override {
924  Type srcType = memRefCastOp.getOperand().getType();
925  Type dstType = memRefCastOp.getType();
926 
927  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
928  // used for type erasure. For now they must preserve underlying element type
929  // and require source and result type to have the same rank. Therefore,
930  // perform a sanity check that the underlying structs are the same. Once op
931  // semantics are relaxed we can revisit.
932  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
933  if (typeConverter->convertType(srcType) !=
934  typeConverter->convertType(dstType))
935  return failure();
936 
937  // Unranked to unranked cast is disallowed
938  if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
939  return failure();
940 
941  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
942  auto loc = memRefCastOp.getLoc();
943 
944  // For ranked/ranked case, just keep the original descriptor.
945  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
946  rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
947  return success();
948  }
949 
950  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
951  // Casting ranked to unranked memref type
952  // Set the rank in the destination from the memref type
953  // Allocate space on the stack and copy the src memref descriptor
954  // Set the ptr in the destination to the stack space
955  auto srcMemRefType = cast<MemRefType>(srcType);
956  int64_t rank = srcMemRefType.getRank();
957  // ptr = AllocaOp sizeof(MemRefDescriptor)
958  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
959  loc, adaptor.getSource(), rewriter);
960 
961  // rank = ConstantOp srcRank
962  auto rankVal = rewriter.create<LLVM::ConstantOp>(
963  loc, getIndexType(), rewriter.getIndexAttr(rank));
964  // poison = PoisonOp
965  UnrankedMemRefDescriptor memRefDesc =
966  UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
967  // d1 = InsertValueOp poison, rank, 0
968  memRefDesc.setRank(rewriter, loc, rankVal);
969  // d2 = InsertValueOp d1, ptr, 1
970  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
971  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
972 
973  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
974  // Casting from unranked type to ranked.
975  // The operation is assumed to be doing a correct cast. If the destination
976  // type mismatches the unranked the type, it is undefined behavior.
977  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
978  // ptr = ExtractValueOp src, 1
979  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
980 
981  // struct = LoadOp ptr
982  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
983  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
984  } else {
985  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
986  }
987 
988  return success();
989  }
990 };
991 
992 /// Pattern to lower a `memref.copy` to llvm.
993 ///
994 /// For memrefs with identity layouts, the copy is lowered to the llvm
995 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
996 /// to the generic `MemrefCopyFn`.
997 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
999 
1000  LogicalResult
1001  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1002  ConversionPatternRewriter &rewriter) const {
1003  auto loc = op.getLoc();
1004  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1005 
1006  MemRefDescriptor srcDesc(adaptor.getSource());
1007 
1008  // Compute number of elements.
1009  Value numElements = rewriter.create<LLVM::ConstantOp>(
1010  loc, getIndexType(), rewriter.getIndexAttr(1));
1011  for (int pos = 0; pos < srcType.getRank(); ++pos) {
1012  auto size = srcDesc.size(rewriter, loc, pos);
1013  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
1014  }
1015 
1016  // Get element size.
1017  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1018  // Compute total.
1019  Value totalSize =
1020  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
1021 
1022  Type elementType = typeConverter->convertType(srcType.getElementType());
1023 
1024  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1025  Value srcOffset = srcDesc.offset(rewriter, loc);
1026  Value srcPtr = rewriter.create<LLVM::GEPOp>(
1027  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
1028  MemRefDescriptor targetDesc(adaptor.getTarget());
1029  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1030  Value targetOffset = targetDesc.offset(rewriter, loc);
1031  Value targetPtr = rewriter.create<LLVM::GEPOp>(
1032  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
1033  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
1034  /*isVolatile=*/false);
1035  rewriter.eraseOp(op);
1036 
1037  return success();
1038  }
1039 
1040  LogicalResult
1041  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1042  ConversionPatternRewriter &rewriter) const {
1043  auto loc = op.getLoc();
1044  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1045  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1046 
1047  // First make sure we have an unranked memref descriptor representation.
1048  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1049  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1050  type.getRank());
1051  auto *typeConverter = getTypeConverter();
1052  auto ptr =
1053  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1054 
1055  auto unrankedType =
1056  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1058  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1059  };
1060 
1061  // Save stack position before promoting descriptors
1062  auto stackSaveOp =
1063  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1064 
1065  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1066  Value unrankedSource =
1067  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1068  : adaptor.getSource();
1069  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1070  Value unrankedTarget =
1071  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1072  : adaptor.getTarget();
1073 
1074  // Now promote the unranked descriptors to the stack.
1075  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1076  rewriter.getIndexAttr(1));
1077  auto promote = [&](Value desc) {
1078  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1079  auto allocated =
1080  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1081  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
1082  return allocated;
1083  };
1084 
1085  auto sourcePtr = promote(unrankedSource);
1086  auto targetPtr = promote(unrankedTarget);
1087 
1088  // Derive size from llvm.getelementptr which will account for any
1089  // potential alignment
1090  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1091  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1092  rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1093  sourcePtr.getType());
1094  if (failed(copyFn))
1095  return failure();
1096  rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
1097  ValueRange{elemSize, sourcePtr, targetPtr});
1098 
1099  // Restore stack used for descriptors
1100  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1101 
1102  rewriter.eraseOp(op);
1103 
1104  return success();
1105  }
1106 
1107  LogicalResult
1108  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1109  ConversionPatternRewriter &rewriter) const override {
1110  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1111  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1112 
1113  auto isContiguousMemrefType = [&](BaseMemRefType type) {
1114  auto memrefType = dyn_cast<mlir::MemRefType>(type);
1115  // We can use memcpy for memrefs if they have an identity layout or are
1116  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1117  // special case handled by memrefCopy.
1118  return memrefType &&
1119  (memrefType.getLayout().isIdentity() ||
1120  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1122  };
1123 
1124  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1125  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1126 
1127  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1128  }
1129 };
1130 
1131 struct MemorySpaceCastOpLowering
1132  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1133  using ConvertOpToLLVMPattern<
1134  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1135 
1136  LogicalResult
1137  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1138  ConversionPatternRewriter &rewriter) const override {
1139  Location loc = op.getLoc();
1140 
1141  Type resultType = op.getDest().getType();
1142  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1143  auto resultDescType =
1144  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1145  Type newPtrType = resultDescType.getBody()[0];
1146 
1147  SmallVector<Value> descVals;
1148  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1149  descVals);
1150  descVals[0] =
1151  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1152  descVals[1] =
1153  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1154  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1155  resultTypeR, descVals);
1156  rewriter.replaceOp(op, result);
1157  return success();
1158  }
1159  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1160  // Since the type converter won't be doing this for us, get the address
1161  // space.
1162  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1163  FailureOr<unsigned> maybeSourceAddrSpace =
1164  getTypeConverter()->getMemRefAddressSpace(sourceType);
1165  if (failed(maybeSourceAddrSpace))
1166  return rewriter.notifyMatchFailure(loc,
1167  "non-integer source address space");
1168  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1169  FailureOr<unsigned> maybeResultAddrSpace =
1170  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1171  if (failed(maybeResultAddrSpace))
1172  return rewriter.notifyMatchFailure(loc,
1173  "non-integer result address space");
1174  unsigned resultAddrSpace = *maybeResultAddrSpace;
1175 
1176  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1177  Value rank = sourceDesc.rank(rewriter, loc);
1178  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1179 
1180  // Create and allocate storage for new memref descriptor.
1181  auto result = UnrankedMemRefDescriptor::poison(
1182  rewriter, loc, typeConverter->convertType(resultTypeU));
1183  result.setRank(rewriter, loc, rank);
1184  SmallVector<Value, 1> sizes;
1185  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1186  result, resultAddrSpace, sizes);
1187  Value resultUnderlyingSize = sizes.front();
1188  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
1189  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
1190  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1191 
1192  // Copy pointers, performing address space casts.
1193  auto sourceElemPtrType =
1194  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1195  auto resultElemPtrType =
1196  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1197 
1198  Value allocatedPtr = sourceDesc.allocatedPtr(
1199  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1200  Value alignedPtr =
1201  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1202  sourceUnderlyingDesc, sourceElemPtrType);
1203  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1204  loc, resultElemPtrType, allocatedPtr);
1205  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1206  loc, resultElemPtrType, alignedPtr);
1207 
1208  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1209  resultElemPtrType, allocatedPtr);
1210  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1211  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1212 
1213  // Copy all the index-valued operands.
1214  Value sourceIndexVals =
1215  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1216  sourceUnderlyingDesc, sourceElemPtrType);
1217  Value resultIndexVals =
1218  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1219  resultUnderlyingDesc, resultElemPtrType);
1220 
1221  int64_t bytesToSkip =
1222  2 * llvm::divideCeil(
1223  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1224  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
1225  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1226  Value copySize = rewriter.create<LLVM::SubOp>(
1227  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1228  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1229  copySize, /*isVolatile=*/false);
1230 
1231  rewriter.replaceOp(op, ValueRange{result});
1232  return success();
1233  }
1234  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1235  }
1236 };
1237 
1238 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
1239 /// memref type. In unranked case, the fields are extracted from the underlying
1240 /// ranked descriptor.
1241 static void extractPointersAndOffset(Location loc,
1242  ConversionPatternRewriter &rewriter,
1243  const LLVMTypeConverter &typeConverter,
1244  Value originalOperand,
1245  Value convertedOperand,
1246  Value *allocatedPtr, Value *alignedPtr,
1247  Value *offset = nullptr) {
1248  Type operandType = originalOperand.getType();
1249  if (isa<MemRefType>(operandType)) {
1250  MemRefDescriptor desc(convertedOperand);
1251  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1252  *alignedPtr = desc.alignedPtr(rewriter, loc);
1253  if (offset != nullptr)
1254  *offset = desc.offset(rewriter, loc);
1255  return;
1256  }
1257 
1258  // These will all cause assert()s on unconvertible types.
1259  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1260  cast<UnrankedMemRefType>(operandType));
1261  auto elementPtrType =
1262  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1263 
1264  // Extract pointer to the underlying ranked memref descriptor and cast it to
1265  // ElemType**.
1266  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1267  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1268 
1270  rewriter, loc, underlyingDescPtr, elementPtrType);
1272  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1273  if (offset != nullptr) {
1275  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1276  }
1277 }
1278 
1279 struct MemRefReinterpretCastOpLowering
1280  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1281  using ConvertOpToLLVMPattern<
1282  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1283 
1284  LogicalResult
1285  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1286  ConversionPatternRewriter &rewriter) const override {
1287  Type srcType = castOp.getSource().getType();
1288 
1289  Value descriptor;
1290  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1291  adaptor, &descriptor)))
1292  return failure();
1293  rewriter.replaceOp(castOp, {descriptor});
1294  return success();
1295  }
1296 
1297 private:
1298  LogicalResult convertSourceMemRefToDescriptor(
1299  ConversionPatternRewriter &rewriter, Type srcType,
1300  memref::ReinterpretCastOp castOp,
1301  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1302  MemRefType targetMemRefType =
1303  cast<MemRefType>(castOp.getResult().getType());
1304  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1305  typeConverter->convertType(targetMemRefType));
1306  if (!llvmTargetDescriptorTy)
1307  return failure();
1308 
1309  // Create descriptor.
1310  Location loc = castOp.getLoc();
1311  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1312 
1313  // Set allocated and aligned pointers.
1314  Value allocatedPtr, alignedPtr;
1315  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1316  castOp.getSource(), adaptor.getSource(),
1317  &allocatedPtr, &alignedPtr);
1318  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1319  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1320 
1321  // Set offset.
1322  if (castOp.isDynamicOffset(0))
1323  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1324  else
1325  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1326 
1327  // Set sizes and strides.
1328  unsigned dynSizeId = 0;
1329  unsigned dynStrideId = 0;
1330  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1331  if (castOp.isDynamicSize(i))
1332  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1333  else
1334  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1335 
1336  if (castOp.isDynamicStride(i))
1337  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1338  else
1339  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1340  }
1341  *descriptor = desc;
1342  return success();
1343  }
1344 };
1345 
1346 struct MemRefReshapeOpLowering
1347  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1349 
1350  LogicalResult
1351  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1352  ConversionPatternRewriter &rewriter) const override {
1353  Type srcType = reshapeOp.getSource().getType();
1354 
1355  Value descriptor;
1356  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1357  adaptor, &descriptor)))
1358  return failure();
1359  rewriter.replaceOp(reshapeOp, {descriptor});
1360  return success();
1361  }
1362 
1363 private:
1364  LogicalResult
1365  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1366  Type srcType, memref::ReshapeOp reshapeOp,
1367  memref::ReshapeOp::Adaptor adaptor,
1368  Value *descriptor) const {
1369  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1370  if (shapeMemRefType.hasStaticShape()) {
1371  MemRefType targetMemRefType =
1372  cast<MemRefType>(reshapeOp.getResult().getType());
1373  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1374  typeConverter->convertType(targetMemRefType));
1375  if (!llvmTargetDescriptorTy)
1376  return failure();
1377 
1378  // Create descriptor.
1379  Location loc = reshapeOp.getLoc();
1380  auto desc =
1381  MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1382 
1383  // Set allocated and aligned pointers.
1384  Value allocatedPtr, alignedPtr;
1385  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1386  reshapeOp.getSource(), adaptor.getSource(),
1387  &allocatedPtr, &alignedPtr);
1388  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1389  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1390 
1391  // Extract the offset and strides from the type.
1392  int64_t offset;
1393  SmallVector<int64_t> strides;
1394  if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1395  return rewriter.notifyMatchFailure(
1396  reshapeOp, "failed to get stride and offset exprs");
1397 
1398  if (!isStaticStrideOrOffset(offset))
1399  return rewriter.notifyMatchFailure(reshapeOp,
1400  "dynamic offset is unsupported");
1401 
1402  desc.setConstantOffset(rewriter, loc, offset);
1403 
1404  assert(targetMemRefType.getLayout().isIdentity() &&
1405  "Identity layout map is a precondition of a valid reshape op");
1406 
1407  Type indexType = getIndexType();
1408  Value stride = nullptr;
1409  int64_t targetRank = targetMemRefType.getRank();
1410  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1411  if (!ShapedType::isDynamic(strides[i])) {
1412  // If the stride for this dimension is dynamic, then use the product
1413  // of the sizes of the inner dimensions.
1414  stride =
1415  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1416  } else if (!stride) {
1417  // `stride` is null only in the first iteration of the loop. However,
1418  // since the target memref has an identity layout, we can safely set
1419  // the innermost stride to 1.
1420  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1421  }
1422 
1423  Value dimSize;
1424  // If the size of this dimension is dynamic, then load it at runtime
1425  // from the shape operand.
1426  if (!targetMemRefType.isDynamicDim(i)) {
1427  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1428  targetMemRefType.getDimSize(i));
1429  } else {
1430  Value shapeOp = reshapeOp.getShape();
1431  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1432  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1433  Type indexType = getIndexType();
1434  if (dimSize.getType() != indexType)
1435  dimSize = typeConverter->materializeTargetConversion(
1436  rewriter, loc, indexType, dimSize);
1437  assert(dimSize && "Invalid memref element type");
1438  }
1439 
1440  desc.setSize(rewriter, loc, i, dimSize);
1441  desc.setStride(rewriter, loc, i, stride);
1442 
1443  // Prepare the stride value for the next dimension.
1444  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1445  }
1446 
1447  *descriptor = desc;
1448  return success();
1449  }
1450 
1451  // The shape is a rank-1 tensor with unknown length.
1452  Location loc = reshapeOp.getLoc();
1453  MemRefDescriptor shapeDesc(adaptor.getShape());
1454  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1455 
1456  // Extract address space and element type.
1457  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1458  unsigned addressSpace =
1459  *getTypeConverter()->getMemRefAddressSpace(targetType);
1460 
1461  // Create the unranked memref descriptor that holds the ranked one. The
1462  // inner descriptor is allocated on stack.
1463  auto targetDesc = UnrankedMemRefDescriptor::poison(
1464  rewriter, loc, typeConverter->convertType(targetType));
1465  targetDesc.setRank(rewriter, loc, resultRank);
1466  SmallVector<Value, 4> sizes;
1467  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1468  targetDesc, addressSpace, sizes);
1469  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1470  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1471  sizes.front());
1472  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1473 
1474  // Extract pointers and offset from the source memref.
1475  Value allocatedPtr, alignedPtr, offset;
1476  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1477  reshapeOp.getSource(), adaptor.getSource(),
1478  &allocatedPtr, &alignedPtr, &offset);
1479 
1480  // Set pointers and offset.
1481  auto elementPtrType =
1482  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1483 
1484  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1485  elementPtrType, allocatedPtr);
1486  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1487  underlyingDescPtr, elementPtrType,
1488  alignedPtr);
1489  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1490  underlyingDescPtr, elementPtrType,
1491  offset);
1492 
1493  // Use the offset pointer as base for further addressing. Copy over the new
1494  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1495  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1496  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1497  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1498  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1499  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1500  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1501  Value resultRankMinusOne =
1502  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1503 
1504  Block *initBlock = rewriter.getInsertionBlock();
1505  Type indexType = getTypeConverter()->getIndexType();
1506  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1507 
1508  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1509  {indexType, indexType}, {loc, loc});
1510 
1511  // Move the remaining initBlock ops to condBlock.
1512  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1513  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1514 
1515  rewriter.setInsertionPointToEnd(initBlock);
1516  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1517  condBlock);
1518  rewriter.setInsertionPointToStart(condBlock);
1519  Value indexArg = condBlock->getArgument(0);
1520  Value strideArg = condBlock->getArgument(1);
1521 
1522  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1523  Value pred = rewriter.create<LLVM::ICmpOp>(
1524  loc, IntegerType::get(rewriter.getContext(), 1),
1525  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1526 
1527  Block *bodyBlock =
1528  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1529  rewriter.setInsertionPointToStart(bodyBlock);
1530 
1531  // Copy size from shape to descriptor.
1532  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1533  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1534  loc, llvmIndexPtrType,
1535  typeConverter->convertType(shapeMemRefType.getElementType()),
1536  shapeOperandPtr, indexArg);
1537  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1538  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1539  targetSizesBase, indexArg, size);
1540 
1541  // Write stride value and compute next one.
1542  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1543  targetStridesBase, indexArg, strideArg);
1544  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1545 
1546  // Decrement loop counter and branch back.
1547  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1548  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1549  condBlock);
1550 
1551  Block *remainder =
1552  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1553 
1554  // Hook up the cond exit to the remainder.
1555  rewriter.setInsertionPointToEnd(condBlock);
1556  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1557  remainder, std::nullopt);
1558 
1559  // Reset position to beginning of new remainder block.
1560  rewriter.setInsertionPointToStart(remainder);
1561 
1562  *descriptor = targetDesc;
1563  return success();
1564  }
1565 };
1566 
1567 /// RessociatingReshapeOp must be expanded before we reach this stage.
1568 /// Report that information.
1569 template <typename ReshapeOp>
1570 class ReassociatingReshapeOpConversion
1571  : public ConvertOpToLLVMPattern<ReshapeOp> {
1572 public:
1574  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1575 
1576  LogicalResult
1577  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1578  ConversionPatternRewriter &rewriter) const override {
1579  return rewriter.notifyMatchFailure(
1580  reshapeOp,
1581  "reassociation operations should have been expanded beforehand");
1582  }
1583 };
1584 
1585 /// Subviews must be expanded before we reach this stage.
1586 /// Report that information.
1587 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1589 
1590  LogicalResult
1591  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1592  ConversionPatternRewriter &rewriter) const override {
1593  return rewriter.notifyMatchFailure(
1594  subViewOp, "subview operations should have been expanded beforehand");
1595  }
1596 };
1597 
1598 /// Conversion pattern that transforms a transpose op into:
1599 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1600 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1601 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1602 /// and stride. Size and stride are permutations of the original values.
1603 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1604 /// The transpose op is replaced by the alloca'ed pointer.
1605 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1606 public:
1608 
1609  LogicalResult
1610  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1611  ConversionPatternRewriter &rewriter) const override {
1612  auto loc = transposeOp.getLoc();
1613  MemRefDescriptor viewMemRef(adaptor.getIn());
1614 
1615  // No permutation, early exit.
1616  if (transposeOp.getPermutation().isIdentity())
1617  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1618 
1619  auto targetMemRef = MemRefDescriptor::poison(
1620  rewriter, loc,
1621  typeConverter->convertType(transposeOp.getIn().getType()));
1622 
1623  // Copy the base and aligned pointers from the old descriptor to the new
1624  // one.
1625  targetMemRef.setAllocatedPtr(rewriter, loc,
1626  viewMemRef.allocatedPtr(rewriter, loc));
1627  targetMemRef.setAlignedPtr(rewriter, loc,
1628  viewMemRef.alignedPtr(rewriter, loc));
1629 
1630  // Copy the offset pointer from the old descriptor to the new one.
1631  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1632 
1633  // Iterate over the dimensions and apply size/stride permutation:
1634  // When enumerating the results of the permutation map, the enumeration
1635  // index is the index into the target dimensions and the DimExpr points to
1636  // the dimension of the source memref.
1637  for (const auto &en :
1638  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1639  int targetPos = en.index();
1640  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1641  targetMemRef.setSize(rewriter, loc, targetPos,
1642  viewMemRef.size(rewriter, loc, sourcePos));
1643  targetMemRef.setStride(rewriter, loc, targetPos,
1644  viewMemRef.stride(rewriter, loc, sourcePos));
1645  }
1646 
1647  rewriter.replaceOp(transposeOp, {targetMemRef});
1648  return success();
1649  }
1650 };
1651 
1652 /// Conversion pattern that transforms an op into:
1653 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1654 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1655 /// and stride.
1656 /// The view op is replaced by the descriptor.
1657 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1659 
1660  // Build and return the value for the idx^th shape dimension, either by
1661  // returning the constant shape dimension or counting the proper dynamic size.
1662  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1663  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1664  Type indexType) const {
1665  assert(idx < shape.size());
1666  if (!ShapedType::isDynamic(shape[idx]))
1667  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1668  // Count the number of dynamic dims in range [0, idx]
1669  unsigned nDynamic =
1670  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1671  return dynamicSizes[nDynamic];
1672  }
1673 
1674  // Build and return the idx^th stride, either by returning the constant stride
1675  // or by computing the dynamic stride from the current `runningStride` and
1676  // `nextSize`. The caller should keep a running stride and update it with the
1677  // result returned by this function.
1679  ArrayRef<int64_t> strides, Value nextSize,
1680  Value runningStride, unsigned idx, Type indexType) const {
1681  assert(idx < strides.size());
1682  if (!ShapedType::isDynamic(strides[idx]))
1683  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1684  if (nextSize)
1685  return runningStride
1686  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1687  : nextSize;
1688  assert(!runningStride);
1689  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1690  }
1691 
1692  LogicalResult
1693  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1694  ConversionPatternRewriter &rewriter) const override {
1695  auto loc = viewOp.getLoc();
1696 
1697  auto viewMemRefType = viewOp.getType();
1698  auto targetElementTy =
1699  typeConverter->convertType(viewMemRefType.getElementType());
1700  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1701  if (!targetDescTy || !targetElementTy ||
1702  !LLVM::isCompatibleType(targetElementTy) ||
1703  !LLVM::isCompatibleType(targetDescTy))
1704  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1705  failure();
1706 
1707  int64_t offset;
1708  SmallVector<int64_t, 4> strides;
1709  auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1710  if (failed(successStrides))
1711  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1712  assert(offset == 0 && "expected offset to be 0");
1713 
1714  // Target memref must be contiguous in memory (innermost stride is 1), or
1715  // empty (special case when at least one of the memref dimensions is 0).
1716  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1717  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1718  failure();
1719 
1720  // Create the descriptor.
1721  MemRefDescriptor sourceMemRef(adaptor.getSource());
1722  auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1723 
1724  // Field 1: Copy the allocated pointer, used for malloc/free.
1725  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1726  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1727  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1728 
1729  // Field 2: Copy the actual aligned pointer to payload.
1730  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1731  alignedPtr = rewriter.create<LLVM::GEPOp>(
1732  loc, alignedPtr.getType(),
1733  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1734  adaptor.getByteShift());
1735 
1736  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1737 
1738  Type indexType = getIndexType();
1739  // Field 3: The offset in the resulting type must be 0. This is
1740  // because of the type change: an offset on srcType* may not be
1741  // expressible as an offset on dstType*.
1742  targetMemRef.setOffset(
1743  rewriter, loc,
1744  createIndexAttrConstant(rewriter, loc, indexType, offset));
1745 
1746  // Early exit for 0-D corner case.
1747  if (viewMemRefType.getRank() == 0)
1748  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1749 
1750  // Fields 4 and 5: Update sizes and strides.
1751  Value stride = nullptr, nextSize = nullptr;
1752  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1753  // Update size.
1754  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1755  adaptor.getSizes(), i, indexType);
1756  targetMemRef.setSize(rewriter, loc, i, size);
1757  // Update stride.
1758  stride =
1759  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1760  targetMemRef.setStride(rewriter, loc, i, stride);
1761  nextSize = size;
1762  }
1763 
1764  rewriter.replaceOp(viewOp, {targetMemRef});
1765  return success();
1766  }
1767 };
1768 
1769 //===----------------------------------------------------------------------===//
1770 // AtomicRMWOpLowering
1771 //===----------------------------------------------------------------------===//
1772 
1773 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1774 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1775 static std::optional<LLVM::AtomicBinOp>
1776 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1777  switch (atomicOp.getKind()) {
1778  case arith::AtomicRMWKind::addf:
1779  return LLVM::AtomicBinOp::fadd;
1780  case arith::AtomicRMWKind::addi:
1781  return LLVM::AtomicBinOp::add;
1782  case arith::AtomicRMWKind::assign:
1783  return LLVM::AtomicBinOp::xchg;
1784  case arith::AtomicRMWKind::maximumf:
1785  return LLVM::AtomicBinOp::fmax;
1786  case arith::AtomicRMWKind::maxs:
1787  return LLVM::AtomicBinOp::max;
1788  case arith::AtomicRMWKind::maxu:
1789  return LLVM::AtomicBinOp::umax;
1790  case arith::AtomicRMWKind::minimumf:
1791  return LLVM::AtomicBinOp::fmin;
1792  case arith::AtomicRMWKind::mins:
1793  return LLVM::AtomicBinOp::min;
1794  case arith::AtomicRMWKind::minu:
1795  return LLVM::AtomicBinOp::umin;
1796  case arith::AtomicRMWKind::ori:
1797  return LLVM::AtomicBinOp::_or;
1798  case arith::AtomicRMWKind::andi:
1799  return LLVM::AtomicBinOp::_and;
1800  default:
1801  return std::nullopt;
1802  }
1803  llvm_unreachable("Invalid AtomicRMWKind");
1804 }
1805 
1806 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1807  using Base::Base;
1808 
1809  LogicalResult
1810  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1811  ConversionPatternRewriter &rewriter) const override {
1812  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1813  if (!maybeKind)
1814  return failure();
1815  auto memRefType = atomicOp.getMemRefType();
1816  SmallVector<int64_t> strides;
1817  int64_t offset;
1818  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1819  return failure();
1820  auto dataPtr =
1821  getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1822  adaptor.getMemref(), adaptor.getIndices());
1823  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1824  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1825  LLVM::AtomicOrdering::acq_rel);
1826  return success();
1827  }
1828 };
1829 
1830 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1831 class ConvertExtractAlignedPointerAsIndex
1832  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1833 public:
1834  using ConvertOpToLLVMPattern<
1835  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1836 
1837  LogicalResult
1838  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1839  OpAdaptor adaptor,
1840  ConversionPatternRewriter &rewriter) const override {
1841  BaseMemRefType sourceTy = extractOp.getSource().getType();
1842 
1843  Value alignedPtr;
1844  if (sourceTy.hasRank()) {
1845  MemRefDescriptor desc(adaptor.getSource());
1846  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1847  } else {
1848  auto elementPtrTy = LLVM::LLVMPointerType::get(
1849  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1850 
1851  UnrankedMemRefDescriptor desc(adaptor.getSource());
1852  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1853 
1855  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1856  elementPtrTy);
1857  }
1858 
1859  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1860  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1861  return success();
1862  }
1863 };
1864 
1865 /// Materialize the MemRef descriptor represented by the results of
1866 /// ExtractStridedMetadataOp.
1867 class ExtractStridedMetadataOpLowering
1868  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1869 public:
1870  using ConvertOpToLLVMPattern<
1871  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1872 
1873  LogicalResult
1874  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1875  OpAdaptor adaptor,
1876  ConversionPatternRewriter &rewriter) const override {
1877 
1878  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1879  return failure();
1880 
1881  // Create the descriptor.
1882  MemRefDescriptor sourceMemRef(adaptor.getSource());
1883  Location loc = extractStridedMetadataOp.getLoc();
1884  Value source = extractStridedMetadataOp.getSource();
1885 
1886  auto sourceMemRefType = cast<MemRefType>(source.getType());
1887  int64_t rank = sourceMemRefType.getRank();
1888  SmallVector<Value> results;
1889  results.reserve(2 + rank * 2);
1890 
1891  // Base buffer.
1892  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1893  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1895  rewriter, loc, *getTypeConverter(),
1896  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1897  baseBuffer, alignedBuffer);
1898  results.push_back((Value)dstMemRef);
1899 
1900  // Offset.
1901  results.push_back(sourceMemRef.offset(rewriter, loc));
1902 
1903  // Sizes.
1904  for (unsigned i = 0; i < rank; ++i)
1905  results.push_back(sourceMemRef.size(rewriter, loc, i));
1906  // Strides.
1907  for (unsigned i = 0; i < rank; ++i)
1908  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1909 
1910  rewriter.replaceOp(extractStridedMetadataOp, results);
1911  return success();
1912  }
1913 };
1914 
1915 } // namespace
1916 
1918  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1919  // clang-format off
1920  patterns.add<
1921  AllocaOpLowering,
1922  AllocaScopeOpLowering,
1923  AtomicRMWOpLowering,
1924  AssumeAlignmentOpLowering,
1925  ConvertExtractAlignedPointerAsIndex,
1926  DimOpLowering,
1927  ExtractStridedMetadataOpLowering,
1928  GenericAtomicRMWOpLowering,
1929  GlobalMemrefOpLowering,
1930  GetGlobalMemrefOpLowering,
1931  LoadOpLowering,
1932  MemRefCastOpLowering,
1933  MemRefCopyOpLowering,
1934  MemorySpaceCastOpLowering,
1935  MemRefReinterpretCastOpLowering,
1936  MemRefReshapeOpLowering,
1937  PrefetchOpLowering,
1938  RankOpLowering,
1939  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1940  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1941  StoreOpLowering,
1942  SubViewOpLowering,
1944  ViewOpLowering>(converter);
1945  // clang-format on
1946  auto allocLowering = converter.getOptions().allocLowering;
1948  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1949  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1950  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1951 }
1952 
1953 namespace {
1954 struct FinalizeMemRefToLLVMConversionPass
1955  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1956  FinalizeMemRefToLLVMConversionPass> {
1957  using FinalizeMemRefToLLVMConversionPassBase::
1958  FinalizeMemRefToLLVMConversionPassBase;
1959 
1960  void runOnOperation() override {
1961  Operation *op = getOperation();
1962  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1964  dataLayoutAnalysis.getAtOrAbove(op));
1965  options.allocLowering =
1968 
1969  options.useGenericFunctions = useGenericFunctions;
1970 
1971  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1972  options.overrideIndexBitwidth(indexBitwidth);
1973 
1974  LLVMTypeConverter typeConverter(&getContext(), options,
1975  &dataLayoutAnalysis);
1978  LLVMConversionTarget target(getContext());
1979  target.addLegalOp<func::FuncOp>();
1980  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1981  signalPassFailure();
1982  }
1983 };
1984 
1985 /// Implement the interface to convert MemRef to LLVM.
1986 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1988  void loadDependentDialects(MLIRContext *context) const final {
1989  context->loadDialect<LLVM::LLVMDialect>();
1990  }
1991 
1992  /// Hook for derived dialect interface to provide conversion patterns
1993  /// and mark dialect legal for the conversion target.
1994  void populateConvertToLLVMConversionPatterns(
1995  ConversionTarget &target, LLVMTypeConverter &typeConverter,
1996  RewritePatternSet &patterns) const final {
1998  }
1999 };
2000 
2001 } // namespace
2002 
2004  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2005  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2006  });
2007 }
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:88
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:191
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:197
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:442
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:439
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:451
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
result_range getResults()
Definition: Operation.h:415
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:487
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:651
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.