MLIR  22.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/Support/DebugLog.h"
28 #include "llvm/Support/MathExtras.h"
29 
30 #include <optional>
31 
32 #define DEBUG_TYPE "memref-to-llvm"
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 using namespace mlir;
40 
41 static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
42  LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43 
44 namespace {
45 
46 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47  return ShapedType::isStatic(strideOrOffset);
48 }
49 
50 static FailureOr<LLVM::LLVMFuncOp>
51 getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
52  SymbolTableCollection *symbolTables) {
53  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
54 
55  if (useGenericFn)
56  return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
57 
58  return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
59 }
60 
61 static FailureOr<LLVM::LLVMFuncOp>
62 getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
63  Operation *module, Type indexType,
64  SymbolTableCollection *symbolTables) {
65  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
66  if (useGenericFn)
67  return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
68  symbolTables);
69 
70  return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
71 }
72 
73 static FailureOr<LLVM::LLVMFuncOp>
74 getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
75  Operation *module, Type indexType,
76  SymbolTableCollection *symbolTables) {
77  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
78 
79  if (useGenericFn)
80  return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
81  symbolTables);
82 
83  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
84 }
85 
86 /// Computes the aligned value for 'input' as follows:
87 /// bumped = input + alignement - 1
88 /// aligned = bumped - bumped % alignment
89 static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
90  Value input, Value alignment) {
91  Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
92  rewriter.getIndexAttr(1));
93  Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94  Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95  Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96  return LLVM::SubOp::create(rewriter, loc, bumped, mod);
97 }
98 
99 /// Computes the byte size for the MemRef element type.
100 static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
101  MemRefType memRefType, Operation *op,
102  const DataLayout *defaultLayout) {
103  const DataLayout *layout = defaultLayout;
104  if (const DataLayoutAnalysis *analysis =
105  typeConverter->getDataLayoutAnalysis()) {
106  layout = &analysis->getAbove(op);
107  }
108  Type elementType = memRefType.getElementType();
109  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110  return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
111  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112  return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
113  *layout);
114  return layout->getTypeSize(elementType);
115 }
116 
117 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
118  Location loc, Value allocatedPtr,
119  MemRefType memRefType, Type elementPtrType,
120  const LLVMTypeConverter &typeConverter) {
121  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
122  FailureOr<unsigned> maybeMemrefAddrSpace =
123  typeConverter.getMemRefAddressSpace(memRefType);
124  assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
125  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127  allocatedPtr = LLVM::AddrSpaceCastOp::create(
128  rewriter, loc,
129  LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
130  allocatedPtr);
131  return allocatedPtr;
132 }
133 
134 class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
135  SymbolTableCollection *symbolTables = nullptr;
136 
137 public:
138  explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
139  SymbolTableCollection *symbolTables = nullptr,
140  PatternBenefit benefit = 1)
141  : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
142  symbolTables(symbolTables) {}
143 
144  LogicalResult
145  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146  ConversionPatternRewriter &rewriter) const override {
147  auto loc = op.getLoc();
148  MemRefType memRefType = op.getType();
149  if (!isConvertibleAndHasIdentityMaps(memRefType))
150  return rewriter.notifyMatchFailure(op, "incompatible memref type");
151 
152  // Get or insert alloc function into the module.
153  FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154  getNotalignedAllocFn(rewriter, getTypeConverter(),
155  op->getParentWithTrait<OpTrait::SymbolTable>(),
156  getIndexType(), symbolTables);
157  if (failed(allocFuncOp))
158  return failure();
159 
160  // Get actual sizes of the memref as values: static sizes are constant
161  // values and dynamic sizes are passed to 'alloc' as operands. In case of
162  // zero-dimensional memref, assume a scalar (size 1).
163  SmallVector<Value, 4> sizes;
164  SmallVector<Value, 4> strides;
165  Value sizeBytes;
166 
167  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168  rewriter, sizes, strides, sizeBytes, true);
169 
170  Value alignment = getAlignment(rewriter, loc, op);
171  if (alignment) {
172  // Adjust the allocation size to consider alignment.
173  sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
174  }
175 
176  // Allocate the underlying buffer.
177  Type elementPtrType = this->getElementPtrType(memRefType);
178  assert(elementPtrType && "could not compute element ptr type");
179  auto results =
180  LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
181 
182  Value allocatedPtr =
183  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184  elementPtrType, *getTypeConverter());
185  Value alignedPtr = allocatedPtr;
186  if (alignment) {
187  // Compute the aligned pointer.
188  Value allocatedInt =
189  LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
190  Value alignmentInt =
191  createAligned(rewriter, loc, allocatedInt, alignment);
192  alignedPtr =
193  LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
194  }
195 
196  // Create the MemRef descriptor.
197  auto memRefDescriptor = this->createMemRefDescriptor(
198  loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
199 
200  // Return the final value of the descriptor.
201  rewriter.replaceOp(op, {memRefDescriptor});
202  return success();
203  }
204 
205  /// Computes the alignment for the given memory allocation op.
206  template <typename OpType>
207  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
208  OpType op) const {
209  MemRefType memRefType = op.getType();
210  Value alignment;
211  if (auto alignmentAttr = op.getAlignment()) {
212  Type indexType = getIndexType();
213  alignment =
214  createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
215  } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
216  // In the case where no alignment is specified, we may want to override
217  // `malloc's` behavior. `malloc` typically aligns at the size of the
218  // biggest scalar on a target HW. For non-scalars, use the natural
219  // alignment of the LLVM type given by the LLVM DataLayout.
220  alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
221  }
222  return alignment;
223  }
224 };
225 
226 class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
227  SymbolTableCollection *symbolTables = nullptr;
228 
229 public:
230  explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
231  SymbolTableCollection *symbolTables = nullptr,
232  PatternBenefit benefit = 1)
233  : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
234  symbolTables(symbolTables) {}
235 
236  LogicalResult
237  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override {
239  auto loc = op.getLoc();
240  MemRefType memRefType = op.getType();
241  if (!isConvertibleAndHasIdentityMaps(memRefType))
242  return rewriter.notifyMatchFailure(op, "incompatible memref type");
243 
244  // Get or insert alloc function into module.
245  FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246  getAlignedAllocFn(rewriter, getTypeConverter(),
247  op->getParentWithTrait<OpTrait::SymbolTable>(),
248  getIndexType(), symbolTables);
249  if (failed(allocFuncOp))
250  return failure();
251 
252  // Get actual sizes of the memref as values: static sizes are constant
253  // values and dynamic sizes are passed to 'alloc' as operands. In case of
254  // zero-dimensional memref, assume a scalar (size 1).
255  SmallVector<Value, 4> sizes;
256  SmallVector<Value, 4> strides;
257  Value sizeBytes;
258 
259  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260  rewriter, sizes, strides, sizeBytes, !false);
261 
262  int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
263 
264  Value allocAlignment =
265  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
266 
267  // Function aligned_alloc requires size to be a multiple of alignment; we
268  // pad the size to the next multiple if necessary.
269  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
271 
272  Type elementPtrType = this->getElementPtrType(memRefType);
273  auto results =
274  LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
275  ValueRange({allocAlignment, sizeBytes}));
276 
277  Value ptr =
278  castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279  elementPtrType, *getTypeConverter());
280 
281  // Create the MemRef descriptor.
282  auto memRefDescriptor = this->createMemRefDescriptor(
283  loc, memRefType, ptr, ptr, sizes, strides, rewriter);
284 
285  // Return the final value of the descriptor.
286  rewriter.replaceOp(op, {memRefDescriptor});
287  return success();
288  }
289 
290  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
291  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
292 
293  /// Computes the alignment for aligned_alloc used to allocate the buffer for
294  /// the memory allocation op.
295  ///
296  /// Aligned_alloc requires the allocation size to be a power of two, and the
297  /// allocation size to be a multiple of the alignment.
298  int64_t alignedAllocationGetAlignment(memref::AllocOp op,
299  const DataLayout *defaultLayout) const {
300  if (std::optional<uint64_t> alignment = op.getAlignment())
301  return *alignment;
302 
303  // Whenever we don't have alignment set, we will use an alignment
304  // consistent with the element type; since the allocation size has to be a
305  // power of two, we will bump to the next power of two if it isn't.
306  unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307  getTypeConverter(), op.getType(), op, defaultLayout);
308  return std::max(kMinAlignedAllocAlignment,
309  llvm::PowerOf2Ceil(eltSizeBytes));
310  }
311 
312  /// Returns true if the memref size in bytes is known to be a multiple of
313  /// factor.
314  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
315  const DataLayout *defaultLayout) const {
316  uint64_t sizeDivisor =
317  getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
319  if (type.isDynamicDim(i))
320  continue;
321  sizeDivisor = sizeDivisor * type.getDimSize(i);
322  }
323  return sizeDivisor % factor == 0;
324  }
325 
326 private:
327  /// Default layout to use in absence of the corresponding analysis.
328  DataLayout defaultLayout;
329 };
330 
331 struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
333 
334  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
335  /// is set to null for stack allocations. `accessAlignment` is set if
336  /// alignment is needed post allocation (for eg. in conjunction with malloc).
337  LogicalResult
338  matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339  ConversionPatternRewriter &rewriter) const override {
340  auto loc = op.getLoc();
341  MemRefType memRefType = op.getType();
342  if (!isConvertibleAndHasIdentityMaps(memRefType))
343  return rewriter.notifyMatchFailure(op, "incompatible memref type");
344 
345  // Get actual sizes of the memref as values: static sizes are constant
346  // values and dynamic sizes are passed to 'alloc' as operands. In case of
347  // zero-dimensional memref, assume a scalar (size 1).
348  SmallVector<Value, 4> sizes;
349  SmallVector<Value, 4> strides;
350  Value size;
351 
352  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353  rewriter, sizes, strides, size, !true);
354 
355  // With alloca, one gets a pointer to the element type right away.
356  // For stack allocations.
357  auto elementType =
358  typeConverter->convertType(op.getType().getElementType());
359  FailureOr<unsigned> maybeAddressSpace =
360  getTypeConverter()->getMemRefAddressSpace(op.getType());
361  assert(succeeded(maybeAddressSpace) && "unsupported address space");
362  unsigned addrSpace = *maybeAddressSpace;
363  auto elementPtrType =
364  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
365 
366  auto allocatedElementPtr =
367  LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368  op.getAlignment().value_or(0));
369 
370  // Create the MemRef descriptor.
371  auto memRefDescriptor = this->createMemRefDescriptor(
372  loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
373  strides, rewriter);
374 
375  // Return the final value of the descriptor.
376  rewriter.replaceOp(op, {memRefDescriptor});
377  return success();
378  }
379 };
380 
381 struct AllocaScopeOpLowering
382  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
384 
385  LogicalResult
386  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override {
388  OpBuilder::InsertionGuard guard(rewriter);
389  Location loc = allocaScopeOp.getLoc();
390 
391  // Split the current block before the AllocaScopeOp to create the inlining
392  // point.
393  auto *currentBlock = rewriter.getInsertionBlock();
394  auto *remainingOpsBlock =
395  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396  Block *continueBlock;
397  if (allocaScopeOp.getNumResults() == 0) {
398  continueBlock = remainingOpsBlock;
399  } else {
400  continueBlock = rewriter.createBlock(
401  remainingOpsBlock, allocaScopeOp.getResultTypes(),
402  SmallVector<Location>(allocaScopeOp->getNumResults(),
403  allocaScopeOp.getLoc()));
404  LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
405  }
406 
407  // Inline body region.
408  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
409  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
410  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
411 
412  // Save stack and then branch into the body of the region.
413  rewriter.setInsertionPointToEnd(currentBlock);
414  auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415  LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
416 
417  // Replace the alloca_scope return with a branch that jumps out of the body.
418  // Stack restore before leaving the body region.
419  rewriter.setInsertionPointToEnd(afterBody);
420  auto returnOp =
421  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
422  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423  returnOp, returnOp.getResults(), continueBlock);
424 
425  // Insert stack restore before jumping out the body of the region.
426  rewriter.setInsertionPoint(branchOp);
427  LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
428 
429  // Replace the op with values return from the body region.
430  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
431 
432  return success();
433  }
434 };
435 
436 struct AssumeAlignmentOpLowering
437  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
439  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
440  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
442 
443  LogicalResult
444  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445  ConversionPatternRewriter &rewriter) const override {
446  Value memref = adaptor.getMemref();
447  unsigned alignment = op.getAlignment();
448  auto loc = op.getLoc();
449 
450  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451  Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
452  /*indices=*/{});
453 
454  // Emit llvm.assume(true) ["align"(memref, alignment)].
455  // This is more direct than ptrtoint-based checks, is explicitly supported,
456  // and works with non-integral address spaces.
457  Value trueCond =
458  LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
459  Value alignmentConst =
460  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
461  LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
462  alignmentConst);
463  rewriter.replaceOp(op, memref);
464  return success();
465  }
466 };
467 
468 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
469 // The memref descriptor being an SSA value, there is no need to clean it up
470 // in any way.
471 class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
472  SymbolTableCollection *symbolTables = nullptr;
473 
474 public:
475  explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
476  SymbolTableCollection *symbolTables = nullptr,
477  PatternBenefit benefit = 1)
478  : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
479  symbolTables(symbolTables) {}
480 
481  LogicalResult
482  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
483  ConversionPatternRewriter &rewriter) const override {
484  // Insert the `free` declaration if it is not already present.
485  FailureOr<LLVM::LLVMFuncOp> freeFunc =
486  getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
487  symbolTables);
488  if (failed(freeFunc))
489  return failure();
490  Value allocatedPtr;
491  if (auto unrankedTy =
492  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
493  auto elementPtrTy = LLVM::LLVMPointerType::get(
494  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
496  rewriter, op.getLoc(),
497  UnrankedMemRefDescriptor(adaptor.getMemref())
498  .memRefDescPtr(rewriter, op.getLoc()),
499  elementPtrTy);
500  } else {
501  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
502  .allocatedPtr(rewriter, op.getLoc());
503  }
504  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
505  allocatedPtr);
506  return success();
507  }
508 };
509 
510 // A `dim` is converted to a constant for static sizes and to an access to the
511 // size stored in the memref descriptor for dynamic sizes.
512 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
514 
515  LogicalResult
516  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
517  ConversionPatternRewriter &rewriter) const override {
518  Type operandType = dimOp.getSource().getType();
519  if (isa<UnrankedMemRefType>(operandType)) {
520  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
521  operandType, dimOp, adaptor.getOperands(), rewriter);
522  if (failed(extractedSize))
523  return failure();
524  rewriter.replaceOp(dimOp, {*extractedSize});
525  return success();
526  }
527  if (isa<MemRefType>(operandType)) {
528  rewriter.replaceOp(
529  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
530  adaptor.getOperands(), rewriter)});
531  return success();
532  }
533  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
534  }
535 
536 private:
537  FailureOr<Value>
538  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
539  OpAdaptor adaptor,
540  ConversionPatternRewriter &rewriter) const {
541  Location loc = dimOp.getLoc();
542 
543  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
544  auto scalarMemRefType =
545  MemRefType::get({}, unrankedMemRefType.getElementType());
546  FailureOr<unsigned> maybeAddressSpace =
547  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
548  if (failed(maybeAddressSpace)) {
549  dimOp.emitOpError("memref memory space must be convertible to an integer "
550  "address space");
551  return failure();
552  }
553  unsigned addressSpace = *maybeAddressSpace;
554 
555  // Extract pointer to the underlying ranked descriptor and bitcast it to a
556  // memref<element_type> descriptor pointer to minimize the number of GEP
557  // operations.
558  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
559  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
560 
561  Type elementType = typeConverter->convertType(scalarMemRefType);
562 
563  // Get pointer to offset field of memref<element_type> descriptor.
564  auto indexPtrTy =
565  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
566  Value offsetPtr =
567  LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
568  underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
569 
570  // The size value that we have to extract can be obtained using GEPop with
571  // `dimOp.index() + 1` index argument.
572  Value idxPlusOne = LLVM::AddOp::create(
573  rewriter, loc,
574  createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
575  adaptor.getIndex());
576  Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
577  getTypeConverter()->getIndexType(),
578  offsetPtr, idxPlusOne);
579  return LLVM::LoadOp::create(rewriter, loc,
580  getTypeConverter()->getIndexType(), sizePtr)
581  .getResult();
582  }
583 
584  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
585  if (auto idx = dimOp.getConstantIndex())
586  return idx;
587 
588  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
589  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
590 
591  return std::nullopt;
592  }
593 
594  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
595  OpAdaptor adaptor,
596  ConversionPatternRewriter &rewriter) const {
597  Location loc = dimOp.getLoc();
598 
599  // Take advantage if index is constant.
600  MemRefType memRefType = cast<MemRefType>(operandType);
601  Type indexType = getIndexType();
602  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
603  int64_t i = *index;
604  if (i >= 0 && i < memRefType.getRank()) {
605  if (memRefType.isDynamicDim(i)) {
606  // extract dynamic size from the memref descriptor.
607  MemRefDescriptor descriptor(adaptor.getSource());
608  return descriptor.size(rewriter, loc, i);
609  }
610  // Use constant for static size.
611  int64_t dimSize = memRefType.getDimSize(i);
612  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
613  }
614  }
615  Value index = adaptor.getIndex();
616  int64_t rank = memRefType.getRank();
617  MemRefDescriptor memrefDescriptor(adaptor.getSource());
618  return memrefDescriptor.size(rewriter, loc, index, rank);
619  }
620 };
621 
622 /// Common base for load and store operations on MemRefs. Restricts the match
623 /// to supported MemRef types. Provides functionality to emit code accessing a
624 /// specific element of the underlying data buffer.
625 template <typename Derived>
626 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
629  using Base = LoadStoreOpLowering<Derived>;
630 };
631 
632 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
633 /// retried until it succeeds in atomically storing a new value into memory.
634 ///
635 /// +---------------------------------+
636 /// | <code before the AtomicRMWOp> |
637 /// | <compute initial %loaded> |
638 /// | cf.br loop(%loaded) |
639 /// +---------------------------------+
640 /// |
641 /// -------| |
642 /// | v v
643 /// | +--------------------------------+
644 /// | | loop(%loaded): |
645 /// | | <body contents> |
646 /// | | %pair = cmpxchg |
647 /// | | %ok = %pair[0] |
648 /// | | %new = %pair[1] |
649 /// | | cf.cond_br %ok, end, loop(%new) |
650 /// | +--------------------------------+
651 /// | | |
652 /// |----------- |
653 /// v
654 /// +--------------------------------+
655 /// | end: |
656 /// | <code after the AtomicRMWOp> |
657 /// +--------------------------------+
658 ///
659 struct GenericAtomicRMWOpLowering
660  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
661  using Base::Base;
662 
663  LogicalResult
664  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
665  ConversionPatternRewriter &rewriter) const override {
666  auto loc = atomicOp.getLoc();
667  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
668 
669  // Split the block into initial, loop, and ending parts.
670  auto *initBlock = rewriter.getInsertionBlock();
671  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
672  loopBlock->addArgument(valueType, loc);
673 
674  auto *endBlock =
675  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
676 
677  // Compute the loaded value and branch to the loop block.
678  rewriter.setInsertionPointToEnd(initBlock);
679  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
680  auto dataPtr = getStridedElementPtr(
681  rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
682  Value init = LLVM::LoadOp::create(
683  rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
684  dataPtr);
685  LLVM::BrOp::create(rewriter, loc, init, loopBlock);
686 
687  // Prepare the body of the loop block.
688  rewriter.setInsertionPointToStart(loopBlock);
689 
690  // Clone the GenericAtomicRMWOp region and extract the result.
691  auto loopArgument = loopBlock->getArgument(0);
692  IRMapping mapping;
693  mapping.map(atomicOp.getCurrentValue(), loopArgument);
694  Block &entryBlock = atomicOp.body().front();
695  for (auto &nestedOp : entryBlock.without_terminator()) {
696  Operation *clone = rewriter.clone(nestedOp, mapping);
697  mapping.map(nestedOp.getResults(), clone->getResults());
698  }
699  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
700 
701  // Prepare the epilog of the loop block.
702  // Append the cmpxchg op to the end of the loop block.
703  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
704  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
705  auto cmpxchg =
706  LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
707  result, successOrdering, failureOrdering);
708  // Extract the %new_loaded and %ok values from the pair.
709  Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
710  Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
711 
712  // Conditionally branch to the end or back to the loop depending on %ok.
713  LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
714  loopBlock, newLoaded);
715 
716  rewriter.setInsertionPointToEnd(endBlock);
717 
718  // The 'result' of the atomic_rmw op is the newly loaded value.
719  rewriter.replaceOp(atomicOp, {newLoaded});
720 
721  return success();
722  }
723 };
724 
725 /// Returns the LLVM type of the global variable given the memref type `type`.
726 static Type
727 convertGlobalMemrefTypeToLLVM(MemRefType type,
728  const LLVMTypeConverter &typeConverter) {
729  // LLVM type for a global memref will be a multi-dimension array. For
730  // declarations or uninitialized global memrefs, we can potentially flatten
731  // this to a 1D array. However, for memref.global's with an initial value,
732  // we do not intend to flatten the ElementsAttribute when going from std ->
733  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
734  Type elementType = typeConverter.convertType(type.getElementType());
735  Type arrayTy = elementType;
736  // Shape has the outermost dim at index 0, so need to walk it backwards
737  for (int64_t dim : llvm::reverse(type.getShape()))
738  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
739  return arrayTy;
740 }
741 
742 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
743 class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
744  SymbolTableCollection *symbolTables = nullptr;
745 
746 public:
747  explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
748  SymbolTableCollection *symbolTables = nullptr,
749  PatternBenefit benefit = 1)
750  : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
751  symbolTables(symbolTables) {}
752 
753  LogicalResult
754  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
755  ConversionPatternRewriter &rewriter) const override {
756  MemRefType type = global.getType();
757  if (!isConvertibleAndHasIdentityMaps(type))
758  return failure();
759 
760  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
761 
762  LLVM::Linkage linkage =
763  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
764  bool isExternal = global.isExternal();
765  bool isUninitialized = global.isUninitialized();
766 
767  Attribute initialValue = nullptr;
768  if (!isExternal && !isUninitialized) {
769  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
770  initialValue = elementsAttr;
771 
772  // For scalar memrefs, the global variable created is of the element type,
773  // so unpack the elements attribute to extract the value.
774  if (type.getRank() == 0)
775  initialValue = elementsAttr.getSplatValue<Attribute>();
776  }
777 
778  uint64_t alignment = global.getAlignment().value_or(0);
779  FailureOr<unsigned> addressSpace =
780  getTypeConverter()->getMemRefAddressSpace(type);
781  if (failed(addressSpace))
782  return global.emitOpError(
783  "memory space cannot be converted to an integer address space");
784 
785  // Remove old operation from symbol table.
786  SymbolTable *symbolTable = nullptr;
787  if (symbolTables) {
788  Operation *symbolTableOp =
790  symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
791  symbolTable->remove(global);
792  }
793 
794  // Create new operation.
795  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
796  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
797  initialValue, alignment, *addressSpace);
798 
799  // Insert new operation into symbol table.
800  if (symbolTable)
801  symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
802 
803  if (!isExternal && isUninitialized) {
804  rewriter.createBlock(&newGlobal.getInitializerRegion());
805  Value undef[] = {
806  LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
807  LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
808  }
809  return success();
810  }
811 };
812 
813 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
814 /// the first element stashed into the descriptor. This reuses
815 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
816 struct GetGlobalMemrefOpLowering
817  : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
819 
820  /// Buffer "allocation" for memref.get_global op is getting the address of
821  /// the global variable referenced.
822  LogicalResult
823  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
824  ConversionPatternRewriter &rewriter) const override {
825  auto loc = op.getLoc();
826  MemRefType memRefType = op.getType();
827  if (!isConvertibleAndHasIdentityMaps(memRefType))
828  return rewriter.notifyMatchFailure(op, "incompatible memref type");
829 
830  // Get actual sizes of the memref as values: static sizes are constant
831  // values and dynamic sizes are passed to 'alloc' as operands. In case of
832  // zero-dimensional memref, assume a scalar (size 1).
833  SmallVector<Value, 4> sizes;
834  SmallVector<Value, 4> strides;
835  Value sizeBytes;
836 
837  this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
838  rewriter, sizes, strides, sizeBytes, !false);
839 
840  MemRefType type = cast<MemRefType>(op.getResult().getType());
841 
842  // This is called after a type conversion, which would have failed if this
843  // call fails.
844  FailureOr<unsigned> maybeAddressSpace =
845  getTypeConverter()->getMemRefAddressSpace(type);
846  assert(succeeded(maybeAddressSpace) && "unsupported address space");
847  unsigned memSpace = *maybeAddressSpace;
848 
849  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
850  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
851  auto addressOf =
852  LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
853 
854  // Get the address of the first element in the array by creating a GEP with
855  // the address of the GV as the base, and (rank + 1) number of 0 indices.
856  auto gep =
857  LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
858  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
859 
860  // We do not expect the memref obtained using `memref.get_global` to be
861  // ever deallocated. Set the allocated pointer to be known bad value to
862  // help debug if that ever happens.
863  auto intPtrType = getIntPtrType(memSpace);
864  Value deadBeefConst =
865  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
866  auto deadBeefPtr =
867  LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
868 
869  // Both allocated and aligned pointers are same. We could potentially stash
870  // a nullptr for the allocated pointer since we do not expect any dealloc.
871  // Create the MemRef descriptor.
872  auto memRefDescriptor = this->createMemRefDescriptor(
873  loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
874 
875  // Return the final value of the descriptor.
876  rewriter.replaceOp(op, {memRefDescriptor});
877  return success();
878  }
879 };
880 
881 // Load operation is lowered to obtaining a pointer to the indexed element
882 // and loading it.
883 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
884  using Base::Base;
885 
886  LogicalResult
887  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
888  ConversionPatternRewriter &rewriter) const override {
889  auto type = loadOp.getMemRefType();
890 
891  // Per memref.load spec, the indices must be in-bounds:
892  // 0 <= idx < dim_size, and additionally all offsets are non-negative,
893  // hence inbounds and nuw are used when lowering to llvm.getelementptr.
894  Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
895  adaptor.getMemref(),
896  adaptor.getIndices(), kNoWrapFlags);
897  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
898  loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
899  loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
900  return success();
901  }
902 };
903 
904 // Store operation is lowered to obtaining a pointer to the indexed element,
905 // and storing the given value to it.
906 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
907  using Base::Base;
908 
909  LogicalResult
910  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
911  ConversionPatternRewriter &rewriter) const override {
912  auto type = op.getMemRefType();
913 
914  // Per memref.store spec, the indices must be in-bounds:
915  // 0 <= idx < dim_size, and additionally all offsets are non-negative,
916  // hence inbounds and nuw are used when lowering to llvm.getelementptr.
917  Value dataPtr =
918  getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
919  adaptor.getIndices(), kNoWrapFlags);
920  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
921  op.getAlignment().value_or(0),
922  false, op.getNontemporal());
923  return success();
924  }
925 };
926 
927 // The prefetch operation is lowered in a way similar to the load operation
928 // except that the llvm.prefetch operation is used for replacement.
929 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
930  using Base::Base;
931 
932  LogicalResult
933  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
934  ConversionPatternRewriter &rewriter) const override {
935  auto type = prefetchOp.getMemRefType();
936  auto loc = prefetchOp.getLoc();
937 
938  Value dataPtr = getStridedElementPtr(
939  rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
940 
941  // Replace with llvm.prefetch.
942  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
943  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
944  IntegerAttr isData =
945  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
946  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
947  localityHint, isData);
948  return success();
949  }
950 };
951 
952 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
954 
955  LogicalResult
956  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
957  ConversionPatternRewriter &rewriter) const override {
958  Location loc = op.getLoc();
959  Type operandType = op.getMemref().getType();
960  if (isa<UnrankedMemRefType>(operandType)) {
961  UnrankedMemRefDescriptor desc(adaptor.getMemref());
962  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
963  return success();
964  }
965  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
966  Type indexType = getIndexType();
967  rewriter.replaceOp(op,
968  {createIndexAttrConstant(rewriter, loc, indexType,
969  rankedMemRefType.getRank())});
970  return success();
971  }
972  return failure();
973  }
974 };
975 
976 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
978 
979  LogicalResult
980  matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
981  ConversionPatternRewriter &rewriter) const override {
982  Type srcType = memRefCastOp.getOperand().getType();
983  Type dstType = memRefCastOp.getType();
984 
985  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
986  // used for type erasure. For now they must preserve underlying element type
987  // and require source and result type to have the same rank. Therefore,
988  // perform a sanity check that the underlying structs are the same. Once op
989  // semantics are relaxed we can revisit.
990  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
991  if (typeConverter->convertType(srcType) !=
992  typeConverter->convertType(dstType))
993  return failure();
994 
995  // Unranked to unranked cast is disallowed
996  if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
997  return failure();
998 
999  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1000  auto loc = memRefCastOp.getLoc();
1001 
1002  // For ranked/ranked case, just keep the original descriptor.
1003  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1004  rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1005  return success();
1006  }
1007 
1008  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1009  // Casting ranked to unranked memref type
1010  // Set the rank in the destination from the memref type
1011  // Allocate space on the stack and copy the src memref descriptor
1012  // Set the ptr in the destination to the stack space
1013  auto srcMemRefType = cast<MemRefType>(srcType);
1014  int64_t rank = srcMemRefType.getRank();
1015  // ptr = AllocaOp sizeof(MemRefDescriptor)
1016  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1017  loc, adaptor.getSource(), rewriter);
1018 
1019  // rank = ConstantOp srcRank
1020  auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1021  rewriter.getIndexAttr(rank));
1022  // poison = PoisonOp
1023  UnrankedMemRefDescriptor memRefDesc =
1024  UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
1025  // d1 = InsertValueOp poison, rank, 0
1026  memRefDesc.setRank(rewriter, loc, rankVal);
1027  // d2 = InsertValueOp d1, ptr, 1
1028  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
1029  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
1030 
1031  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1032  // Casting from unranked type to ranked.
1033  // The operation is assumed to be doing a correct cast. If the destination
1034  // type mismatches the unranked the type, it is undefined behavior.
1035  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1036  // ptr = ExtractValueOp src, 1
1037  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
1038 
1039  // struct = LoadOp ptr
1040  auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1041  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1042  } else {
1043  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1044  }
1045 
1046  return success();
1047  }
1048 };
1049 
1050 /// Pattern to lower a `memref.copy` to llvm.
1051 ///
1052 /// For memrefs with identity layouts, the copy is lowered to the llvm
1053 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1054 /// to the generic `MemrefCopyFn`.
1055 class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1056  SymbolTableCollection *symbolTables = nullptr;
1057 
1058 public:
1059  explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1060  SymbolTableCollection *symbolTables = nullptr,
1061  PatternBenefit benefit = 1)
1062  : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1063  symbolTables(symbolTables) {}
1064 
1065  LogicalResult
1066  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1067  ConversionPatternRewriter &rewriter) const {
1068  auto loc = op.getLoc();
1069  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1070 
1071  MemRefDescriptor srcDesc(adaptor.getSource());
1072 
1073  // Compute number of elements.
1074  Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1075  rewriter.getIndexAttr(1));
1076  for (int pos = 0; pos < srcType.getRank(); ++pos) {
1077  auto size = srcDesc.size(rewriter, loc, pos);
1078  numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1079  }
1080 
1081  // Get element size.
1082  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1083  // Compute total.
1084  Value totalSize =
1085  LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1086 
1087  Type elementType = typeConverter->convertType(srcType.getElementType());
1088 
1089  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1090  Value srcOffset = srcDesc.offset(rewriter, loc);
1091  Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
1092  elementType, srcBasePtr, srcOffset);
1093  MemRefDescriptor targetDesc(adaptor.getTarget());
1094  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1095  Value targetOffset = targetDesc.offset(rewriter, loc);
1096  Value targetPtr =
1097  LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
1098  targetBasePtr, targetOffset);
1099  LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1100  /*isVolatile=*/false);
1101  rewriter.eraseOp(op);
1102 
1103  return success();
1104  }
1105 
1106  LogicalResult
1107  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1108  ConversionPatternRewriter &rewriter) const {
1109  auto loc = op.getLoc();
1110  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1111  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1112 
1113  // First make sure we have an unranked memref descriptor representation.
1114  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1115  auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1116  type.getRank());
1117  auto *typeConverter = getTypeConverter();
1118  auto ptr =
1119  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1120 
1121  auto unrankedType =
1122  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1124  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1125  };
1126 
1127  // Save stack position before promoting descriptors
1128  auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1129 
1130  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1131  Value unrankedSource =
1132  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1133  : adaptor.getSource();
1134  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1135  Value unrankedTarget =
1136  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1137  : adaptor.getTarget();
1138 
1139  // Now promote the unranked descriptors to the stack.
1140  auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1141  rewriter.getIndexAttr(1));
1142  auto promote = [&](Value desc) {
1143  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1144  auto allocated =
1145  LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1146  LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1147  return allocated;
1148  };
1149 
1150  auto sourcePtr = promote(unrankedSource);
1151  auto targetPtr = promote(unrankedTarget);
1152 
1153  // Derive size from llvm.getelementptr which will account for any
1154  // potential alignment
1155  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1156  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1157  rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1158  sourcePtr.getType(), symbolTables);
1159  if (failed(copyFn))
1160  return failure();
1161  LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1162  ValueRange{elemSize, sourcePtr, targetPtr});
1163 
1164  // Restore stack used for descriptors
1165  LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1166 
1167  rewriter.eraseOp(op);
1168 
1169  return success();
1170  }
1171 
1172  LogicalResult
1173  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1174  ConversionPatternRewriter &rewriter) const override {
1175  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1176  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1177 
1178  auto isContiguousMemrefType = [&](BaseMemRefType type) {
1179  auto memrefType = dyn_cast<mlir::MemRefType>(type);
1180  // We can use memcpy for memrefs if they have an identity layout or are
1181  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1182  // special case handled by memrefCopy.
1183  return memrefType &&
1184  (memrefType.getLayout().isIdentity() ||
1185  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1187  };
1188 
1189  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1190  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1191 
1192  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1193  }
1194 };
1195 
1196 struct MemorySpaceCastOpLowering
1197  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1198  using ConvertOpToLLVMPattern<
1199  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1200 
1201  LogicalResult
1202  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1203  ConversionPatternRewriter &rewriter) const override {
1204  Location loc = op.getLoc();
1205 
1206  Type resultType = op.getDest().getType();
1207  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1208  auto resultDescType =
1209  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1210  Type newPtrType = resultDescType.getBody()[0];
1211 
1212  SmallVector<Value> descVals;
1213  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1214  descVals);
1215  descVals[0] =
1216  LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1217  descVals[1] =
1218  LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1219  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1220  resultTypeR, descVals);
1221  rewriter.replaceOp(op, result);
1222  return success();
1223  }
1224  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1225  // Since the type converter won't be doing this for us, get the address
1226  // space.
1227  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1228  FailureOr<unsigned> maybeSourceAddrSpace =
1229  getTypeConverter()->getMemRefAddressSpace(sourceType);
1230  if (failed(maybeSourceAddrSpace))
1231  return rewriter.notifyMatchFailure(loc,
1232  "non-integer source address space");
1233  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1234  FailureOr<unsigned> maybeResultAddrSpace =
1235  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1236  if (failed(maybeResultAddrSpace))
1237  return rewriter.notifyMatchFailure(loc,
1238  "non-integer result address space");
1239  unsigned resultAddrSpace = *maybeResultAddrSpace;
1240 
1241  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1242  Value rank = sourceDesc.rank(rewriter, loc);
1243  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1244 
1245  // Create and allocate storage for new memref descriptor.
1246  auto result = UnrankedMemRefDescriptor::poison(
1247  rewriter, loc, typeConverter->convertType(resultTypeU));
1248  result.setRank(rewriter, loc, rank);
1249  Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1250  rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1251  Value resultUnderlyingDesc =
1252  LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1253  rewriter.getI8Type(), resultUnderlyingSize);
1254  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1255 
1256  // Copy pointers, performing address space casts.
1257  auto sourceElemPtrType =
1258  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1259  auto resultElemPtrType =
1260  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1261 
1262  Value allocatedPtr = sourceDesc.allocatedPtr(
1263  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1264  Value alignedPtr =
1265  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1266  sourceUnderlyingDesc, sourceElemPtrType);
1267  allocatedPtr = LLVM::AddrSpaceCastOp::create(
1268  rewriter, loc, resultElemPtrType, allocatedPtr);
1269  alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1270  resultElemPtrType, alignedPtr);
1271 
1272  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1273  resultElemPtrType, allocatedPtr);
1274  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1275  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1276 
1277  // Copy all the index-valued operands.
1278  Value sourceIndexVals =
1279  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1280  sourceUnderlyingDesc, sourceElemPtrType);
1281  Value resultIndexVals =
1282  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1283  resultUnderlyingDesc, resultElemPtrType);
1284 
1285  int64_t bytesToSkip =
1286  2 * llvm::divideCeil(
1287  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1288  Value bytesToSkipConst = LLVM::ConstantOp::create(
1289  rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1290  Value copySize =
1291  LLVM::SubOp::create(rewriter, loc, getIndexType(),
1292  resultUnderlyingSize, bytesToSkipConst);
1293  LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1294  copySize, /*isVolatile=*/false);
1295 
1296  rewriter.replaceOp(op, ValueRange{result});
1297  return success();
1298  }
1299  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1300  }
1301 };
1302 
1303 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
1304 /// memref type. In unranked case, the fields are extracted from the underlying
1305 /// ranked descriptor.
1306 static void extractPointersAndOffset(Location loc,
1307  ConversionPatternRewriter &rewriter,
1308  const LLVMTypeConverter &typeConverter,
1309  Value originalOperand,
1310  Value convertedOperand,
1311  Value *allocatedPtr, Value *alignedPtr,
1312  Value *offset = nullptr) {
1313  Type operandType = originalOperand.getType();
1314  if (isa<MemRefType>(operandType)) {
1315  MemRefDescriptor desc(convertedOperand);
1316  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1317  *alignedPtr = desc.alignedPtr(rewriter, loc);
1318  if (offset != nullptr)
1319  *offset = desc.offset(rewriter, loc);
1320  return;
1321  }
1322 
1323  // These will all cause assert()s on unconvertible types.
1324  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1325  cast<UnrankedMemRefType>(operandType));
1326  auto elementPtrType =
1327  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1328 
1329  // Extract pointer to the underlying ranked memref descriptor and cast it to
1330  // ElemType**.
1331  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1332  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1333 
1335  rewriter, loc, underlyingDescPtr, elementPtrType);
1337  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1338  if (offset != nullptr) {
1340  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1341  }
1342 }
1343 
1344 struct MemRefReinterpretCastOpLowering
1345  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1346  using ConvertOpToLLVMPattern<
1347  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1348 
1349  LogicalResult
1350  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1351  ConversionPatternRewriter &rewriter) const override {
1352  Type srcType = castOp.getSource().getType();
1353 
1354  Value descriptor;
1355  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1356  adaptor, &descriptor)))
1357  return failure();
1358  rewriter.replaceOp(castOp, {descriptor});
1359  return success();
1360  }
1361 
1362 private:
1363  LogicalResult convertSourceMemRefToDescriptor(
1364  ConversionPatternRewriter &rewriter, Type srcType,
1365  memref::ReinterpretCastOp castOp,
1366  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1367  MemRefType targetMemRefType =
1368  cast<MemRefType>(castOp.getResult().getType());
1369  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1370  typeConverter->convertType(targetMemRefType));
1371  if (!llvmTargetDescriptorTy)
1372  return failure();
1373 
1374  // Create descriptor.
1375  Location loc = castOp.getLoc();
1376  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1377 
1378  // Set allocated and aligned pointers.
1379  Value allocatedPtr, alignedPtr;
1380  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1381  castOp.getSource(), adaptor.getSource(),
1382  &allocatedPtr, &alignedPtr);
1383  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1384  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1385 
1386  // Set offset.
1387  if (castOp.isDynamicOffset(0))
1388  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1389  else
1390  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1391 
1392  // Set sizes and strides.
1393  unsigned dynSizeId = 0;
1394  unsigned dynStrideId = 0;
1395  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1396  if (castOp.isDynamicSize(i))
1397  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1398  else
1399  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1400 
1401  if (castOp.isDynamicStride(i))
1402  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1403  else
1404  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1405  }
1406  *descriptor = desc;
1407  return success();
1408  }
1409 };
1410 
1411 struct MemRefReshapeOpLowering
1412  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1414 
1415  LogicalResult
1416  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1417  ConversionPatternRewriter &rewriter) const override {
1418  Type srcType = reshapeOp.getSource().getType();
1419 
1420  Value descriptor;
1421  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1422  adaptor, &descriptor)))
1423  return failure();
1424  rewriter.replaceOp(reshapeOp, {descriptor});
1425  return success();
1426  }
1427 
1428 private:
1429  LogicalResult
1430  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1431  Type srcType, memref::ReshapeOp reshapeOp,
1432  memref::ReshapeOp::Adaptor adaptor,
1433  Value *descriptor) const {
1434  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1435  if (shapeMemRefType.hasStaticShape()) {
1436  MemRefType targetMemRefType =
1437  cast<MemRefType>(reshapeOp.getResult().getType());
1438  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1439  typeConverter->convertType(targetMemRefType));
1440  if (!llvmTargetDescriptorTy)
1441  return failure();
1442 
1443  // Create descriptor.
1444  Location loc = reshapeOp.getLoc();
1445  auto desc =
1446  MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1447 
1448  // Set allocated and aligned pointers.
1449  Value allocatedPtr, alignedPtr;
1450  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1451  reshapeOp.getSource(), adaptor.getSource(),
1452  &allocatedPtr, &alignedPtr);
1453  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1454  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1455 
1456  // Extract the offset and strides from the type.
1457  int64_t offset;
1458  SmallVector<int64_t> strides;
1459  if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1460  return rewriter.notifyMatchFailure(
1461  reshapeOp, "failed to get stride and offset exprs");
1462 
1463  if (!isStaticStrideOrOffset(offset))
1464  return rewriter.notifyMatchFailure(reshapeOp,
1465  "dynamic offset is unsupported");
1466 
1467  desc.setConstantOffset(rewriter, loc, offset);
1468 
1469  assert(targetMemRefType.getLayout().isIdentity() &&
1470  "Identity layout map is a precondition of a valid reshape op");
1471 
1472  Type indexType = getIndexType();
1473  Value stride = nullptr;
1474  int64_t targetRank = targetMemRefType.getRank();
1475  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1476  if (ShapedType::isStatic(strides[i])) {
1477  // If the stride for this dimension is dynamic, then use the product
1478  // of the sizes of the inner dimensions.
1479  stride =
1480  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1481  } else if (!stride) {
1482  // `stride` is null only in the first iteration of the loop. However,
1483  // since the target memref has an identity layout, we can safely set
1484  // the innermost stride to 1.
1485  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1486  }
1487 
1488  Value dimSize;
1489  // If the size of this dimension is dynamic, then load it at runtime
1490  // from the shape operand.
1491  if (!targetMemRefType.isDynamicDim(i)) {
1492  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1493  targetMemRefType.getDimSize(i));
1494  } else {
1495  Value shapeOp = reshapeOp.getShape();
1496  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1497  dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1498  Type indexType = getIndexType();
1499  if (dimSize.getType() != indexType)
1500  dimSize = typeConverter->materializeTargetConversion(
1501  rewriter, loc, indexType, dimSize);
1502  assert(dimSize && "Invalid memref element type");
1503  }
1504 
1505  desc.setSize(rewriter, loc, i, dimSize);
1506  desc.setStride(rewriter, loc, i, stride);
1507 
1508  // Prepare the stride value for the next dimension.
1509  stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1510  }
1511 
1512  *descriptor = desc;
1513  return success();
1514  }
1515 
1516  // The shape is a rank-1 tensor with unknown length.
1517  Location loc = reshapeOp.getLoc();
1518  MemRefDescriptor shapeDesc(adaptor.getShape());
1519  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1520 
1521  // Extract address space and element type.
1522  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1523  unsigned addressSpace =
1524  *getTypeConverter()->getMemRefAddressSpace(targetType);
1525 
1526  // Create the unranked memref descriptor that holds the ranked one. The
1527  // inner descriptor is allocated on stack.
1528  auto targetDesc = UnrankedMemRefDescriptor::poison(
1529  rewriter, loc, typeConverter->convertType(targetType));
1530  targetDesc.setRank(rewriter, loc, resultRank);
1532  rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1533  Value underlyingDescPtr = LLVM::AllocaOp::create(
1534  rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1535  allocationSize);
1536  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1537 
1538  // Extract pointers and offset from the source memref.
1539  Value allocatedPtr, alignedPtr, offset;
1540  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1541  reshapeOp.getSource(), adaptor.getSource(),
1542  &allocatedPtr, &alignedPtr, &offset);
1543 
1544  // Set pointers and offset.
1545  auto elementPtrType =
1546  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1547 
1548  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1549  elementPtrType, allocatedPtr);
1550  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1551  underlyingDescPtr, elementPtrType,
1552  alignedPtr);
1553  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1554  underlyingDescPtr, elementPtrType,
1555  offset);
1556 
1557  // Use the offset pointer as base for further addressing. Copy over the new
1558  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1559  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1560  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1561  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1562  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1563  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1564  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1565  Value resultRankMinusOne =
1566  LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1567 
1568  Block *initBlock = rewriter.getInsertionBlock();
1569  Type indexType = getTypeConverter()->getIndexType();
1570  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1571 
1572  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1573  {indexType, indexType}, {loc, loc});
1574 
1575  // Move the remaining initBlock ops to condBlock.
1576  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1577  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1578 
1579  rewriter.setInsertionPointToEnd(initBlock);
1580  LLVM::BrOp::create(rewriter, loc,
1581  ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1582  rewriter.setInsertionPointToStart(condBlock);
1583  Value indexArg = condBlock->getArgument(0);
1584  Value strideArg = condBlock->getArgument(1);
1585 
1586  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1587  Value pred = LLVM::ICmpOp::create(
1588  rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1589  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1590 
1591  Block *bodyBlock =
1592  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1593  rewriter.setInsertionPointToStart(bodyBlock);
1594 
1595  // Copy size from shape to descriptor.
1596  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1597  Value sizeLoadGep = LLVM::GEPOp::create(
1598  rewriter, loc, llvmIndexPtrType,
1599  typeConverter->convertType(shapeMemRefType.getElementType()),
1600  shapeOperandPtr, indexArg);
1601  Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1602  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1603  targetSizesBase, indexArg, size);
1604 
1605  // Write stride value and compute next one.
1606  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1607  targetStridesBase, indexArg, strideArg);
1608  Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1609 
1610  // Decrement loop counter and branch back.
1611  Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1612  LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
1613  condBlock);
1614 
1615  Block *remainder =
1616  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1617 
1618  // Hook up the cond exit to the remainder.
1619  rewriter.setInsertionPointToEnd(condBlock);
1620  LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
1621  remainder, ValueRange());
1622 
1623  // Reset position to beginning of new remainder block.
1624  rewriter.setInsertionPointToStart(remainder);
1625 
1626  *descriptor = targetDesc;
1627  return success();
1628  }
1629 };
1630 
1631 /// RessociatingReshapeOp must be expanded before we reach this stage.
1632 /// Report that information.
1633 template <typename ReshapeOp>
1634 class ReassociatingReshapeOpConversion
1635  : public ConvertOpToLLVMPattern<ReshapeOp> {
1636 public:
1638  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1639 
1640  LogicalResult
1641  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1642  ConversionPatternRewriter &rewriter) const override {
1643  return rewriter.notifyMatchFailure(
1644  reshapeOp,
1645  "reassociation operations should have been expanded beforehand");
1646  }
1647 };
1648 
1649 /// Subviews must be expanded before we reach this stage.
1650 /// Report that information.
1651 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1653 
1654  LogicalResult
1655  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1656  ConversionPatternRewriter &rewriter) const override {
1657  return rewriter.notifyMatchFailure(
1658  subViewOp, "subview operations should have been expanded beforehand");
1659  }
1660 };
1661 
1662 /// Conversion pattern that transforms a transpose op into:
1663 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1664 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1665 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1666 /// and stride. Size and stride are permutations of the original values.
1667 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1668 /// The transpose op is replaced by the alloca'ed pointer.
1669 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1670 public:
1672 
1673  LogicalResult
1674  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1675  ConversionPatternRewriter &rewriter) const override {
1676  auto loc = transposeOp.getLoc();
1677  MemRefDescriptor viewMemRef(adaptor.getIn());
1678 
1679  // No permutation, early exit.
1680  if (transposeOp.getPermutation().isIdentity())
1681  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1682 
1683  auto targetMemRef = MemRefDescriptor::poison(
1684  rewriter, loc,
1685  typeConverter->convertType(transposeOp.getIn().getType()));
1686 
1687  // Copy the base and aligned pointers from the old descriptor to the new
1688  // one.
1689  targetMemRef.setAllocatedPtr(rewriter, loc,
1690  viewMemRef.allocatedPtr(rewriter, loc));
1691  targetMemRef.setAlignedPtr(rewriter, loc,
1692  viewMemRef.alignedPtr(rewriter, loc));
1693 
1694  // Copy the offset pointer from the old descriptor to the new one.
1695  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1696 
1697  // Iterate over the dimensions and apply size/stride permutation:
1698  // When enumerating the results of the permutation map, the enumeration
1699  // index is the index into the target dimensions and the DimExpr points to
1700  // the dimension of the source memref.
1701  for (const auto &en :
1702  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1703  int targetPos = en.index();
1704  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1705  targetMemRef.setSize(rewriter, loc, targetPos,
1706  viewMemRef.size(rewriter, loc, sourcePos));
1707  targetMemRef.setStride(rewriter, loc, targetPos,
1708  viewMemRef.stride(rewriter, loc, sourcePos));
1709  }
1710 
1711  rewriter.replaceOp(transposeOp, {targetMemRef});
1712  return success();
1713  }
1714 };
1715 
1716 /// Conversion pattern that transforms an op into:
1717 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1718 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1719 /// and stride.
1720 /// The view op is replaced by the descriptor.
1721 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1723 
1724  // Build and return the value for the idx^th shape dimension, either by
1725  // returning the constant shape dimension or counting the proper dynamic size.
1726  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1727  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1728  Type indexType) const {
1729  assert(idx < shape.size());
1730  if (ShapedType::isStatic(shape[idx]))
1731  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1732  // Count the number of dynamic dims in range [0, idx]
1733  unsigned nDynamic =
1734  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1735  return dynamicSizes[nDynamic];
1736  }
1737 
1738  // Build and return the idx^th stride, either by returning the constant stride
1739  // or by computing the dynamic stride from the current `runningStride` and
1740  // `nextSize`. The caller should keep a running stride and update it with the
1741  // result returned by this function.
1743  ArrayRef<int64_t> strides, Value nextSize,
1744  Value runningStride, unsigned idx, Type indexType) const {
1745  assert(idx < strides.size());
1746  if (ShapedType::isStatic(strides[idx]))
1747  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1748  if (nextSize)
1749  return runningStride
1750  ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1751  : nextSize;
1752  assert(!runningStride);
1753  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1754  }
1755 
1756  LogicalResult
1757  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1758  ConversionPatternRewriter &rewriter) const override {
1759  auto loc = viewOp.getLoc();
1760 
1761  auto viewMemRefType = viewOp.getType();
1762  auto targetElementTy =
1763  typeConverter->convertType(viewMemRefType.getElementType());
1764  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1765  if (!targetDescTy || !targetElementTy ||
1766  !LLVM::isCompatibleType(targetElementTy) ||
1767  !LLVM::isCompatibleType(targetDescTy))
1768  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1769  failure();
1770 
1771  int64_t offset;
1772  SmallVector<int64_t, 4> strides;
1773  auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1774  if (failed(successStrides))
1775  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1776  assert(offset == 0 && "expected offset to be 0");
1777 
1778  // Target memref must be contiguous in memory (innermost stride is 1), or
1779  // empty (special case when at least one of the memref dimensions is 0).
1780  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1781  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1782  failure();
1783 
1784  // Create the descriptor.
1785  MemRefDescriptor sourceMemRef(adaptor.getSource());
1786  auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1787 
1788  // Field 1: Copy the allocated pointer, used for malloc/free.
1789  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1790  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1791  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1792 
1793  // Field 2: Copy the actual aligned pointer to payload.
1794  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1795  alignedPtr = LLVM::GEPOp::create(
1796  rewriter, loc, alignedPtr.getType(),
1797  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1798  adaptor.getByteShift());
1799 
1800  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1801 
1802  Type indexType = getIndexType();
1803  // Field 3: The offset in the resulting type must be 0. This is
1804  // because of the type change: an offset on srcType* may not be
1805  // expressible as an offset on dstType*.
1806  targetMemRef.setOffset(
1807  rewriter, loc,
1808  createIndexAttrConstant(rewriter, loc, indexType, offset));
1809 
1810  // Early exit for 0-D corner case.
1811  if (viewMemRefType.getRank() == 0)
1812  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1813 
1814  // Fields 4 and 5: Update sizes and strides.
1815  Value stride = nullptr, nextSize = nullptr;
1816  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1817  // Update size.
1818  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1819  adaptor.getSizes(), i, indexType);
1820  targetMemRef.setSize(rewriter, loc, i, size);
1821  // Update stride.
1822  stride =
1823  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1824  targetMemRef.setStride(rewriter, loc, i, stride);
1825  nextSize = size;
1826  }
1827 
1828  rewriter.replaceOp(viewOp, {targetMemRef});
1829  return success();
1830  }
1831 };
1832 
1833 //===----------------------------------------------------------------------===//
1834 // AtomicRMWOpLowering
1835 //===----------------------------------------------------------------------===//
1836 
1837 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1838 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1839 static std::optional<LLVM::AtomicBinOp>
1840 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1841  switch (atomicOp.getKind()) {
1842  case arith::AtomicRMWKind::addf:
1843  return LLVM::AtomicBinOp::fadd;
1844  case arith::AtomicRMWKind::addi:
1845  return LLVM::AtomicBinOp::add;
1846  case arith::AtomicRMWKind::assign:
1847  return LLVM::AtomicBinOp::xchg;
1848  case arith::AtomicRMWKind::maximumf:
1849  // TODO: remove this by end of 2025.
1850  LDBG() << "the lowering of memref.atomicrmw maximumf changed "
1851  "from fmax to fmaximum, expect more NaNs";
1852  return LLVM::AtomicBinOp::fmaximum;
1853  case arith::AtomicRMWKind::maxnumf:
1854  return LLVM::AtomicBinOp::fmax;
1855  case arith::AtomicRMWKind::maxs:
1856  return LLVM::AtomicBinOp::max;
1857  case arith::AtomicRMWKind::maxu:
1858  return LLVM::AtomicBinOp::umax;
1859  case arith::AtomicRMWKind::minimumf:
1860  // TODO: remove this by end of 2025.
1861  LDBG() << "the lowering of memref.atomicrmw minimum changed "
1862  "from fmin to fminimum, expect more NaNs";
1863  return LLVM::AtomicBinOp::fminimum;
1864  case arith::AtomicRMWKind::minnumf:
1865  return LLVM::AtomicBinOp::fmin;
1866  case arith::AtomicRMWKind::mins:
1867  return LLVM::AtomicBinOp::min;
1868  case arith::AtomicRMWKind::minu:
1869  return LLVM::AtomicBinOp::umin;
1870  case arith::AtomicRMWKind::ori:
1871  return LLVM::AtomicBinOp::_or;
1872  case arith::AtomicRMWKind::xori:
1873  return LLVM::AtomicBinOp::_xor;
1874  case arith::AtomicRMWKind::andi:
1875  return LLVM::AtomicBinOp::_and;
1876  default:
1877  return std::nullopt;
1878  }
1879  llvm_unreachable("Invalid AtomicRMWKind");
1880 }
1881 
1882 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1883  using Base::Base;
1884 
1885  LogicalResult
1886  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1887  ConversionPatternRewriter &rewriter) const override {
1888  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1889  if (!maybeKind)
1890  return failure();
1891  auto memRefType = atomicOp.getMemRefType();
1892  SmallVector<int64_t> strides;
1893  int64_t offset;
1894  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1895  return failure();
1896  auto dataPtr =
1897  getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1898  adaptor.getMemref(), adaptor.getIndices());
1899  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1900  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1901  LLVM::AtomicOrdering::acq_rel);
1902  return success();
1903  }
1904 };
1905 
1906 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1907 class ConvertExtractAlignedPointerAsIndex
1908  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1909 public:
1910  using ConvertOpToLLVMPattern<
1911  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1912 
1913  LogicalResult
1914  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1915  OpAdaptor adaptor,
1916  ConversionPatternRewriter &rewriter) const override {
1917  BaseMemRefType sourceTy = extractOp.getSource().getType();
1918 
1919  Value alignedPtr;
1920  if (sourceTy.hasRank()) {
1921  MemRefDescriptor desc(adaptor.getSource());
1922  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1923  } else {
1924  auto elementPtrTy = LLVM::LLVMPointerType::get(
1925  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1926 
1927  UnrankedMemRefDescriptor desc(adaptor.getSource());
1928  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1929 
1931  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1932  elementPtrTy);
1933  }
1934 
1935  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1936  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1937  return success();
1938  }
1939 };
1940 
1941 /// Materialize the MemRef descriptor represented by the results of
1942 /// ExtractStridedMetadataOp.
1943 class ExtractStridedMetadataOpLowering
1944  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1945 public:
1946  using ConvertOpToLLVMPattern<
1947  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1948 
1949  LogicalResult
1950  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1951  OpAdaptor adaptor,
1952  ConversionPatternRewriter &rewriter) const override {
1953 
1954  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1955  return failure();
1956 
1957  // Create the descriptor.
1958  MemRefDescriptor sourceMemRef(adaptor.getSource());
1959  Location loc = extractStridedMetadataOp.getLoc();
1960  Value source = extractStridedMetadataOp.getSource();
1961 
1962  auto sourceMemRefType = cast<MemRefType>(source.getType());
1963  int64_t rank = sourceMemRefType.getRank();
1964  SmallVector<Value> results;
1965  results.reserve(2 + rank * 2);
1966 
1967  // Base buffer.
1968  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1969  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1971  rewriter, loc, *getTypeConverter(),
1972  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1973  baseBuffer, alignedBuffer);
1974  results.push_back((Value)dstMemRef);
1975 
1976  // Offset.
1977  results.push_back(sourceMemRef.offset(rewriter, loc));
1978 
1979  // Sizes.
1980  for (unsigned i = 0; i < rank; ++i)
1981  results.push_back(sourceMemRef.size(rewriter, loc, i));
1982  // Strides.
1983  for (unsigned i = 0; i < rank; ++i)
1984  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1985 
1986  rewriter.replaceOp(extractStridedMetadataOp, results);
1987  return success();
1988  }
1989 };
1990 
1991 } // namespace
1992 
1994  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1995  SymbolTableCollection *symbolTables) {
1996  // clang-format off
1997  patterns.add<
1998  AllocaOpLowering,
1999  AllocaScopeOpLowering,
2000  AtomicRMWOpLowering,
2001  AssumeAlignmentOpLowering,
2002  ConvertExtractAlignedPointerAsIndex,
2003  DimOpLowering,
2004  ExtractStridedMetadataOpLowering,
2005  GenericAtomicRMWOpLowering,
2006  GetGlobalMemrefOpLowering,
2007  LoadOpLowering,
2008  MemRefCastOpLowering,
2009  MemorySpaceCastOpLowering,
2010  MemRefReinterpretCastOpLowering,
2011  MemRefReshapeOpLowering,
2012  PrefetchOpLowering,
2013  RankOpLowering,
2014  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2015  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2016  StoreOpLowering,
2017  SubViewOpLowering,
2019  ViewOpLowering>(converter);
2020  // clang-format on
2021  patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2022  symbolTables);
2023  auto allocLowering = converter.getOptions().allocLowering;
2025  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2026  symbolTables);
2027  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2028  patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2029 }
2030 
2031 namespace {
2032 struct FinalizeMemRefToLLVMConversionPass
2033  : public impl::FinalizeMemRefToLLVMConversionPassBase<
2034  FinalizeMemRefToLLVMConversionPass> {
2035  using FinalizeMemRefToLLVMConversionPassBase::
2036  FinalizeMemRefToLLVMConversionPassBase;
2037 
2038  void runOnOperation() override {
2039  Operation *op = getOperation();
2040  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2042  dataLayoutAnalysis.getAtOrAbove(op));
2043  options.allocLowering =
2046 
2047  options.useGenericFunctions = useGenericFunctions;
2048 
2049  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2050  options.overrideIndexBitwidth(indexBitwidth);
2051 
2052  LLVMTypeConverter typeConverter(&getContext(), options,
2053  &dataLayoutAnalysis);
2055  SymbolTableCollection symbolTables;
2057  &symbolTables);
2058  LLVMConversionTarget target(getContext());
2059  target.addLegalOp<func::FuncOp>();
2060  if (failed(applyPartialConversion(op, target, std::move(patterns))))
2061  signalPassFailure();
2062  }
2063 };
2064 
2065 /// Implement the interface to convert MemRef to LLVM.
2066 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2068  void loadDependentDialects(MLIRContext *context) const final {
2069  context->loadDialect<LLVM::LLVMDialect>();
2070  }
2071 
2072  /// Hook for derived dialect interface to provide conversion patterns
2073  /// and mark dialect legal for the conversion target.
2074  void populateConvertToLLVMConversionPatterns(
2075  ConversionTarget &target, LLVMTypeConverter &typeConverter,
2076  RewritePatternSet &patterns) const final {
2078  }
2079 };
2080 
2081 } // namespace
2082 
2084  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2085  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2086  });
2087 }
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI8Type()
Definition: Builders.cpp:58
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:23
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:497
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:664
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.