MLIR  20.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 
39 namespace {
40 
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42  return !ShapedType::isDynamic(strideOrOffset);
43 }
44 
45 LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46  ModuleOp module) {
47  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48 
49  if (useGenericFn)
50  return LLVM::lookupOrCreateGenericFreeFn(module);
51 
52  return LLVM::lookupOrCreateFreeFn(module);
53 }
54 
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56  AllocOpLowering(const LLVMTypeConverter &converter)
57  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58  converter) {}
59  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60  Location loc, Value sizeBytes,
61  Operation *op) const override {
62  return allocateBufferManuallyAlign(
63  rewriter, loc, sizeBytes, op,
64  getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
65  }
66 };
67 
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69  AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71  converter) {}
72  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73  Location loc, Value sizeBytes,
74  Operation *op) const override {
75  Value ptr = allocateBufferAutoAlign(
76  rewriter, loc, sizeBytes, op, &defaultLayout,
77  alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78  &defaultLayout));
79  if (!ptr)
80  return std::make_tuple(Value(), Value());
81  return std::make_tuple(ptr, ptr);
82  }
83 
84 private:
85  /// Default layout to use in absence of the corresponding analysis.
86  DataLayout defaultLayout;
87 };
88 
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90  AllocaOpLowering(const LLVMTypeConverter &converter)
91  : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92  converter) {
93  setRequiresNumElements();
94  }
95 
96  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97  /// is set to null for stack allocations. `accessAlignment` is set if
98  /// alignment is needed post allocation (for eg. in conjunction with malloc).
99  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100  Location loc, Value size,
101  Operation *op) const override {
102 
103  // With alloca, one gets a pointer to the element type right away.
104  // For stack allocations.
105  auto allocaOp = cast<memref::AllocaOp>(op);
106  auto elementType =
107  typeConverter->convertType(allocaOp.getType().getElementType());
108  unsigned addrSpace =
109  *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110  auto elementPtrType =
111  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
112 
113  auto allocatedElementPtr =
114  rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115  allocaOp.getAlignment().value_or(0));
116 
117  return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
118  }
119 };
120 
121 struct AllocaScopeOpLowering
122  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
124 
125  LogicalResult
126  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  OpBuilder::InsertionGuard guard(rewriter);
129  Location loc = allocaScopeOp.getLoc();
130 
131  // Split the current block before the AllocaScopeOp to create the inlining
132  // point.
133  auto *currentBlock = rewriter.getInsertionBlock();
134  auto *remainingOpsBlock =
135  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136  Block *continueBlock;
137  if (allocaScopeOp.getNumResults() == 0) {
138  continueBlock = remainingOpsBlock;
139  } else {
140  continueBlock = rewriter.createBlock(
141  remainingOpsBlock, allocaScopeOp.getResultTypes(),
142  SmallVector<Location>(allocaScopeOp->getNumResults(),
143  allocaScopeOp.getLoc()));
144  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
145  }
146 
147  // Inline body region.
148  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
151 
152  // Save stack and then branch into the body of the region.
153  rewriter.setInsertionPointToEnd(currentBlock);
154  auto stackSaveOp =
155  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
157 
158  // Replace the alloca_scope return with a branch that jumps out of the body.
159  // Stack restore before leaving the body region.
160  rewriter.setInsertionPointToEnd(afterBody);
161  auto returnOp =
162  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164  returnOp, returnOp.getResults(), continueBlock);
165 
166  // Insert stack restore before jumping out the body of the region.
167  rewriter.setInsertionPoint(branchOp);
168  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
169 
170  // Replace the op with values return from the body region.
171  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
172 
173  return success();
174  }
175 };
176 
177 struct AssumeAlignmentOpLowering
178  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
180  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182  : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
183 
184  LogicalResult
185  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187  Value memref = adaptor.getMemref();
188  unsigned alignment = op.getAlignment();
189  auto loc = op.getLoc();
190 
191  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192  Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193  rewriter);
194 
195  // Emit llvm.assume(memref & (alignment - 1) == 0).
196  //
197  // This relies on LLVM's CSE optimization (potentially after SROA), since
198  // after CSE all memref instances should get de-duplicated into the same
199  // pointer SSA value.
200  MemRefDescriptor memRefDescriptor(memref);
201  auto intPtrType =
202  getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
203  Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
204  Value mask =
205  createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
206  Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207  rewriter.create<LLVM::AssumeOp>(
208  loc, rewriter.create<LLVM::ICmpOp>(
209  loc, LLVM::ICmpPredicate::eq,
210  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
211 
212  rewriter.eraseOp(op);
213  return success();
214  }
215 };
216 
217 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
218 // The memref descriptor being an SSA value, there is no need to clean it up
219 // in any way.
220 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
222 
223  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
224  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
225 
226  LogicalResult
227  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
228  ConversionPatternRewriter &rewriter) const override {
229  // Insert the `free` declaration if it is not already present.
230  LLVM::LLVMFuncOp freeFunc =
231  getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
232  Value allocatedPtr;
233  if (auto unrankedTy =
234  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
235  auto elementPtrTy = LLVM::LLVMPointerType::get(
236  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
238  rewriter, op.getLoc(),
239  UnrankedMemRefDescriptor(adaptor.getMemref())
240  .memRefDescPtr(rewriter, op.getLoc()),
241  elementPtrTy);
242  } else {
243  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
244  .allocatedPtr(rewriter, op.getLoc());
245  }
246  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
247  return success();
248  }
249 };
250 
251 // A `dim` is converted to a constant for static sizes and to an access to the
252 // size stored in the memref descriptor for dynamic sizes.
253 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
255 
256  LogicalResult
257  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
258  ConversionPatternRewriter &rewriter) const override {
259  Type operandType = dimOp.getSource().getType();
260  if (isa<UnrankedMemRefType>(operandType)) {
261  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
262  operandType, dimOp, adaptor.getOperands(), rewriter);
263  if (failed(extractedSize))
264  return failure();
265  rewriter.replaceOp(dimOp, {*extractedSize});
266  return success();
267  }
268  if (isa<MemRefType>(operandType)) {
269  rewriter.replaceOp(
270  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
271  adaptor.getOperands(), rewriter)});
272  return success();
273  }
274  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
275  }
276 
277 private:
278  FailureOr<Value>
279  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
280  OpAdaptor adaptor,
281  ConversionPatternRewriter &rewriter) const {
282  Location loc = dimOp.getLoc();
283 
284  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
285  auto scalarMemRefType =
286  MemRefType::get({}, unrankedMemRefType.getElementType());
287  FailureOr<unsigned> maybeAddressSpace =
288  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
289  if (failed(maybeAddressSpace)) {
290  dimOp.emitOpError("memref memory space must be convertible to an integer "
291  "address space");
292  return failure();
293  }
294  unsigned addressSpace = *maybeAddressSpace;
295 
296  // Extract pointer to the underlying ranked descriptor and bitcast it to a
297  // memref<element_type> descriptor pointer to minimize the number of GEP
298  // operations.
299  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
300  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
301 
302  Type elementType = typeConverter->convertType(scalarMemRefType);
303 
304  // Get pointer to offset field of memref<element_type> descriptor.
305  auto indexPtrTy =
306  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
307  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
308  loc, indexPtrTy, elementType, underlyingRankedDesc,
309  ArrayRef<LLVM::GEPArg>{0, 2});
310 
311  // The size value that we have to extract can be obtained using GEPop with
312  // `dimOp.index() + 1` index argument.
313  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
314  loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
315  adaptor.getIndex());
316  Value sizePtr = rewriter.create<LLVM::GEPOp>(
317  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
318  idxPlusOne);
319  return rewriter
320  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
321  .getResult();
322  }
323 
324  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
325  if (auto idx = dimOp.getConstantIndex())
326  return idx;
327 
328  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
329  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
330 
331  return std::nullopt;
332  }
333 
334  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
335  OpAdaptor adaptor,
336  ConversionPatternRewriter &rewriter) const {
337  Location loc = dimOp.getLoc();
338 
339  // Take advantage if index is constant.
340  MemRefType memRefType = cast<MemRefType>(operandType);
341  Type indexType = getIndexType();
342  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
343  int64_t i = *index;
344  if (i >= 0 && i < memRefType.getRank()) {
345  if (memRefType.isDynamicDim(i)) {
346  // extract dynamic size from the memref descriptor.
347  MemRefDescriptor descriptor(adaptor.getSource());
348  return descriptor.size(rewriter, loc, i);
349  }
350  // Use constant for static size.
351  int64_t dimSize = memRefType.getDimSize(i);
352  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
353  }
354  }
355  Value index = adaptor.getIndex();
356  int64_t rank = memRefType.getRank();
357  MemRefDescriptor memrefDescriptor(adaptor.getSource());
358  return memrefDescriptor.size(rewriter, loc, index, rank);
359  }
360 };
361 
362 /// Common base for load and store operations on MemRefs. Restricts the match
363 /// to supported MemRef types. Provides functionality to emit code accessing a
364 /// specific element of the underlying data buffer.
365 template <typename Derived>
366 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
369  using Base = LoadStoreOpLowering<Derived>;
370 
371  LogicalResult match(Derived op) const override {
372  MemRefType type = op.getMemRefType();
373  return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
374  }
375 };
376 
377 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
378 /// retried until it succeeds in atomically storing a new value into memory.
379 ///
380 /// +---------------------------------+
381 /// | <code before the AtomicRMWOp> |
382 /// | <compute initial %loaded> |
383 /// | cf.br loop(%loaded) |
384 /// +---------------------------------+
385 /// |
386 /// -------| |
387 /// | v v
388 /// | +--------------------------------+
389 /// | | loop(%loaded): |
390 /// | | <body contents> |
391 /// | | %pair = cmpxchg |
392 /// | | %ok = %pair[0] |
393 /// | | %new = %pair[1] |
394 /// | | cf.cond_br %ok, end, loop(%new) |
395 /// | +--------------------------------+
396 /// | | |
397 /// |----------- |
398 /// v
399 /// +--------------------------------+
400 /// | end: |
401 /// | <code after the AtomicRMWOp> |
402 /// +--------------------------------+
403 ///
404 struct GenericAtomicRMWOpLowering
405  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
406  using Base::Base;
407 
408  LogicalResult
409  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
410  ConversionPatternRewriter &rewriter) const override {
411  auto loc = atomicOp.getLoc();
412  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
413 
414  // Split the block into initial, loop, and ending parts.
415  auto *initBlock = rewriter.getInsertionBlock();
416  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
417  loopBlock->addArgument(valueType, loc);
418 
419  auto *endBlock =
420  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
421 
422  // Compute the loaded value and branch to the loop block.
423  rewriter.setInsertionPointToEnd(initBlock);
424  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
425  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
426  adaptor.getIndices(), rewriter);
427  Value init = rewriter.create<LLVM::LoadOp>(
428  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
429  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
430 
431  // Prepare the body of the loop block.
432  rewriter.setInsertionPointToStart(loopBlock);
433 
434  // Clone the GenericAtomicRMWOp region and extract the result.
435  auto loopArgument = loopBlock->getArgument(0);
436  IRMapping mapping;
437  mapping.map(atomicOp.getCurrentValue(), loopArgument);
438  Block &entryBlock = atomicOp.body().front();
439  for (auto &nestedOp : entryBlock.without_terminator()) {
440  Operation *clone = rewriter.clone(nestedOp, mapping);
441  mapping.map(nestedOp.getResults(), clone->getResults());
442  }
443  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
444 
445  // Prepare the epilog of the loop block.
446  // Append the cmpxchg op to the end of the loop block.
447  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
448  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
449  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
450  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
451  // Extract the %new_loaded and %ok values from the pair.
452  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
453  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
454 
455  // Conditionally branch to the end or back to the loop depending on %ok.
456  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
457  loopBlock, newLoaded);
458 
459  rewriter.setInsertionPointToEnd(endBlock);
460 
461  // The 'result' of the atomic_rmw op is the newly loaded value.
462  rewriter.replaceOp(atomicOp, {newLoaded});
463 
464  return success();
465  }
466 };
467 
468 /// Returns the LLVM type of the global variable given the memref type `type`.
469 static Type
470 convertGlobalMemrefTypeToLLVM(MemRefType type,
471  const LLVMTypeConverter &typeConverter) {
472  // LLVM type for a global memref will be a multi-dimension array. For
473  // declarations or uninitialized global memrefs, we can potentially flatten
474  // this to a 1D array. However, for memref.global's with an initial value,
475  // we do not intend to flatten the ElementsAttribute when going from std ->
476  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
477  Type elementType = typeConverter.convertType(type.getElementType());
478  Type arrayTy = elementType;
479  // Shape has the outermost dim at index 0, so need to walk it backwards
480  for (int64_t dim : llvm::reverse(type.getShape()))
481  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
482  return arrayTy;
483 }
484 
485 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
486 struct GlobalMemrefOpLowering
487  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
489 
490  LogicalResult
491  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
492  ConversionPatternRewriter &rewriter) const override {
493  MemRefType type = global.getType();
494  if (!isConvertibleAndHasIdentityMaps(type))
495  return failure();
496 
497  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
498 
499  LLVM::Linkage linkage =
500  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
501 
502  Attribute initialValue = nullptr;
503  if (!global.isExternal() && !global.isUninitialized()) {
504  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
505  initialValue = elementsAttr;
506 
507  // For scalar memrefs, the global variable created is of the element type,
508  // so unpack the elements attribute to extract the value.
509  if (type.getRank() == 0)
510  initialValue = elementsAttr.getSplatValue<Attribute>();
511  }
512 
513  uint64_t alignment = global.getAlignment().value_or(0);
514  FailureOr<unsigned> addressSpace =
515  getTypeConverter()->getMemRefAddressSpace(type);
516  if (failed(addressSpace))
517  return global.emitOpError(
518  "memory space cannot be converted to an integer address space");
519  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
520  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
521  initialValue, alignment, *addressSpace);
522  if (!global.isExternal() && global.isUninitialized()) {
523  rewriter.createBlock(&newGlobal.getInitializerRegion());
524  Value undef[] = {
525  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
526  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
527  }
528  return success();
529  }
530 };
531 
532 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
533 /// the first element stashed into the descriptor. This reuses
534 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
535 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
536  GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
537  : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
538  converter) {}
539 
540  /// Buffer "allocation" for memref.get_global op is getting the address of
541  /// the global variable referenced.
542  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
543  Location loc, Value sizeBytes,
544  Operation *op) const override {
545  auto getGlobalOp = cast<memref::GetGlobalOp>(op);
546  MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
547 
548  // This is called after a type conversion, which would have failed if this
549  // call fails.
550  FailureOr<unsigned> maybeAddressSpace =
551  getTypeConverter()->getMemRefAddressSpace(type);
552  if (failed(maybeAddressSpace))
553  return std::make_tuple(Value(), Value());
554  unsigned memSpace = *maybeAddressSpace;
555 
556  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
557  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
558  auto addressOf =
559  rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
560 
561  // Get the address of the first element in the array by creating a GEP with
562  // the address of the GV as the base, and (rank + 1) number of 0 indices.
563  auto gep = rewriter.create<LLVM::GEPOp>(
564  loc, ptrTy, arrayTy, addressOf,
565  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
566 
567  // We do not expect the memref obtained using `memref.get_global` to be
568  // ever deallocated. Set the allocated pointer to be known bad value to
569  // help debug if that ever happens.
570  auto intPtrType = getIntPtrType(memSpace);
571  Value deadBeefConst =
572  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
573  auto deadBeefPtr =
574  rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
575 
576  // Both allocated and aligned pointers are same. We could potentially stash
577  // a nullptr for the allocated pointer since we do not expect any dealloc.
578  return std::make_tuple(deadBeefPtr, gep);
579  }
580 };
581 
582 // Load operation is lowered to obtaining a pointer to the indexed element
583 // and loading it.
584 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
585  using Base::Base;
586 
587  LogicalResult
588  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const override {
590  auto type = loadOp.getMemRefType();
591 
592  Value dataPtr =
593  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
594  adaptor.getIndices(), rewriter);
595  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
596  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
597  false, loadOp.getNontemporal());
598  return success();
599  }
600 };
601 
602 // Store operation is lowered to obtaining a pointer to the indexed element,
603 // and storing the given value to it.
604 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
605  using Base::Base;
606 
607  LogicalResult
608  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
609  ConversionPatternRewriter &rewriter) const override {
610  auto type = op.getMemRefType();
611 
612  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
613  adaptor.getIndices(), rewriter);
614  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
615  0, false, op.getNontemporal());
616  return success();
617  }
618 };
619 
620 // The prefetch operation is lowered in a way similar to the load operation
621 // except that the llvm.prefetch operation is used for replacement.
622 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
623  using Base::Base;
624 
625  LogicalResult
626  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
627  ConversionPatternRewriter &rewriter) const override {
628  auto type = prefetchOp.getMemRefType();
629  auto loc = prefetchOp.getLoc();
630 
631  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
632  adaptor.getIndices(), rewriter);
633 
634  // Replace with llvm.prefetch.
635  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
636  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
637  IntegerAttr isData =
638  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
639  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
640  localityHint, isData);
641  return success();
642  }
643 };
644 
645 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
647 
648  LogicalResult
649  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
650  ConversionPatternRewriter &rewriter) const override {
651  Location loc = op.getLoc();
652  Type operandType = op.getMemref().getType();
653  if (dyn_cast<UnrankedMemRefType>(operandType)) {
654  UnrankedMemRefDescriptor desc(adaptor.getMemref());
655  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
656  return success();
657  }
658  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
659  Type indexType = getIndexType();
660  rewriter.replaceOp(op,
661  {createIndexAttrConstant(rewriter, loc, indexType,
662  rankedMemRefType.getRank())});
663  return success();
664  }
665  return failure();
666  }
667 };
668 
669 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
671 
672  LogicalResult match(memref::CastOp memRefCastOp) const override {
673  Type srcType = memRefCastOp.getOperand().getType();
674  Type dstType = memRefCastOp.getType();
675 
676  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
677  // used for type erasure. For now they must preserve underlying element type
678  // and require source and result type to have the same rank. Therefore,
679  // perform a sanity check that the underlying structs are the same. Once op
680  // semantics are relaxed we can revisit.
681  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
682  return success(typeConverter->convertType(srcType) ==
683  typeConverter->convertType(dstType));
684 
685  // At least one of the operands is unranked type
686  assert(isa<UnrankedMemRefType>(srcType) ||
687  isa<UnrankedMemRefType>(dstType));
688 
689  // Unranked to unranked cast is disallowed
690  return !(isa<UnrankedMemRefType>(srcType) &&
691  isa<UnrankedMemRefType>(dstType))
692  ? success()
693  : failure();
694  }
695 
696  void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
697  ConversionPatternRewriter &rewriter) const override {
698  auto srcType = memRefCastOp.getOperand().getType();
699  auto dstType = memRefCastOp.getType();
700  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
701  auto loc = memRefCastOp.getLoc();
702 
703  // For ranked/ranked case, just keep the original descriptor.
704  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
705  return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
706 
707  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
708  // Casting ranked to unranked memref type
709  // Set the rank in the destination from the memref type
710  // Allocate space on the stack and copy the src memref descriptor
711  // Set the ptr in the destination to the stack space
712  auto srcMemRefType = cast<MemRefType>(srcType);
713  int64_t rank = srcMemRefType.getRank();
714  // ptr = AllocaOp sizeof(MemRefDescriptor)
715  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
716  loc, adaptor.getSource(), rewriter);
717 
718  // rank = ConstantOp srcRank
719  auto rankVal = rewriter.create<LLVM::ConstantOp>(
720  loc, getIndexType(), rewriter.getIndexAttr(rank));
721  // undef = UndefOp
722  UnrankedMemRefDescriptor memRefDesc =
723  UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
724  // d1 = InsertValueOp undef, rank, 0
725  memRefDesc.setRank(rewriter, loc, rankVal);
726  // d2 = InsertValueOp d1, ptr, 1
727  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
728  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
729 
730  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
731  // Casting from unranked type to ranked.
732  // The operation is assumed to be doing a correct cast. If the destination
733  // type mismatches the unranked the type, it is undefined behavior.
734  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
735  // ptr = ExtractValueOp src, 1
736  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
737 
738  // struct = LoadOp ptr
739  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
740  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
741  } else {
742  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
743  }
744  }
745 };
746 
747 /// Pattern to lower a `memref.copy` to llvm.
748 ///
749 /// For memrefs with identity layouts, the copy is lowered to the llvm
750 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
751 /// to the generic `MemrefCopyFn`.
752 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
754 
755  LogicalResult
756  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
757  ConversionPatternRewriter &rewriter) const {
758  auto loc = op.getLoc();
759  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
760 
761  MemRefDescriptor srcDesc(adaptor.getSource());
762 
763  // Compute number of elements.
764  Value numElements = rewriter.create<LLVM::ConstantOp>(
765  loc, getIndexType(), rewriter.getIndexAttr(1));
766  for (int pos = 0; pos < srcType.getRank(); ++pos) {
767  auto size = srcDesc.size(rewriter, loc, pos);
768  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
769  }
770 
771  // Get element size.
772  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
773  // Compute total.
774  Value totalSize =
775  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
776 
777  Type elementType = typeConverter->convertType(srcType.getElementType());
778 
779  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
780  Value srcOffset = srcDesc.offset(rewriter, loc);
781  Value srcPtr = rewriter.create<LLVM::GEPOp>(
782  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
783  MemRefDescriptor targetDesc(adaptor.getTarget());
784  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
785  Value targetOffset = targetDesc.offset(rewriter, loc);
786  Value targetPtr = rewriter.create<LLVM::GEPOp>(
787  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
788  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
789  /*isVolatile=*/false);
790  rewriter.eraseOp(op);
791 
792  return success();
793  }
794 
795  LogicalResult
796  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
797  ConversionPatternRewriter &rewriter) const {
798  auto loc = op.getLoc();
799  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
800  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
801 
802  // First make sure we have an unranked memref descriptor representation.
803  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
804  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
805  type.getRank());
806  auto *typeConverter = getTypeConverter();
807  auto ptr =
808  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
809 
810  auto unrankedType =
811  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
813  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
814  };
815 
816  // Save stack position before promoting descriptors
817  auto stackSaveOp =
818  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
819 
820  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
821  Value unrankedSource =
822  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
823  : adaptor.getSource();
824  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
825  Value unrankedTarget =
826  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
827  : adaptor.getTarget();
828 
829  // Now promote the unranked descriptors to the stack.
830  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
831  rewriter.getIndexAttr(1));
832  auto promote = [&](Value desc) {
833  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
834  auto allocated =
835  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
836  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
837  return allocated;
838  };
839 
840  auto sourcePtr = promote(unrankedSource);
841  auto targetPtr = promote(unrankedTarget);
842 
843  // Derive size from llvm.getelementptr which will account for any
844  // potential alignment
845  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
846  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
847  op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
848  rewriter.create<LLVM::CallOp>(loc, copyFn,
849  ValueRange{elemSize, sourcePtr, targetPtr});
850 
851  // Restore stack used for descriptors
852  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
853 
854  rewriter.eraseOp(op);
855 
856  return success();
857  }
858 
859  LogicalResult
860  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
861  ConversionPatternRewriter &rewriter) const override {
862  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
863  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
864 
865  auto isContiguousMemrefType = [&](BaseMemRefType type) {
866  auto memrefType = dyn_cast<mlir::MemRefType>(type);
867  // We can use memcpy for memrefs if they have an identity layout or are
868  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
869  // special case handled by memrefCopy.
870  return memrefType &&
871  (memrefType.getLayout().isIdentity() ||
872  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
874  };
875 
876  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
877  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
878 
879  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
880  }
881 };
882 
883 struct MemorySpaceCastOpLowering
884  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
886  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
887 
888  LogicalResult
889  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
890  ConversionPatternRewriter &rewriter) const override {
891  Location loc = op.getLoc();
892 
893  Type resultType = op.getDest().getType();
894  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
895  auto resultDescType =
896  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
897  Type newPtrType = resultDescType.getBody()[0];
898 
899  SmallVector<Value> descVals;
900  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
901  descVals);
902  descVals[0] =
903  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
904  descVals[1] =
905  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
906  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
907  resultTypeR, descVals);
908  rewriter.replaceOp(op, result);
909  return success();
910  }
911  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
912  // Since the type converter won't be doing this for us, get the address
913  // space.
914  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
915  FailureOr<unsigned> maybeSourceAddrSpace =
916  getTypeConverter()->getMemRefAddressSpace(sourceType);
917  if (failed(maybeSourceAddrSpace))
918  return rewriter.notifyMatchFailure(loc,
919  "non-integer source address space");
920  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
921  FailureOr<unsigned> maybeResultAddrSpace =
922  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
923  if (failed(maybeResultAddrSpace))
924  return rewriter.notifyMatchFailure(loc,
925  "non-integer result address space");
926  unsigned resultAddrSpace = *maybeResultAddrSpace;
927 
928  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
929  Value rank = sourceDesc.rank(rewriter, loc);
930  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
931 
932  // Create and allocate storage for new memref descriptor.
933  auto result = UnrankedMemRefDescriptor::undef(
934  rewriter, loc, typeConverter->convertType(resultTypeU));
935  result.setRank(rewriter, loc, rank);
936  SmallVector<Value, 1> sizes;
937  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
938  result, resultAddrSpace, sizes);
939  Value resultUnderlyingSize = sizes.front();
940  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
941  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
942  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
943 
944  // Copy pointers, performing address space casts.
945  auto sourceElemPtrType =
946  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
947  auto resultElemPtrType =
948  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
949 
950  Value allocatedPtr = sourceDesc.allocatedPtr(
951  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
952  Value alignedPtr =
953  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
954  sourceUnderlyingDesc, sourceElemPtrType);
955  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
956  loc, resultElemPtrType, allocatedPtr);
957  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
958  loc, resultElemPtrType, alignedPtr);
959 
960  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
961  resultElemPtrType, allocatedPtr);
962  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
963  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
964 
965  // Copy all the index-valued operands.
966  Value sourceIndexVals =
967  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
968  sourceUnderlyingDesc, sourceElemPtrType);
969  Value resultIndexVals =
970  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
971  resultUnderlyingDesc, resultElemPtrType);
972 
973  int64_t bytesToSkip =
974  2 * llvm::divideCeil(
975  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
976  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
977  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
978  Value copySize = rewriter.create<LLVM::SubOp>(
979  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
980  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
981  copySize, /*isVolatile=*/false);
982 
983  rewriter.replaceOp(op, ValueRange{result});
984  return success();
985  }
986  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
987  }
988 };
989 
990 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
991 /// memref type. In unranked case, the fields are extracted from the underlying
992 /// ranked descriptor.
993 static void extractPointersAndOffset(Location loc,
994  ConversionPatternRewriter &rewriter,
995  const LLVMTypeConverter &typeConverter,
996  Value originalOperand,
997  Value convertedOperand,
998  Value *allocatedPtr, Value *alignedPtr,
999  Value *offset = nullptr) {
1000  Type operandType = originalOperand.getType();
1001  if (isa<MemRefType>(operandType)) {
1002  MemRefDescriptor desc(convertedOperand);
1003  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1004  *alignedPtr = desc.alignedPtr(rewriter, loc);
1005  if (offset != nullptr)
1006  *offset = desc.offset(rewriter, loc);
1007  return;
1008  }
1009 
1010  // These will all cause assert()s on unconvertible types.
1011  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1012  cast<UnrankedMemRefType>(operandType));
1013  auto elementPtrType =
1014  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1015 
1016  // Extract pointer to the underlying ranked memref descriptor and cast it to
1017  // ElemType**.
1018  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1019  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1020 
1022  rewriter, loc, underlyingDescPtr, elementPtrType);
1024  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1025  if (offset != nullptr) {
1027  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1028  }
1029 }
1030 
1031 struct MemRefReinterpretCastOpLowering
1032  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1033  using ConvertOpToLLVMPattern<
1034  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1035 
1036  LogicalResult
1037  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1038  ConversionPatternRewriter &rewriter) const override {
1039  Type srcType = castOp.getSource().getType();
1040 
1041  Value descriptor;
1042  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1043  adaptor, &descriptor)))
1044  return failure();
1045  rewriter.replaceOp(castOp, {descriptor});
1046  return success();
1047  }
1048 
1049 private:
1050  LogicalResult convertSourceMemRefToDescriptor(
1051  ConversionPatternRewriter &rewriter, Type srcType,
1052  memref::ReinterpretCastOp castOp,
1053  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1054  MemRefType targetMemRefType =
1055  cast<MemRefType>(castOp.getResult().getType());
1056  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1057  typeConverter->convertType(targetMemRefType));
1058  if (!llvmTargetDescriptorTy)
1059  return failure();
1060 
1061  // Create descriptor.
1062  Location loc = castOp.getLoc();
1063  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1064 
1065  // Set allocated and aligned pointers.
1066  Value allocatedPtr, alignedPtr;
1067  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1068  castOp.getSource(), adaptor.getSource(),
1069  &allocatedPtr, &alignedPtr);
1070  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1071  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1072 
1073  // Set offset.
1074  if (castOp.isDynamicOffset(0))
1075  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1076  else
1077  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1078 
1079  // Set sizes and strides.
1080  unsigned dynSizeId = 0;
1081  unsigned dynStrideId = 0;
1082  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1083  if (castOp.isDynamicSize(i))
1084  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1085  else
1086  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1087 
1088  if (castOp.isDynamicStride(i))
1089  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1090  else
1091  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1092  }
1093  *descriptor = desc;
1094  return success();
1095  }
1096 };
1097 
1098 struct MemRefReshapeOpLowering
1099  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1101 
1102  LogicalResult
1103  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1104  ConversionPatternRewriter &rewriter) const override {
1105  Type srcType = reshapeOp.getSource().getType();
1106 
1107  Value descriptor;
1108  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1109  adaptor, &descriptor)))
1110  return failure();
1111  rewriter.replaceOp(reshapeOp, {descriptor});
1112  return success();
1113  }
1114 
1115 private:
1116  LogicalResult
1117  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1118  Type srcType, memref::ReshapeOp reshapeOp,
1119  memref::ReshapeOp::Adaptor adaptor,
1120  Value *descriptor) const {
1121  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1122  if (shapeMemRefType.hasStaticShape()) {
1123  MemRefType targetMemRefType =
1124  cast<MemRefType>(reshapeOp.getResult().getType());
1125  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1126  typeConverter->convertType(targetMemRefType));
1127  if (!llvmTargetDescriptorTy)
1128  return failure();
1129 
1130  // Create descriptor.
1131  Location loc = reshapeOp.getLoc();
1132  auto desc =
1133  MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1134 
1135  // Set allocated and aligned pointers.
1136  Value allocatedPtr, alignedPtr;
1137  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1138  reshapeOp.getSource(), adaptor.getSource(),
1139  &allocatedPtr, &alignedPtr);
1140  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1141  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1142 
1143  // Extract the offset and strides from the type.
1144  int64_t offset;
1145  SmallVector<int64_t> strides;
1146  if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1147  return rewriter.notifyMatchFailure(
1148  reshapeOp, "failed to get stride and offset exprs");
1149 
1150  if (!isStaticStrideOrOffset(offset))
1151  return rewriter.notifyMatchFailure(reshapeOp,
1152  "dynamic offset is unsupported");
1153 
1154  desc.setConstantOffset(rewriter, loc, offset);
1155 
1156  assert(targetMemRefType.getLayout().isIdentity() &&
1157  "Identity layout map is a precondition of a valid reshape op");
1158 
1159  Type indexType = getIndexType();
1160  Value stride = nullptr;
1161  int64_t targetRank = targetMemRefType.getRank();
1162  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1163  if (!ShapedType::isDynamic(strides[i])) {
1164  // If the stride for this dimension is dynamic, then use the product
1165  // of the sizes of the inner dimensions.
1166  stride =
1167  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1168  } else if (!stride) {
1169  // `stride` is null only in the first iteration of the loop. However,
1170  // since the target memref has an identity layout, we can safely set
1171  // the innermost stride to 1.
1172  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1173  }
1174 
1175  Value dimSize;
1176  // If the size of this dimension is dynamic, then load it at runtime
1177  // from the shape operand.
1178  if (!targetMemRefType.isDynamicDim(i)) {
1179  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1180  targetMemRefType.getDimSize(i));
1181  } else {
1182  Value shapeOp = reshapeOp.getShape();
1183  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1184  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1185  Type indexType = getIndexType();
1186  if (dimSize.getType() != indexType)
1187  dimSize = typeConverter->materializeTargetConversion(
1188  rewriter, loc, indexType, dimSize);
1189  assert(dimSize && "Invalid memref element type");
1190  }
1191 
1192  desc.setSize(rewriter, loc, i, dimSize);
1193  desc.setStride(rewriter, loc, i, stride);
1194 
1195  // Prepare the stride value for the next dimension.
1196  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1197  }
1198 
1199  *descriptor = desc;
1200  return success();
1201  }
1202 
1203  // The shape is a rank-1 tensor with unknown length.
1204  Location loc = reshapeOp.getLoc();
1205  MemRefDescriptor shapeDesc(adaptor.getShape());
1206  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1207 
1208  // Extract address space and element type.
1209  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1210  unsigned addressSpace =
1211  *getTypeConverter()->getMemRefAddressSpace(targetType);
1212 
1213  // Create the unranked memref descriptor that holds the ranked one. The
1214  // inner descriptor is allocated on stack.
1215  auto targetDesc = UnrankedMemRefDescriptor::undef(
1216  rewriter, loc, typeConverter->convertType(targetType));
1217  targetDesc.setRank(rewriter, loc, resultRank);
1218  SmallVector<Value, 4> sizes;
1219  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1220  targetDesc, addressSpace, sizes);
1221  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1222  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1223  sizes.front());
1224  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1225 
1226  // Extract pointers and offset from the source memref.
1227  Value allocatedPtr, alignedPtr, offset;
1228  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1229  reshapeOp.getSource(), adaptor.getSource(),
1230  &allocatedPtr, &alignedPtr, &offset);
1231 
1232  // Set pointers and offset.
1233  auto elementPtrType =
1234  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1235 
1236  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1237  elementPtrType, allocatedPtr);
1238  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1239  underlyingDescPtr, elementPtrType,
1240  alignedPtr);
1241  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1242  underlyingDescPtr, elementPtrType,
1243  offset);
1244 
1245  // Use the offset pointer as base for further addressing. Copy over the new
1246  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1247  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1248  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1249  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1250  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1251  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1252  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1253  Value resultRankMinusOne =
1254  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1255 
1256  Block *initBlock = rewriter.getInsertionBlock();
1257  Type indexType = getTypeConverter()->getIndexType();
1258  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1259 
1260  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1261  {indexType, indexType}, {loc, loc});
1262 
1263  // Move the remaining initBlock ops to condBlock.
1264  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1265  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1266 
1267  rewriter.setInsertionPointToEnd(initBlock);
1268  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1269  condBlock);
1270  rewriter.setInsertionPointToStart(condBlock);
1271  Value indexArg = condBlock->getArgument(0);
1272  Value strideArg = condBlock->getArgument(1);
1273 
1274  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1275  Value pred = rewriter.create<LLVM::ICmpOp>(
1276  loc, IntegerType::get(rewriter.getContext(), 1),
1277  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1278 
1279  Block *bodyBlock =
1280  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1281  rewriter.setInsertionPointToStart(bodyBlock);
1282 
1283  // Copy size from shape to descriptor.
1284  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1285  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1286  loc, llvmIndexPtrType,
1287  typeConverter->convertType(shapeMemRefType.getElementType()),
1288  shapeOperandPtr, indexArg);
1289  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1290  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1291  targetSizesBase, indexArg, size);
1292 
1293  // Write stride value and compute next one.
1294  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1295  targetStridesBase, indexArg, strideArg);
1296  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1297 
1298  // Decrement loop counter and branch back.
1299  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1300  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1301  condBlock);
1302 
1303  Block *remainder =
1304  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1305 
1306  // Hook up the cond exit to the remainder.
1307  rewriter.setInsertionPointToEnd(condBlock);
1308  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1309  remainder, std::nullopt);
1310 
1311  // Reset position to beginning of new remainder block.
1312  rewriter.setInsertionPointToStart(remainder);
1313 
1314  *descriptor = targetDesc;
1315  return success();
1316  }
1317 };
1318 
1319 /// RessociatingReshapeOp must be expanded before we reach this stage.
1320 /// Report that information.
1321 template <typename ReshapeOp>
1322 class ReassociatingReshapeOpConversion
1323  : public ConvertOpToLLVMPattern<ReshapeOp> {
1324 public:
1326  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1327 
1328  LogicalResult
1329  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1330  ConversionPatternRewriter &rewriter) const override {
1331  return rewriter.notifyMatchFailure(
1332  reshapeOp,
1333  "reassociation operations should have been expanded beforehand");
1334  }
1335 };
1336 
1337 /// Subviews must be expanded before we reach this stage.
1338 /// Report that information.
1339 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1341 
1342  LogicalResult
1343  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1344  ConversionPatternRewriter &rewriter) const override {
1345  return rewriter.notifyMatchFailure(
1346  subViewOp, "subview operations should have been expanded beforehand");
1347  }
1348 };
1349 
1350 /// Conversion pattern that transforms a transpose op into:
1351 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1352 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1353 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1354 /// and stride. Size and stride are permutations of the original values.
1355 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1356 /// The transpose op is replaced by the alloca'ed pointer.
1357 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1358 public:
1360 
1361  LogicalResult
1362  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1363  ConversionPatternRewriter &rewriter) const override {
1364  auto loc = transposeOp.getLoc();
1365  MemRefDescriptor viewMemRef(adaptor.getIn());
1366 
1367  // No permutation, early exit.
1368  if (transposeOp.getPermutation().isIdentity())
1369  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1370 
1371  auto targetMemRef = MemRefDescriptor::undef(
1372  rewriter, loc,
1373  typeConverter->convertType(transposeOp.getIn().getType()));
1374 
1375  // Copy the base and aligned pointers from the old descriptor to the new
1376  // one.
1377  targetMemRef.setAllocatedPtr(rewriter, loc,
1378  viewMemRef.allocatedPtr(rewriter, loc));
1379  targetMemRef.setAlignedPtr(rewriter, loc,
1380  viewMemRef.alignedPtr(rewriter, loc));
1381 
1382  // Copy the offset pointer from the old descriptor to the new one.
1383  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1384 
1385  // Iterate over the dimensions and apply size/stride permutation:
1386  // When enumerating the results of the permutation map, the enumeration
1387  // index is the index into the target dimensions and the DimExpr points to
1388  // the dimension of the source memref.
1389  for (const auto &en :
1390  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1391  int targetPos = en.index();
1392  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1393  targetMemRef.setSize(rewriter, loc, targetPos,
1394  viewMemRef.size(rewriter, loc, sourcePos));
1395  targetMemRef.setStride(rewriter, loc, targetPos,
1396  viewMemRef.stride(rewriter, loc, sourcePos));
1397  }
1398 
1399  rewriter.replaceOp(transposeOp, {targetMemRef});
1400  return success();
1401  }
1402 };
1403 
1404 /// Conversion pattern that transforms an op into:
1405 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1406 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1407 /// and stride.
1408 /// The view op is replaced by the descriptor.
1409 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1411 
1412  // Build and return the value for the idx^th shape dimension, either by
1413  // returning the constant shape dimension or counting the proper dynamic size.
1414  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1415  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1416  Type indexType) const {
1417  assert(idx < shape.size());
1418  if (!ShapedType::isDynamic(shape[idx]))
1419  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1420  // Count the number of dynamic dims in range [0, idx]
1421  unsigned nDynamic =
1422  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1423  return dynamicSizes[nDynamic];
1424  }
1425 
1426  // Build and return the idx^th stride, either by returning the constant stride
1427  // or by computing the dynamic stride from the current `runningStride` and
1428  // `nextSize`. The caller should keep a running stride and update it with the
1429  // result returned by this function.
1430  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1431  ArrayRef<int64_t> strides, Value nextSize,
1432  Value runningStride, unsigned idx, Type indexType) const {
1433  assert(idx < strides.size());
1434  if (!ShapedType::isDynamic(strides[idx]))
1435  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1436  if (nextSize)
1437  return runningStride
1438  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1439  : nextSize;
1440  assert(!runningStride);
1441  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1442  }
1443 
1444  LogicalResult
1445  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1446  ConversionPatternRewriter &rewriter) const override {
1447  auto loc = viewOp.getLoc();
1448 
1449  auto viewMemRefType = viewOp.getType();
1450  auto targetElementTy =
1451  typeConverter->convertType(viewMemRefType.getElementType());
1452  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1453  if (!targetDescTy || !targetElementTy ||
1454  !LLVM::isCompatibleType(targetElementTy) ||
1455  !LLVM::isCompatibleType(targetDescTy))
1456  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1457  failure();
1458 
1459  int64_t offset;
1460  SmallVector<int64_t, 4> strides;
1461  auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1462  if (failed(successStrides))
1463  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1464  assert(offset == 0 && "expected offset to be 0");
1465 
1466  // Target memref must be contiguous in memory (innermost stride is 1), or
1467  // empty (special case when at least one of the memref dimensions is 0).
1468  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1469  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1470  failure();
1471 
1472  // Create the descriptor.
1473  MemRefDescriptor sourceMemRef(adaptor.getSource());
1474  auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1475 
1476  // Field 1: Copy the allocated pointer, used for malloc/free.
1477  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1478  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1479  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1480 
1481  // Field 2: Copy the actual aligned pointer to payload.
1482  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1483  alignedPtr = rewriter.create<LLVM::GEPOp>(
1484  loc, alignedPtr.getType(),
1485  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1486  adaptor.getByteShift());
1487 
1488  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1489 
1490  Type indexType = getIndexType();
1491  // Field 3: The offset in the resulting type must be 0. This is
1492  // because of the type change: an offset on srcType* may not be
1493  // expressible as an offset on dstType*.
1494  targetMemRef.setOffset(
1495  rewriter, loc,
1496  createIndexAttrConstant(rewriter, loc, indexType, offset));
1497 
1498  // Early exit for 0-D corner case.
1499  if (viewMemRefType.getRank() == 0)
1500  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1501 
1502  // Fields 4 and 5: Update sizes and strides.
1503  Value stride = nullptr, nextSize = nullptr;
1504  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1505  // Update size.
1506  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1507  adaptor.getSizes(), i, indexType);
1508  targetMemRef.setSize(rewriter, loc, i, size);
1509  // Update stride.
1510  stride =
1511  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1512  targetMemRef.setStride(rewriter, loc, i, stride);
1513  nextSize = size;
1514  }
1515 
1516  rewriter.replaceOp(viewOp, {targetMemRef});
1517  return success();
1518  }
1519 };
1520 
1521 //===----------------------------------------------------------------------===//
1522 // AtomicRMWOpLowering
1523 //===----------------------------------------------------------------------===//
1524 
1525 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1526 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1527 static std::optional<LLVM::AtomicBinOp>
1528 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1529  switch (atomicOp.getKind()) {
1530  case arith::AtomicRMWKind::addf:
1531  return LLVM::AtomicBinOp::fadd;
1532  case arith::AtomicRMWKind::addi:
1533  return LLVM::AtomicBinOp::add;
1534  case arith::AtomicRMWKind::assign:
1535  return LLVM::AtomicBinOp::xchg;
1536  case arith::AtomicRMWKind::maximumf:
1537  return LLVM::AtomicBinOp::fmax;
1538  case arith::AtomicRMWKind::maxs:
1539  return LLVM::AtomicBinOp::max;
1540  case arith::AtomicRMWKind::maxu:
1541  return LLVM::AtomicBinOp::umax;
1542  case arith::AtomicRMWKind::minimumf:
1543  return LLVM::AtomicBinOp::fmin;
1544  case arith::AtomicRMWKind::mins:
1545  return LLVM::AtomicBinOp::min;
1546  case arith::AtomicRMWKind::minu:
1547  return LLVM::AtomicBinOp::umin;
1548  case arith::AtomicRMWKind::ori:
1549  return LLVM::AtomicBinOp::_or;
1550  case arith::AtomicRMWKind::andi:
1551  return LLVM::AtomicBinOp::_and;
1552  default:
1553  return std::nullopt;
1554  }
1555  llvm_unreachable("Invalid AtomicRMWKind");
1556 }
1557 
1558 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1559  using Base::Base;
1560 
1561  LogicalResult
1562  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1563  ConversionPatternRewriter &rewriter) const override {
1564  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1565  if (!maybeKind)
1566  return failure();
1567  auto memRefType = atomicOp.getMemRefType();
1568  SmallVector<int64_t> strides;
1569  int64_t offset;
1570  if (failed(getStridesAndOffset(memRefType, strides, offset)))
1571  return failure();
1572  auto dataPtr =
1573  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1574  adaptor.getIndices(), rewriter);
1575  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1576  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1577  LLVM::AtomicOrdering::acq_rel);
1578  return success();
1579  }
1580 };
1581 
1582 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1583 class ConvertExtractAlignedPointerAsIndex
1584  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1585 public:
1586  using ConvertOpToLLVMPattern<
1587  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1588 
1589  LogicalResult
1590  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1591  OpAdaptor adaptor,
1592  ConversionPatternRewriter &rewriter) const override {
1593  BaseMemRefType sourceTy = extractOp.getSource().getType();
1594 
1595  Value alignedPtr;
1596  if (sourceTy.hasRank()) {
1597  MemRefDescriptor desc(adaptor.getSource());
1598  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1599  } else {
1600  auto elementPtrTy = LLVM::LLVMPointerType::get(
1601  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1602 
1603  UnrankedMemRefDescriptor desc(adaptor.getSource());
1604  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1605 
1607  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1608  elementPtrTy);
1609  }
1610 
1611  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1612  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1613  return success();
1614  }
1615 };
1616 
1617 /// Materialize the MemRef descriptor represented by the results of
1618 /// ExtractStridedMetadataOp.
1619 class ExtractStridedMetadataOpLowering
1620  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1621 public:
1622  using ConvertOpToLLVMPattern<
1623  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1624 
1625  LogicalResult
1626  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1627  OpAdaptor adaptor,
1628  ConversionPatternRewriter &rewriter) const override {
1629 
1630  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1631  return failure();
1632 
1633  // Create the descriptor.
1634  MemRefDescriptor sourceMemRef(adaptor.getSource());
1635  Location loc = extractStridedMetadataOp.getLoc();
1636  Value source = extractStridedMetadataOp.getSource();
1637 
1638  auto sourceMemRefType = cast<MemRefType>(source.getType());
1639  int64_t rank = sourceMemRefType.getRank();
1640  SmallVector<Value> results;
1641  results.reserve(2 + rank * 2);
1642 
1643  // Base buffer.
1644  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1645  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1647  rewriter, loc, *getTypeConverter(),
1648  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1649  baseBuffer, alignedBuffer);
1650  results.push_back((Value)dstMemRef);
1651 
1652  // Offset.
1653  results.push_back(sourceMemRef.offset(rewriter, loc));
1654 
1655  // Sizes.
1656  for (unsigned i = 0; i < rank; ++i)
1657  results.push_back(sourceMemRef.size(rewriter, loc, i));
1658  // Strides.
1659  for (unsigned i = 0; i < rank; ++i)
1660  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1661 
1662  rewriter.replaceOp(extractStridedMetadataOp, results);
1663  return success();
1664  }
1665 };
1666 
1667 } // namespace
1668 
1670  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1671  // clang-format off
1672  patterns.add<
1673  AllocaOpLowering,
1674  AllocaScopeOpLowering,
1675  AtomicRMWOpLowering,
1676  AssumeAlignmentOpLowering,
1677  ConvertExtractAlignedPointerAsIndex,
1678  DimOpLowering,
1679  ExtractStridedMetadataOpLowering,
1680  GenericAtomicRMWOpLowering,
1681  GlobalMemrefOpLowering,
1682  GetGlobalMemrefOpLowering,
1683  LoadOpLowering,
1684  MemRefCastOpLowering,
1685  MemRefCopyOpLowering,
1686  MemorySpaceCastOpLowering,
1687  MemRefReinterpretCastOpLowering,
1688  MemRefReshapeOpLowering,
1689  PrefetchOpLowering,
1690  RankOpLowering,
1691  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1692  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1693  StoreOpLowering,
1694  SubViewOpLowering,
1696  ViewOpLowering>(converter);
1697  // clang-format on
1698  auto allocLowering = converter.getOptions().allocLowering;
1700  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1701  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1702  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1703 }
1704 
1705 namespace {
1706 struct FinalizeMemRefToLLVMConversionPass
1707  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1708  FinalizeMemRefToLLVMConversionPass> {
1709  using FinalizeMemRefToLLVMConversionPassBase::
1710  FinalizeMemRefToLLVMConversionPassBase;
1711 
1712  void runOnOperation() override {
1713  Operation *op = getOperation();
1714  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1716  dataLayoutAnalysis.getAtOrAbove(op));
1717  options.allocLowering =
1720 
1721  options.useGenericFunctions = useGenericFunctions;
1722 
1723  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1724  options.overrideIndexBitwidth(indexBitwidth);
1725 
1726  LLVMTypeConverter typeConverter(&getContext(), options,
1727  &dataLayoutAnalysis);
1728  RewritePatternSet patterns(&getContext());
1729  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1730  LLVMConversionTarget target(getContext());
1731  target.addLegalOp<func::FuncOp>();
1732  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1733  signalPassFailure();
1734  }
1735 };
1736 
1737 /// Implement the interface to convert MemRef to LLVM.
1738 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1740  void loadDependentDialects(MLIRContext *context) const final {
1741  context->loadDialect<LLVM::LLVMDialect>();
1742  }
1743 
1744  /// Hook for derived dialect interface to provide conversion patterns
1745  /// and mark dialect legal for the conversion target.
1746  void populateConvertToLLVMConversionPatterns(
1747  ConversionTarget &target, LLVMTypeConverter &typeConverter,
1748  RewritePatternSet &patterns) const final {
1749  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1750  }
1751 };
1752 
1753 } // namespace
1754 
1756  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1757  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1758  });
1759 }
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:146
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:228
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI8Type()
Definition: Builders.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:450
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:567
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:441
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:449
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
result_range getResults()
Definition: Operation.h:410
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:23
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for AllocOp and AllocaOp.