MLIR  21.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 
39 namespace {
40 
41 static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42  return !ShapedType::isDynamic(strideOrOffset);
43 }
44 
45 static FailureOr<LLVM::LLVMFuncOp>
46 getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
47  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48 
49  if (useGenericFn)
50  return LLVM::lookupOrCreateGenericFreeFn(module);
51 
52  return LLVM::lookupOrCreateFreeFn(module);
53 }
54 
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56  AllocOpLowering(const LLVMTypeConverter &converter)
57  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58  converter) {}
59  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60  Location loc, Value sizeBytes,
61  Operation *op) const override {
62  return allocateBufferManuallyAlign(
63  rewriter, loc, sizeBytes, op,
64  getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
65  }
66 };
67 
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69  AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71  converter) {}
72  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73  Location loc, Value sizeBytes,
74  Operation *op) const override {
75  Value ptr = allocateBufferAutoAlign(
76  rewriter, loc, sizeBytes, op, &defaultLayout,
77  alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78  &defaultLayout));
79  if (!ptr)
80  return std::make_tuple(Value(), Value());
81  return std::make_tuple(ptr, ptr);
82  }
83 
84 private:
85  /// Default layout to use in absence of the corresponding analysis.
86  DataLayout defaultLayout;
87 };
88 
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90  AllocaOpLowering(const LLVMTypeConverter &converter)
91  : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92  converter) {
93  setRequiresNumElements();
94  }
95 
96  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97  /// is set to null for stack allocations. `accessAlignment` is set if
98  /// alignment is needed post allocation (for eg. in conjunction with malloc).
99  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100  Location loc, Value size,
101  Operation *op) const override {
102 
103  // With alloca, one gets a pointer to the element type right away.
104  // For stack allocations.
105  auto allocaOp = cast<memref::AllocaOp>(op);
106  auto elementType =
107  typeConverter->convertType(allocaOp.getType().getElementType());
108  unsigned addrSpace =
109  *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110  auto elementPtrType =
111  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
112 
113  auto allocatedElementPtr =
114  rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115  allocaOp.getAlignment().value_or(0));
116 
117  return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
118  }
119 };
120 
121 struct AllocaScopeOpLowering
122  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
124 
125  LogicalResult
126  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  OpBuilder::InsertionGuard guard(rewriter);
129  Location loc = allocaScopeOp.getLoc();
130 
131  // Split the current block before the AllocaScopeOp to create the inlining
132  // point.
133  auto *currentBlock = rewriter.getInsertionBlock();
134  auto *remainingOpsBlock =
135  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136  Block *continueBlock;
137  if (allocaScopeOp.getNumResults() == 0) {
138  continueBlock = remainingOpsBlock;
139  } else {
140  continueBlock = rewriter.createBlock(
141  remainingOpsBlock, allocaScopeOp.getResultTypes(),
142  SmallVector<Location>(allocaScopeOp->getNumResults(),
143  allocaScopeOp.getLoc()));
144  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
145  }
146 
147  // Inline body region.
148  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
151 
152  // Save stack and then branch into the body of the region.
153  rewriter.setInsertionPointToEnd(currentBlock);
154  auto stackSaveOp =
155  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
157 
158  // Replace the alloca_scope return with a branch that jumps out of the body.
159  // Stack restore before leaving the body region.
160  rewriter.setInsertionPointToEnd(afterBody);
161  auto returnOp =
162  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164  returnOp, returnOp.getResults(), continueBlock);
165 
166  // Insert stack restore before jumping out the body of the region.
167  rewriter.setInsertionPoint(branchOp);
168  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
169 
170  // Replace the op with values return from the body region.
171  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
172 
173  return success();
174  }
175 };
176 
177 struct AssumeAlignmentOpLowering
178  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
180  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182  : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
183 
184  LogicalResult
185  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187  Value memref = adaptor.getMemref();
188  unsigned alignment = op.getAlignment();
189  auto loc = op.getLoc();
190 
191  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192  Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193  rewriter);
194 
195  // Emit llvm.assume(true) ["align"(memref, alignment)].
196  // This is more direct than ptrtoint-based checks, is explicitly supported,
197  // and works with non-integral address spaces.
198  Value trueCond =
199  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
200  Value alignmentConst =
201  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
202  rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
203  alignmentConst);
204 
205  rewriter.eraseOp(op);
206  return success();
207  }
208 };
209 
210 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
211 // The memref descriptor being an SSA value, there is no need to clean it up
212 // in any way.
213 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
215 
216  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
217  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
218 
219  LogicalResult
220  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221  ConversionPatternRewriter &rewriter) const override {
222  // Insert the `free` declaration if it is not already present.
223  FailureOr<LLVM::LLVMFuncOp> freeFunc =
224  getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225  if (failed(freeFunc))
226  return failure();
227  Value allocatedPtr;
228  if (auto unrankedTy =
229  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
230  auto elementPtrTy = LLVM::LLVMPointerType::get(
231  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
233  rewriter, op.getLoc(),
234  UnrankedMemRefDescriptor(adaptor.getMemref())
235  .memRefDescPtr(rewriter, op.getLoc()),
236  elementPtrTy);
237  } else {
238  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
239  .allocatedPtr(rewriter, op.getLoc());
240  }
241  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
242  allocatedPtr);
243  return success();
244  }
245 };
246 
247 // A `dim` is converted to a constant for static sizes and to an access to the
248 // size stored in the memref descriptor for dynamic sizes.
249 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
251 
252  LogicalResult
253  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
254  ConversionPatternRewriter &rewriter) const override {
255  Type operandType = dimOp.getSource().getType();
256  if (isa<UnrankedMemRefType>(operandType)) {
257  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
258  operandType, dimOp, adaptor.getOperands(), rewriter);
259  if (failed(extractedSize))
260  return failure();
261  rewriter.replaceOp(dimOp, {*extractedSize});
262  return success();
263  }
264  if (isa<MemRefType>(operandType)) {
265  rewriter.replaceOp(
266  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
267  adaptor.getOperands(), rewriter)});
268  return success();
269  }
270  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
271  }
272 
273 private:
274  FailureOr<Value>
275  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
276  OpAdaptor adaptor,
277  ConversionPatternRewriter &rewriter) const {
278  Location loc = dimOp.getLoc();
279 
280  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
281  auto scalarMemRefType =
282  MemRefType::get({}, unrankedMemRefType.getElementType());
283  FailureOr<unsigned> maybeAddressSpace =
284  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
285  if (failed(maybeAddressSpace)) {
286  dimOp.emitOpError("memref memory space must be convertible to an integer "
287  "address space");
288  return failure();
289  }
290  unsigned addressSpace = *maybeAddressSpace;
291 
292  // Extract pointer to the underlying ranked descriptor and bitcast it to a
293  // memref<element_type> descriptor pointer to minimize the number of GEP
294  // operations.
295  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
296  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
297 
298  Type elementType = typeConverter->convertType(scalarMemRefType);
299 
300  // Get pointer to offset field of memref<element_type> descriptor.
301  auto indexPtrTy =
302  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
303  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
304  loc, indexPtrTy, elementType, underlyingRankedDesc,
305  ArrayRef<LLVM::GEPArg>{0, 2});
306 
307  // The size value that we have to extract can be obtained using GEPop with
308  // `dimOp.index() + 1` index argument.
309  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
310  loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
311  adaptor.getIndex());
312  Value sizePtr = rewriter.create<LLVM::GEPOp>(
313  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
314  idxPlusOne);
315  return rewriter
316  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
317  .getResult();
318  }
319 
320  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
321  if (auto idx = dimOp.getConstantIndex())
322  return idx;
323 
324  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
325  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
326 
327  return std::nullopt;
328  }
329 
330  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
331  OpAdaptor adaptor,
332  ConversionPatternRewriter &rewriter) const {
333  Location loc = dimOp.getLoc();
334 
335  // Take advantage if index is constant.
336  MemRefType memRefType = cast<MemRefType>(operandType);
337  Type indexType = getIndexType();
338  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
339  int64_t i = *index;
340  if (i >= 0 && i < memRefType.getRank()) {
341  if (memRefType.isDynamicDim(i)) {
342  // extract dynamic size from the memref descriptor.
343  MemRefDescriptor descriptor(adaptor.getSource());
344  return descriptor.size(rewriter, loc, i);
345  }
346  // Use constant for static size.
347  int64_t dimSize = memRefType.getDimSize(i);
348  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
349  }
350  }
351  Value index = adaptor.getIndex();
352  int64_t rank = memRefType.getRank();
353  MemRefDescriptor memrefDescriptor(adaptor.getSource());
354  return memrefDescriptor.size(rewriter, loc, index, rank);
355  }
356 };
357 
358 /// Common base for load and store operations on MemRefs. Restricts the match
359 /// to supported MemRef types. Provides functionality to emit code accessing a
360 /// specific element of the underlying data buffer.
361 template <typename Derived>
362 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
365  using Base = LoadStoreOpLowering<Derived>;
366 
367  LogicalResult match(Derived op) const override {
368  MemRefType type = op.getMemRefType();
369  return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
370  }
371 };
372 
373 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
374 /// retried until it succeeds in atomically storing a new value into memory.
375 ///
376 /// +---------------------------------+
377 /// | <code before the AtomicRMWOp> |
378 /// | <compute initial %loaded> |
379 /// | cf.br loop(%loaded) |
380 /// +---------------------------------+
381 /// |
382 /// -------| |
383 /// | v v
384 /// | +--------------------------------+
385 /// | | loop(%loaded): |
386 /// | | <body contents> |
387 /// | | %pair = cmpxchg |
388 /// | | %ok = %pair[0] |
389 /// | | %new = %pair[1] |
390 /// | | cf.cond_br %ok, end, loop(%new) |
391 /// | +--------------------------------+
392 /// | | |
393 /// |----------- |
394 /// v
395 /// +--------------------------------+
396 /// | end: |
397 /// | <code after the AtomicRMWOp> |
398 /// +--------------------------------+
399 ///
400 struct GenericAtomicRMWOpLowering
401  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
402  using Base::Base;
403 
404  LogicalResult
405  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
406  ConversionPatternRewriter &rewriter) const override {
407  auto loc = atomicOp.getLoc();
408  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
409 
410  // Split the block into initial, loop, and ending parts.
411  auto *initBlock = rewriter.getInsertionBlock();
412  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
413  loopBlock->addArgument(valueType, loc);
414 
415  auto *endBlock =
416  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
417 
418  // Compute the loaded value and branch to the loop block.
419  rewriter.setInsertionPointToEnd(initBlock);
420  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
421  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
422  adaptor.getIndices(), rewriter);
423  Value init = rewriter.create<LLVM::LoadOp>(
424  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
425  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
426 
427  // Prepare the body of the loop block.
428  rewriter.setInsertionPointToStart(loopBlock);
429 
430  // Clone the GenericAtomicRMWOp region and extract the result.
431  auto loopArgument = loopBlock->getArgument(0);
432  IRMapping mapping;
433  mapping.map(atomicOp.getCurrentValue(), loopArgument);
434  Block &entryBlock = atomicOp.body().front();
435  for (auto &nestedOp : entryBlock.without_terminator()) {
436  Operation *clone = rewriter.clone(nestedOp, mapping);
437  mapping.map(nestedOp.getResults(), clone->getResults());
438  }
439  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
440 
441  // Prepare the epilog of the loop block.
442  // Append the cmpxchg op to the end of the loop block.
443  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
444  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
445  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
446  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
447  // Extract the %new_loaded and %ok values from the pair.
448  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
449  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
450 
451  // Conditionally branch to the end or back to the loop depending on %ok.
452  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
453  loopBlock, newLoaded);
454 
455  rewriter.setInsertionPointToEnd(endBlock);
456 
457  // The 'result' of the atomic_rmw op is the newly loaded value.
458  rewriter.replaceOp(atomicOp, {newLoaded});
459 
460  return success();
461  }
462 };
463 
464 /// Returns the LLVM type of the global variable given the memref type `type`.
465 static Type
466 convertGlobalMemrefTypeToLLVM(MemRefType type,
467  const LLVMTypeConverter &typeConverter) {
468  // LLVM type for a global memref will be a multi-dimension array. For
469  // declarations or uninitialized global memrefs, we can potentially flatten
470  // this to a 1D array. However, for memref.global's with an initial value,
471  // we do not intend to flatten the ElementsAttribute when going from std ->
472  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
473  Type elementType = typeConverter.convertType(type.getElementType());
474  Type arrayTy = elementType;
475  // Shape has the outermost dim at index 0, so need to walk it backwards
476  for (int64_t dim : llvm::reverse(type.getShape()))
477  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
478  return arrayTy;
479 }
480 
481 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
482 struct GlobalMemrefOpLowering
483  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
485 
486  LogicalResult
487  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
488  ConversionPatternRewriter &rewriter) const override {
489  MemRefType type = global.getType();
490  if (!isConvertibleAndHasIdentityMaps(type))
491  return failure();
492 
493  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
494 
495  LLVM::Linkage linkage =
496  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
497 
498  Attribute initialValue = nullptr;
499  if (!global.isExternal() && !global.isUninitialized()) {
500  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
501  initialValue = elementsAttr;
502 
503  // For scalar memrefs, the global variable created is of the element type,
504  // so unpack the elements attribute to extract the value.
505  if (type.getRank() == 0)
506  initialValue = elementsAttr.getSplatValue<Attribute>();
507  }
508 
509  uint64_t alignment = global.getAlignment().value_or(0);
510  FailureOr<unsigned> addressSpace =
511  getTypeConverter()->getMemRefAddressSpace(type);
512  if (failed(addressSpace))
513  return global.emitOpError(
514  "memory space cannot be converted to an integer address space");
515  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
516  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
517  initialValue, alignment, *addressSpace);
518  if (!global.isExternal() && global.isUninitialized()) {
519  rewriter.createBlock(&newGlobal.getInitializerRegion());
520  Value undef[] = {
521  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
522  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
523  }
524  return success();
525  }
526 };
527 
528 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
529 /// the first element stashed into the descriptor. This reuses
530 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
531 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
532  GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
533  : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
534  converter) {}
535 
536  /// Buffer "allocation" for memref.get_global op is getting the address of
537  /// the global variable referenced.
538  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
539  Location loc, Value sizeBytes,
540  Operation *op) const override {
541  auto getGlobalOp = cast<memref::GetGlobalOp>(op);
542  MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
543 
544  // This is called after a type conversion, which would have failed if this
545  // call fails.
546  FailureOr<unsigned> maybeAddressSpace =
547  getTypeConverter()->getMemRefAddressSpace(type);
548  if (failed(maybeAddressSpace))
549  return std::make_tuple(Value(), Value());
550  unsigned memSpace = *maybeAddressSpace;
551 
552  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
553  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
554  auto addressOf =
555  rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
556 
557  // Get the address of the first element in the array by creating a GEP with
558  // the address of the GV as the base, and (rank + 1) number of 0 indices.
559  auto gep = rewriter.create<LLVM::GEPOp>(
560  loc, ptrTy, arrayTy, addressOf,
561  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
562 
563  // We do not expect the memref obtained using `memref.get_global` to be
564  // ever deallocated. Set the allocated pointer to be known bad value to
565  // help debug if that ever happens.
566  auto intPtrType = getIntPtrType(memSpace);
567  Value deadBeefConst =
568  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
569  auto deadBeefPtr =
570  rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
571 
572  // Both allocated and aligned pointers are same. We could potentially stash
573  // a nullptr for the allocated pointer since we do not expect any dealloc.
574  return std::make_tuple(deadBeefPtr, gep);
575  }
576 };
577 
578 // Load operation is lowered to obtaining a pointer to the indexed element
579 // and loading it.
580 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
581  using Base::Base;
582 
583  LogicalResult
584  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
585  ConversionPatternRewriter &rewriter) const override {
586  auto type = loadOp.getMemRefType();
587 
588  Value dataPtr =
589  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
590  adaptor.getIndices(), rewriter);
591  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
592  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
593  false, loadOp.getNontemporal());
594  return success();
595  }
596 };
597 
598 // Store operation is lowered to obtaining a pointer to the indexed element,
599 // and storing the given value to it.
600 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
601  using Base::Base;
602 
603  LogicalResult
604  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
605  ConversionPatternRewriter &rewriter) const override {
606  auto type = op.getMemRefType();
607 
608  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
609  adaptor.getIndices(), rewriter);
610  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
611  0, false, op.getNontemporal());
612  return success();
613  }
614 };
615 
616 // The prefetch operation is lowered in a way similar to the load operation
617 // except that the llvm.prefetch operation is used for replacement.
618 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
619  using Base::Base;
620 
621  LogicalResult
622  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
623  ConversionPatternRewriter &rewriter) const override {
624  auto type = prefetchOp.getMemRefType();
625  auto loc = prefetchOp.getLoc();
626 
627  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
628  adaptor.getIndices(), rewriter);
629 
630  // Replace with llvm.prefetch.
631  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
632  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
633  IntegerAttr isData =
634  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
635  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
636  localityHint, isData);
637  return success();
638  }
639 };
640 
641 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
643 
644  LogicalResult
645  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
646  ConversionPatternRewriter &rewriter) const override {
647  Location loc = op.getLoc();
648  Type operandType = op.getMemref().getType();
649  if (dyn_cast<UnrankedMemRefType>(operandType)) {
650  UnrankedMemRefDescriptor desc(adaptor.getMemref());
651  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
652  return success();
653  }
654  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
655  Type indexType = getIndexType();
656  rewriter.replaceOp(op,
657  {createIndexAttrConstant(rewriter, loc, indexType,
658  rankedMemRefType.getRank())});
659  return success();
660  }
661  return failure();
662  }
663 };
664 
665 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
667 
668  LogicalResult match(memref::CastOp memRefCastOp) const override {
669  Type srcType = memRefCastOp.getOperand().getType();
670  Type dstType = memRefCastOp.getType();
671 
672  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
673  // used for type erasure. For now they must preserve underlying element type
674  // and require source and result type to have the same rank. Therefore,
675  // perform a sanity check that the underlying structs are the same. Once op
676  // semantics are relaxed we can revisit.
677  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
678  return success(typeConverter->convertType(srcType) ==
679  typeConverter->convertType(dstType));
680 
681  // At least one of the operands is unranked type
682  assert(isa<UnrankedMemRefType>(srcType) ||
683  isa<UnrankedMemRefType>(dstType));
684 
685  // Unranked to unranked cast is disallowed
686  return !(isa<UnrankedMemRefType>(srcType) &&
687  isa<UnrankedMemRefType>(dstType))
688  ? success()
689  : failure();
690  }
691 
692  void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
693  ConversionPatternRewriter &rewriter) const override {
694  auto srcType = memRefCastOp.getOperand().getType();
695  auto dstType = memRefCastOp.getType();
696  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
697  auto loc = memRefCastOp.getLoc();
698 
699  // For ranked/ranked case, just keep the original descriptor.
700  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
701  return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
702 
703  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
704  // Casting ranked to unranked memref type
705  // Set the rank in the destination from the memref type
706  // Allocate space on the stack and copy the src memref descriptor
707  // Set the ptr in the destination to the stack space
708  auto srcMemRefType = cast<MemRefType>(srcType);
709  int64_t rank = srcMemRefType.getRank();
710  // ptr = AllocaOp sizeof(MemRefDescriptor)
711  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
712  loc, adaptor.getSource(), rewriter);
713 
714  // rank = ConstantOp srcRank
715  auto rankVal = rewriter.create<LLVM::ConstantOp>(
716  loc, getIndexType(), rewriter.getIndexAttr(rank));
717  // poison = PoisonOp
718  UnrankedMemRefDescriptor memRefDesc =
719  UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
720  // d1 = InsertValueOp poison, rank, 0
721  memRefDesc.setRank(rewriter, loc, rankVal);
722  // d2 = InsertValueOp d1, ptr, 1
723  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
724  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
725 
726  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
727  // Casting from unranked type to ranked.
728  // The operation is assumed to be doing a correct cast. If the destination
729  // type mismatches the unranked the type, it is undefined behavior.
730  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
731  // ptr = ExtractValueOp src, 1
732  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
733 
734  // struct = LoadOp ptr
735  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
736  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
737  } else {
738  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
739  }
740  }
741 };
742 
743 /// Pattern to lower a `memref.copy` to llvm.
744 ///
745 /// For memrefs with identity layouts, the copy is lowered to the llvm
746 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
747 /// to the generic `MemrefCopyFn`.
748 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
750 
751  LogicalResult
752  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
753  ConversionPatternRewriter &rewriter) const {
754  auto loc = op.getLoc();
755  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
756 
757  MemRefDescriptor srcDesc(adaptor.getSource());
758 
759  // Compute number of elements.
760  Value numElements = rewriter.create<LLVM::ConstantOp>(
761  loc, getIndexType(), rewriter.getIndexAttr(1));
762  for (int pos = 0; pos < srcType.getRank(); ++pos) {
763  auto size = srcDesc.size(rewriter, loc, pos);
764  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
765  }
766 
767  // Get element size.
768  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
769  // Compute total.
770  Value totalSize =
771  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
772 
773  Type elementType = typeConverter->convertType(srcType.getElementType());
774 
775  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
776  Value srcOffset = srcDesc.offset(rewriter, loc);
777  Value srcPtr = rewriter.create<LLVM::GEPOp>(
778  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
779  MemRefDescriptor targetDesc(adaptor.getTarget());
780  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
781  Value targetOffset = targetDesc.offset(rewriter, loc);
782  Value targetPtr = rewriter.create<LLVM::GEPOp>(
783  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
784  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
785  /*isVolatile=*/false);
786  rewriter.eraseOp(op);
787 
788  return success();
789  }
790 
791  LogicalResult
792  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
793  ConversionPatternRewriter &rewriter) const {
794  auto loc = op.getLoc();
795  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
796  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
797 
798  // First make sure we have an unranked memref descriptor representation.
799  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
800  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
801  type.getRank());
802  auto *typeConverter = getTypeConverter();
803  auto ptr =
804  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
805 
806  auto unrankedType =
807  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
809  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
810  };
811 
812  // Save stack position before promoting descriptors
813  auto stackSaveOp =
814  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
815 
816  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
817  Value unrankedSource =
818  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
819  : adaptor.getSource();
820  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
821  Value unrankedTarget =
822  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
823  : adaptor.getTarget();
824 
825  // Now promote the unranked descriptors to the stack.
826  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
827  rewriter.getIndexAttr(1));
828  auto promote = [&](Value desc) {
829  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
830  auto allocated =
831  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
832  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
833  return allocated;
834  };
835 
836  auto sourcePtr = promote(unrankedSource);
837  auto targetPtr = promote(unrankedTarget);
838 
839  // Derive size from llvm.getelementptr which will account for any
840  // potential alignment
841  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
842  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
843  op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
844  if (failed(copyFn))
845  return failure();
846  rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
847  ValueRange{elemSize, sourcePtr, targetPtr});
848 
849  // Restore stack used for descriptors
850  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
851 
852  rewriter.eraseOp(op);
853 
854  return success();
855  }
856 
857  LogicalResult
858  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
859  ConversionPatternRewriter &rewriter) const override {
860  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
861  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
862 
863  auto isContiguousMemrefType = [&](BaseMemRefType type) {
864  auto memrefType = dyn_cast<mlir::MemRefType>(type);
865  // We can use memcpy for memrefs if they have an identity layout or are
866  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
867  // special case handled by memrefCopy.
868  return memrefType &&
869  (memrefType.getLayout().isIdentity() ||
870  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
872  };
873 
874  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
875  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
876 
877  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
878  }
879 };
880 
881 struct MemorySpaceCastOpLowering
882  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
884  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
885 
886  LogicalResult
887  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
888  ConversionPatternRewriter &rewriter) const override {
889  Location loc = op.getLoc();
890 
891  Type resultType = op.getDest().getType();
892  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
893  auto resultDescType =
894  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
895  Type newPtrType = resultDescType.getBody()[0];
896 
897  SmallVector<Value> descVals;
898  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
899  descVals);
900  descVals[0] =
901  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
902  descVals[1] =
903  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
904  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
905  resultTypeR, descVals);
906  rewriter.replaceOp(op, result);
907  return success();
908  }
909  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
910  // Since the type converter won't be doing this for us, get the address
911  // space.
912  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
913  FailureOr<unsigned> maybeSourceAddrSpace =
914  getTypeConverter()->getMemRefAddressSpace(sourceType);
915  if (failed(maybeSourceAddrSpace))
916  return rewriter.notifyMatchFailure(loc,
917  "non-integer source address space");
918  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
919  FailureOr<unsigned> maybeResultAddrSpace =
920  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
921  if (failed(maybeResultAddrSpace))
922  return rewriter.notifyMatchFailure(loc,
923  "non-integer result address space");
924  unsigned resultAddrSpace = *maybeResultAddrSpace;
925 
926  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
927  Value rank = sourceDesc.rank(rewriter, loc);
928  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
929 
930  // Create and allocate storage for new memref descriptor.
931  auto result = UnrankedMemRefDescriptor::poison(
932  rewriter, loc, typeConverter->convertType(resultTypeU));
933  result.setRank(rewriter, loc, rank);
934  SmallVector<Value, 1> sizes;
935  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
936  result, resultAddrSpace, sizes);
937  Value resultUnderlyingSize = sizes.front();
938  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
939  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
940  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
941 
942  // Copy pointers, performing address space casts.
943  auto sourceElemPtrType =
944  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
945  auto resultElemPtrType =
946  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
947 
948  Value allocatedPtr = sourceDesc.allocatedPtr(
949  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
950  Value alignedPtr =
951  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
952  sourceUnderlyingDesc, sourceElemPtrType);
953  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
954  loc, resultElemPtrType, allocatedPtr);
955  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
956  loc, resultElemPtrType, alignedPtr);
957 
958  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
959  resultElemPtrType, allocatedPtr);
960  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
961  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
962 
963  // Copy all the index-valued operands.
964  Value sourceIndexVals =
965  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
966  sourceUnderlyingDesc, sourceElemPtrType);
967  Value resultIndexVals =
968  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
969  resultUnderlyingDesc, resultElemPtrType);
970 
971  int64_t bytesToSkip =
972  2 * llvm::divideCeil(
973  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
974  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
975  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
976  Value copySize = rewriter.create<LLVM::SubOp>(
977  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
978  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
979  copySize, /*isVolatile=*/false);
980 
981  rewriter.replaceOp(op, ValueRange{result});
982  return success();
983  }
984  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
985  }
986 };
987 
988 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
989 /// memref type. In unranked case, the fields are extracted from the underlying
990 /// ranked descriptor.
991 static void extractPointersAndOffset(Location loc,
992  ConversionPatternRewriter &rewriter,
993  const LLVMTypeConverter &typeConverter,
994  Value originalOperand,
995  Value convertedOperand,
996  Value *allocatedPtr, Value *alignedPtr,
997  Value *offset = nullptr) {
998  Type operandType = originalOperand.getType();
999  if (isa<MemRefType>(operandType)) {
1000  MemRefDescriptor desc(convertedOperand);
1001  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1002  *alignedPtr = desc.alignedPtr(rewriter, loc);
1003  if (offset != nullptr)
1004  *offset = desc.offset(rewriter, loc);
1005  return;
1006  }
1007 
1008  // These will all cause assert()s on unconvertible types.
1009  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1010  cast<UnrankedMemRefType>(operandType));
1011  auto elementPtrType =
1012  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1013 
1014  // Extract pointer to the underlying ranked memref descriptor and cast it to
1015  // ElemType**.
1016  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1017  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1018 
1020  rewriter, loc, underlyingDescPtr, elementPtrType);
1022  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1023  if (offset != nullptr) {
1025  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1026  }
1027 }
1028 
1029 struct MemRefReinterpretCastOpLowering
1030  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1031  using ConvertOpToLLVMPattern<
1032  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1033 
1034  LogicalResult
1035  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1036  ConversionPatternRewriter &rewriter) const override {
1037  Type srcType = castOp.getSource().getType();
1038 
1039  Value descriptor;
1040  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1041  adaptor, &descriptor)))
1042  return failure();
1043  rewriter.replaceOp(castOp, {descriptor});
1044  return success();
1045  }
1046 
1047 private:
1048  LogicalResult convertSourceMemRefToDescriptor(
1049  ConversionPatternRewriter &rewriter, Type srcType,
1050  memref::ReinterpretCastOp castOp,
1051  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1052  MemRefType targetMemRefType =
1053  cast<MemRefType>(castOp.getResult().getType());
1054  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1055  typeConverter->convertType(targetMemRefType));
1056  if (!llvmTargetDescriptorTy)
1057  return failure();
1058 
1059  // Create descriptor.
1060  Location loc = castOp.getLoc();
1061  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1062 
1063  // Set allocated and aligned pointers.
1064  Value allocatedPtr, alignedPtr;
1065  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1066  castOp.getSource(), adaptor.getSource(),
1067  &allocatedPtr, &alignedPtr);
1068  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1069  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1070 
1071  // Set offset.
1072  if (castOp.isDynamicOffset(0))
1073  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1074  else
1075  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1076 
1077  // Set sizes and strides.
1078  unsigned dynSizeId = 0;
1079  unsigned dynStrideId = 0;
1080  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1081  if (castOp.isDynamicSize(i))
1082  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1083  else
1084  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1085 
1086  if (castOp.isDynamicStride(i))
1087  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1088  else
1089  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1090  }
1091  *descriptor = desc;
1092  return success();
1093  }
1094 };
1095 
1096 struct MemRefReshapeOpLowering
1097  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1099 
1100  LogicalResult
1101  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1102  ConversionPatternRewriter &rewriter) const override {
1103  Type srcType = reshapeOp.getSource().getType();
1104 
1105  Value descriptor;
1106  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1107  adaptor, &descriptor)))
1108  return failure();
1109  rewriter.replaceOp(reshapeOp, {descriptor});
1110  return success();
1111  }
1112 
1113 private:
1114  LogicalResult
1115  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1116  Type srcType, memref::ReshapeOp reshapeOp,
1117  memref::ReshapeOp::Adaptor adaptor,
1118  Value *descriptor) const {
1119  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1120  if (shapeMemRefType.hasStaticShape()) {
1121  MemRefType targetMemRefType =
1122  cast<MemRefType>(reshapeOp.getResult().getType());
1123  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1124  typeConverter->convertType(targetMemRefType));
1125  if (!llvmTargetDescriptorTy)
1126  return failure();
1127 
1128  // Create descriptor.
1129  Location loc = reshapeOp.getLoc();
1130  auto desc =
1131  MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1132 
1133  // Set allocated and aligned pointers.
1134  Value allocatedPtr, alignedPtr;
1135  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1136  reshapeOp.getSource(), adaptor.getSource(),
1137  &allocatedPtr, &alignedPtr);
1138  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1139  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1140 
1141  // Extract the offset and strides from the type.
1142  int64_t offset;
1143  SmallVector<int64_t> strides;
1144  if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1145  return rewriter.notifyMatchFailure(
1146  reshapeOp, "failed to get stride and offset exprs");
1147 
1148  if (!isStaticStrideOrOffset(offset))
1149  return rewriter.notifyMatchFailure(reshapeOp,
1150  "dynamic offset is unsupported");
1151 
1152  desc.setConstantOffset(rewriter, loc, offset);
1153 
1154  assert(targetMemRefType.getLayout().isIdentity() &&
1155  "Identity layout map is a precondition of a valid reshape op");
1156 
1157  Type indexType = getIndexType();
1158  Value stride = nullptr;
1159  int64_t targetRank = targetMemRefType.getRank();
1160  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1161  if (!ShapedType::isDynamic(strides[i])) {
1162  // If the stride for this dimension is dynamic, then use the product
1163  // of the sizes of the inner dimensions.
1164  stride =
1165  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1166  } else if (!stride) {
1167  // `stride` is null only in the first iteration of the loop. However,
1168  // since the target memref has an identity layout, we can safely set
1169  // the innermost stride to 1.
1170  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1171  }
1172 
1173  Value dimSize;
1174  // If the size of this dimension is dynamic, then load it at runtime
1175  // from the shape operand.
1176  if (!targetMemRefType.isDynamicDim(i)) {
1177  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1178  targetMemRefType.getDimSize(i));
1179  } else {
1180  Value shapeOp = reshapeOp.getShape();
1181  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1182  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1183  Type indexType = getIndexType();
1184  if (dimSize.getType() != indexType)
1185  dimSize = typeConverter->materializeTargetConversion(
1186  rewriter, loc, indexType, dimSize);
1187  assert(dimSize && "Invalid memref element type");
1188  }
1189 
1190  desc.setSize(rewriter, loc, i, dimSize);
1191  desc.setStride(rewriter, loc, i, stride);
1192 
1193  // Prepare the stride value for the next dimension.
1194  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1195  }
1196 
1197  *descriptor = desc;
1198  return success();
1199  }
1200 
1201  // The shape is a rank-1 tensor with unknown length.
1202  Location loc = reshapeOp.getLoc();
1203  MemRefDescriptor shapeDesc(adaptor.getShape());
1204  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1205 
1206  // Extract address space and element type.
1207  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1208  unsigned addressSpace =
1209  *getTypeConverter()->getMemRefAddressSpace(targetType);
1210 
1211  // Create the unranked memref descriptor that holds the ranked one. The
1212  // inner descriptor is allocated on stack.
1213  auto targetDesc = UnrankedMemRefDescriptor::poison(
1214  rewriter, loc, typeConverter->convertType(targetType));
1215  targetDesc.setRank(rewriter, loc, resultRank);
1216  SmallVector<Value, 4> sizes;
1217  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1218  targetDesc, addressSpace, sizes);
1219  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1220  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1221  sizes.front());
1222  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1223 
1224  // Extract pointers and offset from the source memref.
1225  Value allocatedPtr, alignedPtr, offset;
1226  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1227  reshapeOp.getSource(), adaptor.getSource(),
1228  &allocatedPtr, &alignedPtr, &offset);
1229 
1230  // Set pointers and offset.
1231  auto elementPtrType =
1232  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1233 
1234  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1235  elementPtrType, allocatedPtr);
1236  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1237  underlyingDescPtr, elementPtrType,
1238  alignedPtr);
1239  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1240  underlyingDescPtr, elementPtrType,
1241  offset);
1242 
1243  // Use the offset pointer as base for further addressing. Copy over the new
1244  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1245  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1246  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1247  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1248  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1249  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1250  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1251  Value resultRankMinusOne =
1252  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1253 
1254  Block *initBlock = rewriter.getInsertionBlock();
1255  Type indexType = getTypeConverter()->getIndexType();
1256  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1257 
1258  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1259  {indexType, indexType}, {loc, loc});
1260 
1261  // Move the remaining initBlock ops to condBlock.
1262  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1263  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1264 
1265  rewriter.setInsertionPointToEnd(initBlock);
1266  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1267  condBlock);
1268  rewriter.setInsertionPointToStart(condBlock);
1269  Value indexArg = condBlock->getArgument(0);
1270  Value strideArg = condBlock->getArgument(1);
1271 
1272  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1273  Value pred = rewriter.create<LLVM::ICmpOp>(
1274  loc, IntegerType::get(rewriter.getContext(), 1),
1275  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1276 
1277  Block *bodyBlock =
1278  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1279  rewriter.setInsertionPointToStart(bodyBlock);
1280 
1281  // Copy size from shape to descriptor.
1282  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1283  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1284  loc, llvmIndexPtrType,
1285  typeConverter->convertType(shapeMemRefType.getElementType()),
1286  shapeOperandPtr, indexArg);
1287  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1288  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1289  targetSizesBase, indexArg, size);
1290 
1291  // Write stride value and compute next one.
1292  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1293  targetStridesBase, indexArg, strideArg);
1294  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1295 
1296  // Decrement loop counter and branch back.
1297  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1298  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1299  condBlock);
1300 
1301  Block *remainder =
1302  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1303 
1304  // Hook up the cond exit to the remainder.
1305  rewriter.setInsertionPointToEnd(condBlock);
1306  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1307  remainder, std::nullopt);
1308 
1309  // Reset position to beginning of new remainder block.
1310  rewriter.setInsertionPointToStart(remainder);
1311 
1312  *descriptor = targetDesc;
1313  return success();
1314  }
1315 };
1316 
1317 /// RessociatingReshapeOp must be expanded before we reach this stage.
1318 /// Report that information.
1319 template <typename ReshapeOp>
1320 class ReassociatingReshapeOpConversion
1321  : public ConvertOpToLLVMPattern<ReshapeOp> {
1322 public:
1324  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1325 
1326  LogicalResult
1327  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1328  ConversionPatternRewriter &rewriter) const override {
1329  return rewriter.notifyMatchFailure(
1330  reshapeOp,
1331  "reassociation operations should have been expanded beforehand");
1332  }
1333 };
1334 
1335 /// Subviews must be expanded before we reach this stage.
1336 /// Report that information.
1337 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1339 
1340  LogicalResult
1341  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1342  ConversionPatternRewriter &rewriter) const override {
1343  return rewriter.notifyMatchFailure(
1344  subViewOp, "subview operations should have been expanded beforehand");
1345  }
1346 };
1347 
1348 /// Conversion pattern that transforms a transpose op into:
1349 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1350 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1351 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1352 /// and stride. Size and stride are permutations of the original values.
1353 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1354 /// The transpose op is replaced by the alloca'ed pointer.
1355 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1356 public:
1358 
1359  LogicalResult
1360  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1361  ConversionPatternRewriter &rewriter) const override {
1362  auto loc = transposeOp.getLoc();
1363  MemRefDescriptor viewMemRef(adaptor.getIn());
1364 
1365  // No permutation, early exit.
1366  if (transposeOp.getPermutation().isIdentity())
1367  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1368 
1369  auto targetMemRef = MemRefDescriptor::poison(
1370  rewriter, loc,
1371  typeConverter->convertType(transposeOp.getIn().getType()));
1372 
1373  // Copy the base and aligned pointers from the old descriptor to the new
1374  // one.
1375  targetMemRef.setAllocatedPtr(rewriter, loc,
1376  viewMemRef.allocatedPtr(rewriter, loc));
1377  targetMemRef.setAlignedPtr(rewriter, loc,
1378  viewMemRef.alignedPtr(rewriter, loc));
1379 
1380  // Copy the offset pointer from the old descriptor to the new one.
1381  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1382 
1383  // Iterate over the dimensions and apply size/stride permutation:
1384  // When enumerating the results of the permutation map, the enumeration
1385  // index is the index into the target dimensions and the DimExpr points to
1386  // the dimension of the source memref.
1387  for (const auto &en :
1388  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1389  int targetPos = en.index();
1390  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1391  targetMemRef.setSize(rewriter, loc, targetPos,
1392  viewMemRef.size(rewriter, loc, sourcePos));
1393  targetMemRef.setStride(rewriter, loc, targetPos,
1394  viewMemRef.stride(rewriter, loc, sourcePos));
1395  }
1396 
1397  rewriter.replaceOp(transposeOp, {targetMemRef});
1398  return success();
1399  }
1400 };
1401 
1402 /// Conversion pattern that transforms an op into:
1403 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1404 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1405 /// and stride.
1406 /// The view op is replaced by the descriptor.
1407 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1409 
1410  // Build and return the value for the idx^th shape dimension, either by
1411  // returning the constant shape dimension or counting the proper dynamic size.
1412  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1413  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1414  Type indexType) const {
1415  assert(idx < shape.size());
1416  if (!ShapedType::isDynamic(shape[idx]))
1417  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1418  // Count the number of dynamic dims in range [0, idx]
1419  unsigned nDynamic =
1420  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1421  return dynamicSizes[nDynamic];
1422  }
1423 
1424  // Build and return the idx^th stride, either by returning the constant stride
1425  // or by computing the dynamic stride from the current `runningStride` and
1426  // `nextSize`. The caller should keep a running stride and update it with the
1427  // result returned by this function.
1428  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1429  ArrayRef<int64_t> strides, Value nextSize,
1430  Value runningStride, unsigned idx, Type indexType) const {
1431  assert(idx < strides.size());
1432  if (!ShapedType::isDynamic(strides[idx]))
1433  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1434  if (nextSize)
1435  return runningStride
1436  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1437  : nextSize;
1438  assert(!runningStride);
1439  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1440  }
1441 
1442  LogicalResult
1443  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1444  ConversionPatternRewriter &rewriter) const override {
1445  auto loc = viewOp.getLoc();
1446 
1447  auto viewMemRefType = viewOp.getType();
1448  auto targetElementTy =
1449  typeConverter->convertType(viewMemRefType.getElementType());
1450  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1451  if (!targetDescTy || !targetElementTy ||
1452  !LLVM::isCompatibleType(targetElementTy) ||
1453  !LLVM::isCompatibleType(targetDescTy))
1454  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1455  failure();
1456 
1457  int64_t offset;
1458  SmallVector<int64_t, 4> strides;
1459  auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1460  if (failed(successStrides))
1461  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1462  assert(offset == 0 && "expected offset to be 0");
1463 
1464  // Target memref must be contiguous in memory (innermost stride is 1), or
1465  // empty (special case when at least one of the memref dimensions is 0).
1466  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1467  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1468  failure();
1469 
1470  // Create the descriptor.
1471  MemRefDescriptor sourceMemRef(adaptor.getSource());
1472  auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1473 
1474  // Field 1: Copy the allocated pointer, used for malloc/free.
1475  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1476  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1477  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1478 
1479  // Field 2: Copy the actual aligned pointer to payload.
1480  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1481  alignedPtr = rewriter.create<LLVM::GEPOp>(
1482  loc, alignedPtr.getType(),
1483  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1484  adaptor.getByteShift());
1485 
1486  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1487 
1488  Type indexType = getIndexType();
1489  // Field 3: The offset in the resulting type must be 0. This is
1490  // because of the type change: an offset on srcType* may not be
1491  // expressible as an offset on dstType*.
1492  targetMemRef.setOffset(
1493  rewriter, loc,
1494  createIndexAttrConstant(rewriter, loc, indexType, offset));
1495 
1496  // Early exit for 0-D corner case.
1497  if (viewMemRefType.getRank() == 0)
1498  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1499 
1500  // Fields 4 and 5: Update sizes and strides.
1501  Value stride = nullptr, nextSize = nullptr;
1502  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1503  // Update size.
1504  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1505  adaptor.getSizes(), i, indexType);
1506  targetMemRef.setSize(rewriter, loc, i, size);
1507  // Update stride.
1508  stride =
1509  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1510  targetMemRef.setStride(rewriter, loc, i, stride);
1511  nextSize = size;
1512  }
1513 
1514  rewriter.replaceOp(viewOp, {targetMemRef});
1515  return success();
1516  }
1517 };
1518 
1519 //===----------------------------------------------------------------------===//
1520 // AtomicRMWOpLowering
1521 //===----------------------------------------------------------------------===//
1522 
1523 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1524 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1525 static std::optional<LLVM::AtomicBinOp>
1526 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1527  switch (atomicOp.getKind()) {
1528  case arith::AtomicRMWKind::addf:
1529  return LLVM::AtomicBinOp::fadd;
1530  case arith::AtomicRMWKind::addi:
1531  return LLVM::AtomicBinOp::add;
1532  case arith::AtomicRMWKind::assign:
1533  return LLVM::AtomicBinOp::xchg;
1534  case arith::AtomicRMWKind::maximumf:
1535  return LLVM::AtomicBinOp::fmax;
1536  case arith::AtomicRMWKind::maxs:
1537  return LLVM::AtomicBinOp::max;
1538  case arith::AtomicRMWKind::maxu:
1539  return LLVM::AtomicBinOp::umax;
1540  case arith::AtomicRMWKind::minimumf:
1541  return LLVM::AtomicBinOp::fmin;
1542  case arith::AtomicRMWKind::mins:
1543  return LLVM::AtomicBinOp::min;
1544  case arith::AtomicRMWKind::minu:
1545  return LLVM::AtomicBinOp::umin;
1546  case arith::AtomicRMWKind::ori:
1547  return LLVM::AtomicBinOp::_or;
1548  case arith::AtomicRMWKind::andi:
1549  return LLVM::AtomicBinOp::_and;
1550  default:
1551  return std::nullopt;
1552  }
1553  llvm_unreachable("Invalid AtomicRMWKind");
1554 }
1555 
1556 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1557  using Base::Base;
1558 
1559  LogicalResult
1560  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1561  ConversionPatternRewriter &rewriter) const override {
1562  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1563  if (!maybeKind)
1564  return failure();
1565  auto memRefType = atomicOp.getMemRefType();
1566  SmallVector<int64_t> strides;
1567  int64_t offset;
1568  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1569  return failure();
1570  auto dataPtr =
1571  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1572  adaptor.getIndices(), rewriter);
1573  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1574  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1575  LLVM::AtomicOrdering::acq_rel);
1576  return success();
1577  }
1578 };
1579 
1580 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1581 class ConvertExtractAlignedPointerAsIndex
1582  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1583 public:
1584  using ConvertOpToLLVMPattern<
1585  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1586 
1587  LogicalResult
1588  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1589  OpAdaptor adaptor,
1590  ConversionPatternRewriter &rewriter) const override {
1591  BaseMemRefType sourceTy = extractOp.getSource().getType();
1592 
1593  Value alignedPtr;
1594  if (sourceTy.hasRank()) {
1595  MemRefDescriptor desc(adaptor.getSource());
1596  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1597  } else {
1598  auto elementPtrTy = LLVM::LLVMPointerType::get(
1599  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1600 
1601  UnrankedMemRefDescriptor desc(adaptor.getSource());
1602  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1603 
1605  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1606  elementPtrTy);
1607  }
1608 
1609  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1610  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1611  return success();
1612  }
1613 };
1614 
1615 /// Materialize the MemRef descriptor represented by the results of
1616 /// ExtractStridedMetadataOp.
1617 class ExtractStridedMetadataOpLowering
1618  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1619 public:
1620  using ConvertOpToLLVMPattern<
1621  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1622 
1623  LogicalResult
1624  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1625  OpAdaptor adaptor,
1626  ConversionPatternRewriter &rewriter) const override {
1627 
1628  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1629  return failure();
1630 
1631  // Create the descriptor.
1632  MemRefDescriptor sourceMemRef(adaptor.getSource());
1633  Location loc = extractStridedMetadataOp.getLoc();
1634  Value source = extractStridedMetadataOp.getSource();
1635 
1636  auto sourceMemRefType = cast<MemRefType>(source.getType());
1637  int64_t rank = sourceMemRefType.getRank();
1638  SmallVector<Value> results;
1639  results.reserve(2 + rank * 2);
1640 
1641  // Base buffer.
1642  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1643  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1645  rewriter, loc, *getTypeConverter(),
1646  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1647  baseBuffer, alignedBuffer);
1648  results.push_back((Value)dstMemRef);
1649 
1650  // Offset.
1651  results.push_back(sourceMemRef.offset(rewriter, loc));
1652 
1653  // Sizes.
1654  for (unsigned i = 0; i < rank; ++i)
1655  results.push_back(sourceMemRef.size(rewriter, loc, i));
1656  // Strides.
1657  for (unsigned i = 0; i < rank; ++i)
1658  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1659 
1660  rewriter.replaceOp(extractStridedMetadataOp, results);
1661  return success();
1662  }
1663 };
1664 
1665 } // namespace
1666 
1668  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1669  // clang-format off
1670  patterns.add<
1671  AllocaOpLowering,
1672  AllocaScopeOpLowering,
1673  AtomicRMWOpLowering,
1674  AssumeAlignmentOpLowering,
1675  ConvertExtractAlignedPointerAsIndex,
1676  DimOpLowering,
1677  ExtractStridedMetadataOpLowering,
1678  GenericAtomicRMWOpLowering,
1679  GlobalMemrefOpLowering,
1680  GetGlobalMemrefOpLowering,
1681  LoadOpLowering,
1682  MemRefCastOpLowering,
1683  MemRefCopyOpLowering,
1684  MemorySpaceCastOpLowering,
1685  MemRefReinterpretCastOpLowering,
1686  MemRefReshapeOpLowering,
1687  PrefetchOpLowering,
1688  RankOpLowering,
1689  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1690  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1691  StoreOpLowering,
1692  SubViewOpLowering,
1694  ViewOpLowering>(converter);
1695  // clang-format on
1696  auto allocLowering = converter.getOptions().allocLowering;
1698  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1699  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1700  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1701 }
1702 
1703 namespace {
1704 struct FinalizeMemRefToLLVMConversionPass
1705  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1706  FinalizeMemRefToLLVMConversionPass> {
1707  using FinalizeMemRefToLLVMConversionPassBase::
1708  FinalizeMemRefToLLVMConversionPassBase;
1709 
1710  void runOnOperation() override {
1711  Operation *op = getOperation();
1712  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1714  dataLayoutAnalysis.getAtOrAbove(op));
1715  options.allocLowering =
1718 
1719  options.useGenericFunctions = useGenericFunctions;
1720 
1721  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1722  options.overrideIndexBitwidth(indexBitwidth);
1723 
1724  LLVMTypeConverter typeConverter(&getContext(), options,
1725  &dataLayoutAnalysis);
1728  LLVMConversionTarget target(getContext());
1729  target.addLegalOp<func::FuncOp>();
1730  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1731  signalPassFailure();
1732  }
1733 };
1734 
1735 /// Implement the interface to convert MemRef to LLVM.
1736 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1738  void loadDependentDialects(MLIRContext *context) const final {
1739  context->loadDialect<LLVM::LLVMDialect>();
1740  }
1741 
1742  /// Hook for derived dialect interface to provide conversion patterns
1743  /// and mark dialect legal for the conversion target.
1744  void populateConvertToLLVMConversionPatterns(
1745  ConversionTarget &target, LLVMTypeConverter &typeConverter,
1746  RewritePatternSet &patterns) const final {
1748  }
1749 };
1750 
1751 } // namespace
1752 
1754  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1755  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1756  });
1757 }
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI8Type()
Definition: Builders.cpp:59
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:415
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(Operation *moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(Operation *moduleOp)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for AllocOp and AllocaOp.