MLIR  18.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 
39 namespace {
40 
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42  return !ShapedType::isDynamic(strideOrOffset);
43 }
44 
45 LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46  ModuleOp module) {
47  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48 
49  if (useGenericFn)
50  return LLVM::lookupOrCreateGenericFreeFn(module);
51 
52  return LLVM::lookupOrCreateFreeFn(module);
53 }
54 
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56  AllocOpLowering(const LLVMTypeConverter &converter)
57  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58  converter) {}
59  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60  Location loc, Value sizeBytes,
61  Operation *op) const override {
62  return allocateBufferManuallyAlign(
63  rewriter, loc, sizeBytes, op,
64  getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
65  }
66 };
67 
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69  AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71  converter) {}
72  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73  Location loc, Value sizeBytes,
74  Operation *op) const override {
75  Value ptr = allocateBufferAutoAlign(
76  rewriter, loc, sizeBytes, op, &defaultLayout,
77  alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78  &defaultLayout));
79  if (!ptr)
80  return std::make_tuple(Value(), Value());
81  return std::make_tuple(ptr, ptr);
82  }
83 
84 private:
85  /// Default layout to use in absence of the corresponding analysis.
86  DataLayout defaultLayout;
87 };
88 
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90  AllocaOpLowering(const LLVMTypeConverter &converter)
91  : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92  converter) {
93  setRequiresNumElements();
94  }
95 
96  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97  /// is set to null for stack allocations. `accessAlignment` is set if
98  /// alignment is needed post allocation (for eg. in conjunction with malloc).
99  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100  Location loc, Value size,
101  Operation *op) const override {
102 
103  // With alloca, one gets a pointer to the element type right away.
104  // For stack allocations.
105  auto allocaOp = cast<memref::AllocaOp>(op);
106  auto elementType =
107  typeConverter->convertType(allocaOp.getType().getElementType());
108  unsigned addrSpace =
109  *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110  auto elementPtrType =
111  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
112 
113  auto allocatedElementPtr =
114  rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115  allocaOp.getAlignment().value_or(0));
116 
117  return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
118  }
119 };
120 
121 struct AllocaScopeOpLowering
122  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
124 
126  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  OpBuilder::InsertionGuard guard(rewriter);
129  Location loc = allocaScopeOp.getLoc();
130 
131  // Split the current block before the AllocaScopeOp to create the inlining
132  // point.
133  auto *currentBlock = rewriter.getInsertionBlock();
134  auto *remainingOpsBlock =
135  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136  Block *continueBlock;
137  if (allocaScopeOp.getNumResults() == 0) {
138  continueBlock = remainingOpsBlock;
139  } else {
140  continueBlock = rewriter.createBlock(
141  remainingOpsBlock, allocaScopeOp.getResultTypes(),
142  SmallVector<Location>(allocaScopeOp->getNumResults(),
143  allocaScopeOp.getLoc()));
144  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
145  }
146 
147  // Inline body region.
148  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
151 
152  // Save stack and then branch into the body of the region.
153  rewriter.setInsertionPointToEnd(currentBlock);
154  auto stackSaveOp =
155  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
157 
158  // Replace the alloca_scope return with a branch that jumps out of the body.
159  // Stack restore before leaving the body region.
160  rewriter.setInsertionPointToEnd(afterBody);
161  auto returnOp =
162  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164  returnOp, returnOp.getResults(), continueBlock);
165 
166  // Insert stack restore before jumping out the body of the region.
167  rewriter.setInsertionPoint(branchOp);
168  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
169 
170  // Replace the op with values return from the body region.
171  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
172 
173  return success();
174  }
175 };
176 
177 struct AssumeAlignmentOpLowering
178  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
180  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182  : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
183 
185  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187  Value memref = adaptor.getMemref();
188  unsigned alignment = op.getAlignment();
189  auto loc = op.getLoc();
190 
191  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192  Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193  rewriter);
194 
195  // Emit llvm.assume(memref & (alignment - 1) == 0).
196  //
197  // This relies on LLVM's CSE optimization (potentially after SROA), since
198  // after CSE all memref instances should get de-duplicated into the same
199  // pointer SSA value.
200  MemRefDescriptor memRefDescriptor(memref);
201  auto intPtrType =
202  getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
203  Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
204  Value mask =
205  createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
206  Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207  rewriter.create<LLVM::AssumeOp>(
208  loc, rewriter.create<LLVM::ICmpOp>(
209  loc, LLVM::ICmpPredicate::eq,
210  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
211 
212  rewriter.eraseOp(op);
213  return success();
214  }
215 };
216 
217 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
218 // The memref descriptor being an SSA value, there is no need to clean it up
219 // in any way.
220 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
222 
223  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
224  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
225 
227  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
228  ConversionPatternRewriter &rewriter) const override {
229  // Insert the `free` declaration if it is not already present.
230  LLVM::LLVMFuncOp freeFunc =
231  getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
232  Value allocatedPtr;
233  if (auto unrankedTy =
234  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
235  auto elementPtrTy = LLVM::LLVMPointerType::get(
236  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
238  rewriter, op.getLoc(),
239  UnrankedMemRefDescriptor(adaptor.getMemref())
240  .memRefDescPtr(rewriter, op.getLoc()),
241  elementPtrTy);
242  } else {
243  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
244  .allocatedPtr(rewriter, op.getLoc());
245  }
246  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
247  return success();
248  }
249 };
250 
251 // A `dim` is converted to a constant for static sizes and to an access to the
252 // size stored in the memref descriptor for dynamic sizes.
253 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
255 
257  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
258  ConversionPatternRewriter &rewriter) const override {
259  Type operandType = dimOp.getSource().getType();
260  if (isa<UnrankedMemRefType>(operandType)) {
261  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
262  operandType, dimOp, adaptor.getOperands(), rewriter);
263  if (failed(extractedSize))
264  return failure();
265  rewriter.replaceOp(dimOp, {*extractedSize});
266  return success();
267  }
268  if (isa<MemRefType>(operandType)) {
269  rewriter.replaceOp(
270  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
271  adaptor.getOperands(), rewriter)});
272  return success();
273  }
274  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
275  }
276 
277 private:
279  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
280  OpAdaptor adaptor,
281  ConversionPatternRewriter &rewriter) const {
282  Location loc = dimOp.getLoc();
283 
284  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
285  auto scalarMemRefType =
286  MemRefType::get({}, unrankedMemRefType.getElementType());
287  FailureOr<unsigned> maybeAddressSpace =
288  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
289  if (failed(maybeAddressSpace)) {
290  dimOp.emitOpError("memref memory space must be convertible to an integer "
291  "address space");
292  return failure();
293  }
294  unsigned addressSpace = *maybeAddressSpace;
295 
296  // Extract pointer to the underlying ranked descriptor and bitcast it to a
297  // memref<element_type> descriptor pointer to minimize the number of GEP
298  // operations.
299  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
300  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
301 
302  Type elementType = typeConverter->convertType(scalarMemRefType);
303 
304  // Get pointer to offset field of memref<element_type> descriptor.
305  auto indexPtrTy =
306  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
307  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
308  loc, indexPtrTy, elementType, underlyingRankedDesc,
309  ArrayRef<LLVM::GEPArg>{0, 2});
310 
311  // The size value that we have to extract can be obtained using GEPop with
312  // `dimOp.index() + 1` index argument.
313  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
314  loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
315  adaptor.getIndex());
316  Value sizePtr = rewriter.create<LLVM::GEPOp>(
317  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
318  idxPlusOne);
319  return rewriter
320  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
321  .getResult();
322  }
323 
324  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
325  if (auto idx = dimOp.getConstantIndex())
326  return idx;
327 
328  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
329  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
330 
331  return std::nullopt;
332  }
333 
334  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
335  OpAdaptor adaptor,
336  ConversionPatternRewriter &rewriter) const {
337  Location loc = dimOp.getLoc();
338 
339  // Take advantage if index is constant.
340  MemRefType memRefType = cast<MemRefType>(operandType);
341  Type indexType = getIndexType();
342  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
343  int64_t i = *index;
344  if (i >= 0 && i < memRefType.getRank()) {
345  if (memRefType.isDynamicDim(i)) {
346  // extract dynamic size from the memref descriptor.
347  MemRefDescriptor descriptor(adaptor.getSource());
348  return descriptor.size(rewriter, loc, i);
349  }
350  // Use constant for static size.
351  int64_t dimSize = memRefType.getDimSize(i);
352  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
353  }
354  }
355  Value index = adaptor.getIndex();
356  int64_t rank = memRefType.getRank();
357  MemRefDescriptor memrefDescriptor(adaptor.getSource());
358  return memrefDescriptor.size(rewriter, loc, index, rank);
359  }
360 };
361 
362 /// Common base for load and store operations on MemRefs. Restricts the match
363 /// to supported MemRef types. Provides functionality to emit code accessing a
364 /// specific element of the underlying data buffer.
365 template <typename Derived>
366 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
369  using Base = LoadStoreOpLowering<Derived>;
370 
371  LogicalResult match(Derived op) const override {
372  MemRefType type = op.getMemRefType();
373  return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
374  }
375 };
376 
377 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
378 /// retried until it succeeds in atomically storing a new value into memory.
379 ///
380 /// +---------------------------------+
381 /// | <code before the AtomicRMWOp> |
382 /// | <compute initial %loaded> |
383 /// | cf.br loop(%loaded) |
384 /// +---------------------------------+
385 /// |
386 /// -------| |
387 /// | v v
388 /// | +--------------------------------+
389 /// | | loop(%loaded): |
390 /// | | <body contents> |
391 /// | | %pair = cmpxchg |
392 /// | | %ok = %pair[0] |
393 /// | | %new = %pair[1] |
394 /// | | cf.cond_br %ok, end, loop(%new) |
395 /// | +--------------------------------+
396 /// | | |
397 /// |----------- |
398 /// v
399 /// +--------------------------------+
400 /// | end: |
401 /// | <code after the AtomicRMWOp> |
402 /// +--------------------------------+
403 ///
404 struct GenericAtomicRMWOpLowering
405  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
406  using Base::Base;
407 
409  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
410  ConversionPatternRewriter &rewriter) const override {
411  auto loc = atomicOp.getLoc();
412  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
413 
414  // Split the block into initial, loop, and ending parts.
415  auto *initBlock = rewriter.getInsertionBlock();
416  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
417  loopBlock->addArgument(valueType, loc);
418 
419  auto *endBlock =
420  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
421 
422  // Compute the loaded value and branch to the loop block.
423  rewriter.setInsertionPointToEnd(initBlock);
424  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
425  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
426  adaptor.getIndices(), rewriter);
427  Value init = rewriter.create<LLVM::LoadOp>(
428  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
429  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
430 
431  // Prepare the body of the loop block.
432  rewriter.setInsertionPointToStart(loopBlock);
433 
434  // Clone the GenericAtomicRMWOp region and extract the result.
435  auto loopArgument = loopBlock->getArgument(0);
436  IRMapping mapping;
437  mapping.map(atomicOp.getCurrentValue(), loopArgument);
438  Block &entryBlock = atomicOp.body().front();
439  for (auto &nestedOp : entryBlock.without_terminator()) {
440  Operation *clone = rewriter.clone(nestedOp, mapping);
441  mapping.map(nestedOp.getResults(), clone->getResults());
442  }
443  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
444 
445  // Prepare the epilog of the loop block.
446  // Append the cmpxchg op to the end of the loop block.
447  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
448  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
449  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
450  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
451  // Extract the %new_loaded and %ok values from the pair.
452  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
453  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
454 
455  // Conditionally branch to the end or back to the loop depending on %ok.
456  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
457  loopBlock, newLoaded);
458 
459  rewriter.setInsertionPointToEnd(endBlock);
460 
461  // The 'result' of the atomic_rmw op is the newly loaded value.
462  rewriter.replaceOp(atomicOp, {newLoaded});
463 
464  return success();
465  }
466 };
467 
468 /// Returns the LLVM type of the global variable given the memref type `type`.
469 static Type
470 convertGlobalMemrefTypeToLLVM(MemRefType type,
471  const LLVMTypeConverter &typeConverter) {
472  // LLVM type for a global memref will be a multi-dimension array. For
473  // declarations or uninitialized global memrefs, we can potentially flatten
474  // this to a 1D array. However, for memref.global's with an initial value,
475  // we do not intend to flatten the ElementsAttribute when going from std ->
476  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
477  Type elementType = typeConverter.convertType(type.getElementType());
478  Type arrayTy = elementType;
479  // Shape has the outermost dim at index 0, so need to walk it backwards
480  for (int64_t dim : llvm::reverse(type.getShape()))
481  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
482  return arrayTy;
483 }
484 
485 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
486 struct GlobalMemrefOpLowering
487  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
489 
491  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
492  ConversionPatternRewriter &rewriter) const override {
493  MemRefType type = global.getType();
494  if (!isConvertibleAndHasIdentityMaps(type))
495  return failure();
496 
497  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
498 
499  LLVM::Linkage linkage =
500  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
501 
502  Attribute initialValue = nullptr;
503  if (!global.isExternal() && !global.isUninitialized()) {
504  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
505  initialValue = elementsAttr;
506 
507  // For scalar memrefs, the global variable created is of the element type,
508  // so unpack the elements attribute to extract the value.
509  if (type.getRank() == 0)
510  initialValue = elementsAttr.getSplatValue<Attribute>();
511  }
512 
513  uint64_t alignment = global.getAlignment().value_or(0);
514  FailureOr<unsigned> addressSpace =
515  getTypeConverter()->getMemRefAddressSpace(type);
516  if (failed(addressSpace))
517  return global.emitOpError(
518  "memory space cannot be converted to an integer address space");
519  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
520  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
521  initialValue, alignment, *addressSpace);
522  if (!global.isExternal() && global.isUninitialized()) {
523  Block *blk = new Block();
524  newGlobal.getInitializerRegion().push_back(blk);
525  rewriter.setInsertionPointToStart(blk);
526  Value undef[] = {
527  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
528  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
529  }
530  return success();
531  }
532 };
533 
534 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
535 /// the first element stashed into the descriptor. This reuses
536 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
537 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
538  GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
539  : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
540  converter) {}
541 
542  /// Buffer "allocation" for memref.get_global op is getting the address of
543  /// the global variable referenced.
544  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
545  Location loc, Value sizeBytes,
546  Operation *op) const override {
547  auto getGlobalOp = cast<memref::GetGlobalOp>(op);
548  MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
549 
550  // This is called after a type conversion, which would have failed if this
551  // call fails.
552  FailureOr<unsigned> maybeAddressSpace =
553  getTypeConverter()->getMemRefAddressSpace(type);
554  if (failed(maybeAddressSpace))
555  return std::make_tuple(Value(), Value());
556  unsigned memSpace = *maybeAddressSpace;
557 
558  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
559  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
560  auto addressOf =
561  rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
562 
563  // Get the address of the first element in the array by creating a GEP with
564  // the address of the GV as the base, and (rank + 1) number of 0 indices.
565  auto gep = rewriter.create<LLVM::GEPOp>(
566  loc, ptrTy, arrayTy, addressOf,
567  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
568 
569  // We do not expect the memref obtained using `memref.get_global` to be
570  // ever deallocated. Set the allocated pointer to be known bad value to
571  // help debug if that ever happens.
572  auto intPtrType = getIntPtrType(memSpace);
573  Value deadBeefConst =
574  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
575  auto deadBeefPtr =
576  rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
577 
578  // Both allocated and aligned pointers are same. We could potentially stash
579  // a nullptr for the allocated pointer since we do not expect any dealloc.
580  return std::make_tuple(deadBeefPtr, gep);
581  }
582 };
583 
584 // Load operation is lowered to obtaining a pointer to the indexed element
585 // and loading it.
586 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
587  using Base::Base;
588 
590  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
591  ConversionPatternRewriter &rewriter) const override {
592  auto type = loadOp.getMemRefType();
593 
594  Value dataPtr =
595  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
596  adaptor.getIndices(), rewriter);
597  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
598  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
599  false, loadOp.getNontemporal());
600  return success();
601  }
602 };
603 
604 // Store operation is lowered to obtaining a pointer to the indexed element,
605 // and storing the given value to it.
606 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
607  using Base::Base;
608 
610  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
611  ConversionPatternRewriter &rewriter) const override {
612  auto type = op.getMemRefType();
613 
614  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
615  adaptor.getIndices(), rewriter);
616  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
617  0, false, op.getNontemporal());
618  return success();
619  }
620 };
621 
622 // The prefetch operation is lowered in a way similar to the load operation
623 // except that the llvm.prefetch operation is used for replacement.
624 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
625  using Base::Base;
626 
628  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
629  ConversionPatternRewriter &rewriter) const override {
630  auto type = prefetchOp.getMemRefType();
631  auto loc = prefetchOp.getLoc();
632 
633  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
634  adaptor.getIndices(), rewriter);
635 
636  // Replace with llvm.prefetch.
637  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
638  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
639  IntegerAttr isData =
640  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
641  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
642  localityHint, isData);
643  return success();
644  }
645 };
646 
647 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
649 
651  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
652  ConversionPatternRewriter &rewriter) const override {
653  Location loc = op.getLoc();
654  Type operandType = op.getMemref().getType();
655  if (dyn_cast<UnrankedMemRefType>(operandType)) {
656  UnrankedMemRefDescriptor desc(adaptor.getMemref());
657  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
658  return success();
659  }
660  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
661  Type indexType = getIndexType();
662  rewriter.replaceOp(op,
663  {createIndexAttrConstant(rewriter, loc, indexType,
664  rankedMemRefType.getRank())});
665  return success();
666  }
667  return failure();
668  }
669 };
670 
671 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
673 
674  LogicalResult match(memref::CastOp memRefCastOp) const override {
675  Type srcType = memRefCastOp.getOperand().getType();
676  Type dstType = memRefCastOp.getType();
677 
678  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
679  // used for type erasure. For now they must preserve underlying element type
680  // and require source and result type to have the same rank. Therefore,
681  // perform a sanity check that the underlying structs are the same. Once op
682  // semantics are relaxed we can revisit.
683  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
684  return success(typeConverter->convertType(srcType) ==
685  typeConverter->convertType(dstType));
686 
687  // At least one of the operands is unranked type
688  assert(isa<UnrankedMemRefType>(srcType) ||
689  isa<UnrankedMemRefType>(dstType));
690 
691  // Unranked to unranked cast is disallowed
692  return !(isa<UnrankedMemRefType>(srcType) &&
693  isa<UnrankedMemRefType>(dstType))
694  ? success()
695  : failure();
696  }
697 
698  void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
699  ConversionPatternRewriter &rewriter) const override {
700  auto srcType = memRefCastOp.getOperand().getType();
701  auto dstType = memRefCastOp.getType();
702  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
703  auto loc = memRefCastOp.getLoc();
704 
705  // For ranked/ranked case, just keep the original descriptor.
706  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
707  return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
708 
709  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
710  // Casting ranked to unranked memref type
711  // Set the rank in the destination from the memref type
712  // Allocate space on the stack and copy the src memref descriptor
713  // Set the ptr in the destination to the stack space
714  auto srcMemRefType = cast<MemRefType>(srcType);
715  int64_t rank = srcMemRefType.getRank();
716  // ptr = AllocaOp sizeof(MemRefDescriptor)
717  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
718  loc, adaptor.getSource(), rewriter);
719 
720  // rank = ConstantOp srcRank
721  auto rankVal = rewriter.create<LLVM::ConstantOp>(
722  loc, getIndexType(), rewriter.getIndexAttr(rank));
723  // undef = UndefOp
724  UnrankedMemRefDescriptor memRefDesc =
725  UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
726  // d1 = InsertValueOp undef, rank, 0
727  memRefDesc.setRank(rewriter, loc, rankVal);
728  // d2 = InsertValueOp d1, ptr, 1
729  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
730  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
731 
732  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
733  // Casting from unranked type to ranked.
734  // The operation is assumed to be doing a correct cast. If the destination
735  // type mismatches the unranked the type, it is undefined behavior.
736  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
737  // ptr = ExtractValueOp src, 1
738  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
739 
740  // struct = LoadOp ptr
741  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
742  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
743  } else {
744  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
745  }
746  }
747 };
748 
749 /// Pattern to lower a `memref.copy` to llvm.
750 ///
751 /// For memrefs with identity layouts, the copy is lowered to the llvm
752 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
753 /// to the generic `MemrefCopyFn`.
754 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
756 
758  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
759  ConversionPatternRewriter &rewriter) const {
760  auto loc = op.getLoc();
761  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
762 
763  MemRefDescriptor srcDesc(adaptor.getSource());
764 
765  // Compute number of elements.
766  Value numElements = rewriter.create<LLVM::ConstantOp>(
767  loc, getIndexType(), rewriter.getIndexAttr(1));
768  for (int pos = 0; pos < srcType.getRank(); ++pos) {
769  auto size = srcDesc.size(rewriter, loc, pos);
770  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
771  }
772 
773  // Get element size.
774  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
775  // Compute total.
776  Value totalSize =
777  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
778 
779  Type elementType = typeConverter->convertType(srcType.getElementType());
780 
781  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
782  Value srcOffset = srcDesc.offset(rewriter, loc);
783  Value srcPtr = rewriter.create<LLVM::GEPOp>(
784  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
785  MemRefDescriptor targetDesc(adaptor.getTarget());
786  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
787  Value targetOffset = targetDesc.offset(rewriter, loc);
788  Value targetPtr = rewriter.create<LLVM::GEPOp>(
789  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
790  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
791  /*isVolatile=*/false);
792  rewriter.eraseOp(op);
793 
794  return success();
795  }
796 
798  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
799  ConversionPatternRewriter &rewriter) const {
800  auto loc = op.getLoc();
801  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
802  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
803 
804  // First make sure we have an unranked memref descriptor representation.
805  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
806  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
807  type.getRank());
808  auto *typeConverter = getTypeConverter();
809  auto ptr =
810  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
811 
812  auto unrankedType =
813  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
815  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
816  };
817 
818  // Save stack position before promoting descriptors
819  auto stackSaveOp =
820  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
821 
822  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
823  Value unrankedSource =
824  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
825  : adaptor.getSource();
826  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
827  Value unrankedTarget =
828  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
829  : adaptor.getTarget();
830 
831  // Now promote the unranked descriptors to the stack.
832  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
833  rewriter.getIndexAttr(1));
834  auto promote = [&](Value desc) {
835  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
836  auto allocated =
837  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
838  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
839  return allocated;
840  };
841 
842  auto sourcePtr = promote(unrankedSource);
843  auto targetPtr = promote(unrankedTarget);
844 
845  // Derive size from llvm.getelementptr which will account for any
846  // potential alignment
847  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
848  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
849  op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
850  rewriter.create<LLVM::CallOp>(loc, copyFn,
851  ValueRange{elemSize, sourcePtr, targetPtr});
852 
853  // Restore stack used for descriptors
854  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
855 
856  rewriter.eraseOp(op);
857 
858  return success();
859  }
860 
862  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
863  ConversionPatternRewriter &rewriter) const override {
864  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
865  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
866 
867  auto isContiguousMemrefType = [&](BaseMemRefType type) {
868  auto memrefType = dyn_cast<mlir::MemRefType>(type);
869  // We can use memcpy for memrefs if they have an identity layout or are
870  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
871  // special case handled by memrefCopy.
872  return memrefType &&
873  (memrefType.getLayout().isIdentity() ||
874  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
876  };
877 
878  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
879  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
880 
881  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
882  }
883 };
884 
885 struct MemorySpaceCastOpLowering
886  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
888  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
889 
891  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
892  ConversionPatternRewriter &rewriter) const override {
893  Location loc = op.getLoc();
894 
895  Type resultType = op.getDest().getType();
896  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
897  auto resultDescType =
898  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
899  Type newPtrType = resultDescType.getBody()[0];
900 
901  SmallVector<Value> descVals;
902  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
903  descVals);
904  descVals[0] =
905  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
906  descVals[1] =
907  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
908  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
909  resultTypeR, descVals);
910  rewriter.replaceOp(op, result);
911  return success();
912  }
913  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
914  // Since the type converter won't be doing this for us, get the address
915  // space.
916  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
917  FailureOr<unsigned> maybeSourceAddrSpace =
918  getTypeConverter()->getMemRefAddressSpace(sourceType);
919  if (failed(maybeSourceAddrSpace))
920  return rewriter.notifyMatchFailure(loc,
921  "non-integer source address space");
922  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
923  FailureOr<unsigned> maybeResultAddrSpace =
924  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
925  if (failed(maybeResultAddrSpace))
926  return rewriter.notifyMatchFailure(loc,
927  "non-integer result address space");
928  unsigned resultAddrSpace = *maybeResultAddrSpace;
929 
930  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
931  Value rank = sourceDesc.rank(rewriter, loc);
932  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
933 
934  // Create and allocate storage for new memref descriptor.
935  auto result = UnrankedMemRefDescriptor::undef(
936  rewriter, loc, typeConverter->convertType(resultTypeU));
937  result.setRank(rewriter, loc, rank);
938  SmallVector<Value, 1> sizes;
939  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
940  result, resultAddrSpace, sizes);
941  Value resultUnderlyingSize = sizes.front();
942  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
943  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
944  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
945 
946  // Copy pointers, performing address space casts.
947  auto sourceElemPtrType =
948  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
949  auto resultElemPtrType =
950  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
951 
952  Value allocatedPtr = sourceDesc.allocatedPtr(
953  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
954  Value alignedPtr =
955  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
956  sourceUnderlyingDesc, sourceElemPtrType);
957  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
958  loc, resultElemPtrType, allocatedPtr);
959  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
960  loc, resultElemPtrType, alignedPtr);
961 
962  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
963  resultElemPtrType, allocatedPtr);
964  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
965  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
966 
967  // Copy all the index-valued operands.
968  Value sourceIndexVals =
969  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
970  sourceUnderlyingDesc, sourceElemPtrType);
971  Value resultIndexVals =
972  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
973  resultUnderlyingDesc, resultElemPtrType);
974 
975  int64_t bytesToSkip =
976  2 *
977  ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
978  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
979  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
980  Value copySize = rewriter.create<LLVM::SubOp>(
981  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
982  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
983  copySize, /*isVolatile=*/false);
984 
985  rewriter.replaceOp(op, ValueRange{result});
986  return success();
987  }
988  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
989  }
990 };
991 
992 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
993 /// memref type. In unranked case, the fields are extracted from the underlying
994 /// ranked descriptor.
995 static void extractPointersAndOffset(Location loc,
996  ConversionPatternRewriter &rewriter,
997  const LLVMTypeConverter &typeConverter,
998  Value originalOperand,
999  Value convertedOperand,
1000  Value *allocatedPtr, Value *alignedPtr,
1001  Value *offset = nullptr) {
1002  Type operandType = originalOperand.getType();
1003  if (isa<MemRefType>(operandType)) {
1004  MemRefDescriptor desc(convertedOperand);
1005  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1006  *alignedPtr = desc.alignedPtr(rewriter, loc);
1007  if (offset != nullptr)
1008  *offset = desc.offset(rewriter, loc);
1009  return;
1010  }
1011 
1012  // These will all cause assert()s on unconvertible types.
1013  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1014  cast<UnrankedMemRefType>(operandType));
1015  auto elementPtrType =
1016  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1017 
1018  // Extract pointer to the underlying ranked memref descriptor and cast it to
1019  // ElemType**.
1020  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1021  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1022 
1024  rewriter, loc, underlyingDescPtr, elementPtrType);
1026  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1027  if (offset != nullptr) {
1029  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1030  }
1031 }
1032 
1033 struct MemRefReinterpretCastOpLowering
1034  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1035  using ConvertOpToLLVMPattern<
1036  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1037 
1039  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1040  ConversionPatternRewriter &rewriter) const override {
1041  Type srcType = castOp.getSource().getType();
1042 
1043  Value descriptor;
1044  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1045  adaptor, &descriptor)))
1046  return failure();
1047  rewriter.replaceOp(castOp, {descriptor});
1048  return success();
1049  }
1050 
1051 private:
1052  LogicalResult convertSourceMemRefToDescriptor(
1053  ConversionPatternRewriter &rewriter, Type srcType,
1054  memref::ReinterpretCastOp castOp,
1055  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1056  MemRefType targetMemRefType =
1057  cast<MemRefType>(castOp.getResult().getType());
1058  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1059  typeConverter->convertType(targetMemRefType));
1060  if (!llvmTargetDescriptorTy)
1061  return failure();
1062 
1063  // Create descriptor.
1064  Location loc = castOp.getLoc();
1065  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1066 
1067  // Set allocated and aligned pointers.
1068  Value allocatedPtr, alignedPtr;
1069  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1070  castOp.getSource(), adaptor.getSource(),
1071  &allocatedPtr, &alignedPtr);
1072  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1073  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1074 
1075  // Set offset.
1076  if (castOp.isDynamicOffset(0))
1077  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1078  else
1079  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1080 
1081  // Set sizes and strides.
1082  unsigned dynSizeId = 0;
1083  unsigned dynStrideId = 0;
1084  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1085  if (castOp.isDynamicSize(i))
1086  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1087  else
1088  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1089 
1090  if (castOp.isDynamicStride(i))
1091  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1092  else
1093  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1094  }
1095  *descriptor = desc;
1096  return success();
1097  }
1098 };
1099 
1100 struct MemRefReshapeOpLowering
1101  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1103 
1105  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1106  ConversionPatternRewriter &rewriter) const override {
1107  Type srcType = reshapeOp.getSource().getType();
1108 
1109  Value descriptor;
1110  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1111  adaptor, &descriptor)))
1112  return failure();
1113  rewriter.replaceOp(reshapeOp, {descriptor});
1114  return success();
1115  }
1116 
1117 private:
1119  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1120  Type srcType, memref::ReshapeOp reshapeOp,
1121  memref::ReshapeOp::Adaptor adaptor,
1122  Value *descriptor) const {
1123  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1124  if (shapeMemRefType.hasStaticShape()) {
1125  MemRefType targetMemRefType =
1126  cast<MemRefType>(reshapeOp.getResult().getType());
1127  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1128  typeConverter->convertType(targetMemRefType));
1129  if (!llvmTargetDescriptorTy)
1130  return failure();
1131 
1132  // Create descriptor.
1133  Location loc = reshapeOp.getLoc();
1134  auto desc =
1135  MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1136 
1137  // Set allocated and aligned pointers.
1138  Value allocatedPtr, alignedPtr;
1139  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1140  reshapeOp.getSource(), adaptor.getSource(),
1141  &allocatedPtr, &alignedPtr);
1142  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1143  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1144 
1145  // Extract the offset and strides from the type.
1146  int64_t offset;
1147  SmallVector<int64_t> strides;
1148  if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1149  return rewriter.notifyMatchFailure(
1150  reshapeOp, "failed to get stride and offset exprs");
1151 
1152  if (!isStaticStrideOrOffset(offset))
1153  return rewriter.notifyMatchFailure(reshapeOp,
1154  "dynamic offset is unsupported");
1155 
1156  desc.setConstantOffset(rewriter, loc, offset);
1157 
1158  assert(targetMemRefType.getLayout().isIdentity() &&
1159  "Identity layout map is a precondition of a valid reshape op");
1160 
1161  Type indexType = getIndexType();
1162  Value stride = nullptr;
1163  int64_t targetRank = targetMemRefType.getRank();
1164  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1165  if (!ShapedType::isDynamic(strides[i])) {
1166  // If the stride for this dimension is dynamic, then use the product
1167  // of the sizes of the inner dimensions.
1168  stride =
1169  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1170  } else if (!stride) {
1171  // `stride` is null only in the first iteration of the loop. However,
1172  // since the target memref has an identity layout, we can safely set
1173  // the innermost stride to 1.
1174  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1175  }
1176 
1177  Value dimSize;
1178  // If the size of this dimension is dynamic, then load it at runtime
1179  // from the shape operand.
1180  if (!targetMemRefType.isDynamicDim(i)) {
1181  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1182  targetMemRefType.getDimSize(i));
1183  } else {
1184  Value shapeOp = reshapeOp.getShape();
1185  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1186  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1187  Type indexType = getIndexType();
1188  if (dimSize.getType() != indexType)
1189  dimSize = typeConverter->materializeTargetConversion(
1190  rewriter, loc, indexType, dimSize);
1191  assert(dimSize && "Invalid memref element type");
1192  }
1193 
1194  desc.setSize(rewriter, loc, i, dimSize);
1195  desc.setStride(rewriter, loc, i, stride);
1196 
1197  // Prepare the stride value for the next dimension.
1198  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1199  }
1200 
1201  *descriptor = desc;
1202  return success();
1203  }
1204 
1205  // The shape is a rank-1 tensor with unknown length.
1206  Location loc = reshapeOp.getLoc();
1207  MemRefDescriptor shapeDesc(adaptor.getShape());
1208  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1209 
1210  // Extract address space and element type.
1211  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1212  unsigned addressSpace =
1213  *getTypeConverter()->getMemRefAddressSpace(targetType);
1214 
1215  // Create the unranked memref descriptor that holds the ranked one. The
1216  // inner descriptor is allocated on stack.
1217  auto targetDesc = UnrankedMemRefDescriptor::undef(
1218  rewriter, loc, typeConverter->convertType(targetType));
1219  targetDesc.setRank(rewriter, loc, resultRank);
1220  SmallVector<Value, 4> sizes;
1221  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1222  targetDesc, addressSpace, sizes);
1223  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1224  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1225  sizes.front());
1226  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1227 
1228  // Extract pointers and offset from the source memref.
1229  Value allocatedPtr, alignedPtr, offset;
1230  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1231  reshapeOp.getSource(), adaptor.getSource(),
1232  &allocatedPtr, &alignedPtr, &offset);
1233 
1234  // Set pointers and offset.
1235  auto elementPtrType =
1236  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1237 
1238  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1239  elementPtrType, allocatedPtr);
1240  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1241  underlyingDescPtr, elementPtrType,
1242  alignedPtr);
1243  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1244  underlyingDescPtr, elementPtrType,
1245  offset);
1246 
1247  // Use the offset pointer as base for further addressing. Copy over the new
1248  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1249  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1250  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1251  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1252  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1253  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1254  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1255  Value resultRankMinusOne =
1256  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1257 
1258  Block *initBlock = rewriter.getInsertionBlock();
1259  Type indexType = getTypeConverter()->getIndexType();
1260  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1261 
1262  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1263  {indexType, indexType}, {loc, loc});
1264 
1265  // Move the remaining initBlock ops to condBlock.
1266  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1267  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1268 
1269  rewriter.setInsertionPointToEnd(initBlock);
1270  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1271  condBlock);
1272  rewriter.setInsertionPointToStart(condBlock);
1273  Value indexArg = condBlock->getArgument(0);
1274  Value strideArg = condBlock->getArgument(1);
1275 
1276  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1277  Value pred = rewriter.create<LLVM::ICmpOp>(
1278  loc, IntegerType::get(rewriter.getContext(), 1),
1279  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1280 
1281  Block *bodyBlock =
1282  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1283  rewriter.setInsertionPointToStart(bodyBlock);
1284 
1285  // Copy size from shape to descriptor.
1286  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1287  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1288  loc, llvmIndexPtrType,
1289  typeConverter->convertType(shapeMemRefType.getElementType()),
1290  shapeOperandPtr, indexArg);
1291  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1292  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1293  targetSizesBase, indexArg, size);
1294 
1295  // Write stride value and compute next one.
1296  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1297  targetStridesBase, indexArg, strideArg);
1298  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1299 
1300  // Decrement loop counter and branch back.
1301  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1302  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1303  condBlock);
1304 
1305  Block *remainder =
1306  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1307 
1308  // Hook up the cond exit to the remainder.
1309  rewriter.setInsertionPointToEnd(condBlock);
1310  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1311  remainder, std::nullopt);
1312 
1313  // Reset position to beginning of new remainder block.
1314  rewriter.setInsertionPointToStart(remainder);
1315 
1316  *descriptor = targetDesc;
1317  return success();
1318  }
1319 };
1320 
1321 /// RessociatingReshapeOp must be expanded before we reach this stage.
1322 /// Report that information.
1323 template <typename ReshapeOp>
1324 class ReassociatingReshapeOpConversion
1325  : public ConvertOpToLLVMPattern<ReshapeOp> {
1326 public:
1328  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1329 
1331  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1332  ConversionPatternRewriter &rewriter) const override {
1333  return rewriter.notifyMatchFailure(
1334  reshapeOp,
1335  "reassociation operations should have been expanded beforehand");
1336  }
1337 };
1338 
1339 /// Subviews must be expanded before we reach this stage.
1340 /// Report that information.
1341 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1343 
1345  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1346  ConversionPatternRewriter &rewriter) const override {
1347  return rewriter.notifyMatchFailure(
1348  subViewOp, "subview operations should have been expanded beforehand");
1349  }
1350 };
1351 
1352 /// Conversion pattern that transforms a transpose op into:
1353 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1354 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1355 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1356 /// and stride. Size and stride are permutations of the original values.
1357 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1358 /// The transpose op is replaced by the alloca'ed pointer.
1359 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1360 public:
1362 
1364  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1365  ConversionPatternRewriter &rewriter) const override {
1366  auto loc = transposeOp.getLoc();
1367  MemRefDescriptor viewMemRef(adaptor.getIn());
1368 
1369  // No permutation, early exit.
1370  if (transposeOp.getPermutation().isIdentity())
1371  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1372 
1373  auto targetMemRef = MemRefDescriptor::undef(
1374  rewriter, loc,
1375  typeConverter->convertType(transposeOp.getIn().getType()));
1376 
1377  // Copy the base and aligned pointers from the old descriptor to the new
1378  // one.
1379  targetMemRef.setAllocatedPtr(rewriter, loc,
1380  viewMemRef.allocatedPtr(rewriter, loc));
1381  targetMemRef.setAlignedPtr(rewriter, loc,
1382  viewMemRef.alignedPtr(rewriter, loc));
1383 
1384  // Copy the offset pointer from the old descriptor to the new one.
1385  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1386 
1387  // Iterate over the dimensions and apply size/stride permutation:
1388  // When enumerating the results of the permutation map, the enumeration
1389  // index is the index into the target dimensions and the DimExpr points to
1390  // the dimension of the source memref.
1391  for (const auto &en :
1392  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1393  int targetPos = en.index();
1394  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1395  targetMemRef.setSize(rewriter, loc, targetPos,
1396  viewMemRef.size(rewriter, loc, sourcePos));
1397  targetMemRef.setStride(rewriter, loc, targetPos,
1398  viewMemRef.stride(rewriter, loc, sourcePos));
1399  }
1400 
1401  rewriter.replaceOp(transposeOp, {targetMemRef});
1402  return success();
1403  }
1404 };
1405 
1406 /// Conversion pattern that transforms an op into:
1407 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1408 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1409 /// and stride.
1410 /// The view op is replaced by the descriptor.
1411 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1413 
1414  // Build and return the value for the idx^th shape dimension, either by
1415  // returning the constant shape dimension or counting the proper dynamic size.
1416  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1417  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1418  Type indexType) const {
1419  assert(idx < shape.size());
1420  if (!ShapedType::isDynamic(shape[idx]))
1421  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1422  // Count the number of dynamic dims in range [0, idx]
1423  unsigned nDynamic =
1424  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1425  return dynamicSizes[nDynamic];
1426  }
1427 
1428  // Build and return the idx^th stride, either by returning the constant stride
1429  // or by computing the dynamic stride from the current `runningStride` and
1430  // `nextSize`. The caller should keep a running stride and update it with the
1431  // result returned by this function.
1432  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1433  ArrayRef<int64_t> strides, Value nextSize,
1434  Value runningStride, unsigned idx, Type indexType) const {
1435  assert(idx < strides.size());
1436  if (!ShapedType::isDynamic(strides[idx]))
1437  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1438  if (nextSize)
1439  return runningStride
1440  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1441  : nextSize;
1442  assert(!runningStride);
1443  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1444  }
1445 
1447  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1448  ConversionPatternRewriter &rewriter) const override {
1449  auto loc = viewOp.getLoc();
1450 
1451  auto viewMemRefType = viewOp.getType();
1452  auto targetElementTy =
1453  typeConverter->convertType(viewMemRefType.getElementType());
1454  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1455  if (!targetDescTy || !targetElementTy ||
1456  !LLVM::isCompatibleType(targetElementTy) ||
1457  !LLVM::isCompatibleType(targetDescTy))
1458  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1459  failure();
1460 
1461  int64_t offset;
1462  SmallVector<int64_t, 4> strides;
1463  auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1464  if (failed(successStrides))
1465  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1466  assert(offset == 0 && "expected offset to be 0");
1467 
1468  // Target memref must be contiguous in memory (innermost stride is 1), or
1469  // empty (special case when at least one of the memref dimensions is 0).
1470  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1471  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1472  failure();
1473 
1474  // Create the descriptor.
1475  MemRefDescriptor sourceMemRef(adaptor.getSource());
1476  auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1477 
1478  // Field 1: Copy the allocated pointer, used for malloc/free.
1479  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1480  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1481  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1482 
1483  // Field 2: Copy the actual aligned pointer to payload.
1484  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1485  alignedPtr = rewriter.create<LLVM::GEPOp>(
1486  loc, alignedPtr.getType(),
1487  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1488  adaptor.getByteShift());
1489 
1490  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1491 
1492  Type indexType = getIndexType();
1493  // Field 3: The offset in the resulting type must be 0. This is
1494  // because of the type change: an offset on srcType* may not be
1495  // expressible as an offset on dstType*.
1496  targetMemRef.setOffset(
1497  rewriter, loc,
1498  createIndexAttrConstant(rewriter, loc, indexType, offset));
1499 
1500  // Early exit for 0-D corner case.
1501  if (viewMemRefType.getRank() == 0)
1502  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1503 
1504  // Fields 4 and 5: Update sizes and strides.
1505  Value stride = nullptr, nextSize = nullptr;
1506  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1507  // Update size.
1508  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1509  adaptor.getSizes(), i, indexType);
1510  targetMemRef.setSize(rewriter, loc, i, size);
1511  // Update stride.
1512  stride =
1513  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1514  targetMemRef.setStride(rewriter, loc, i, stride);
1515  nextSize = size;
1516  }
1517 
1518  rewriter.replaceOp(viewOp, {targetMemRef});
1519  return success();
1520  }
1521 };
1522 
1523 //===----------------------------------------------------------------------===//
1524 // AtomicRMWOpLowering
1525 //===----------------------------------------------------------------------===//
1526 
1527 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1528 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1529 static std::optional<LLVM::AtomicBinOp>
1530 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1531  switch (atomicOp.getKind()) {
1532  case arith::AtomicRMWKind::addf:
1533  return LLVM::AtomicBinOp::fadd;
1534  case arith::AtomicRMWKind::addi:
1535  return LLVM::AtomicBinOp::add;
1536  case arith::AtomicRMWKind::assign:
1537  return LLVM::AtomicBinOp::xchg;
1538  case arith::AtomicRMWKind::maximumf:
1539  return LLVM::AtomicBinOp::fmax;
1540  case arith::AtomicRMWKind::maxs:
1541  return LLVM::AtomicBinOp::max;
1542  case arith::AtomicRMWKind::maxu:
1543  return LLVM::AtomicBinOp::umax;
1544  case arith::AtomicRMWKind::minimumf:
1545  return LLVM::AtomicBinOp::fmin;
1546  case arith::AtomicRMWKind::mins:
1547  return LLVM::AtomicBinOp::min;
1548  case arith::AtomicRMWKind::minu:
1549  return LLVM::AtomicBinOp::umin;
1550  case arith::AtomicRMWKind::ori:
1551  return LLVM::AtomicBinOp::_or;
1552  case arith::AtomicRMWKind::andi:
1553  return LLVM::AtomicBinOp::_and;
1554  default:
1555  return std::nullopt;
1556  }
1557  llvm_unreachable("Invalid AtomicRMWKind");
1558 }
1559 
1560 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1561  using Base::Base;
1562 
1564  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1565  ConversionPatternRewriter &rewriter) const override {
1566  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1567  if (!maybeKind)
1568  return failure();
1569  auto memRefType = atomicOp.getMemRefType();
1570  SmallVector<int64_t> strides;
1571  int64_t offset;
1572  if (failed(getStridesAndOffset(memRefType, strides, offset)))
1573  return failure();
1574  auto dataPtr =
1575  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1576  adaptor.getIndices(), rewriter);
1577  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1578  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1579  LLVM::AtomicOrdering::acq_rel);
1580  return success();
1581  }
1582 };
1583 
1584 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1585 class ConvertExtractAlignedPointerAsIndex
1586  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1587 public:
1588  using ConvertOpToLLVMPattern<
1589  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1590 
1592  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1593  OpAdaptor adaptor,
1594  ConversionPatternRewriter &rewriter) const override {
1595  MemRefDescriptor desc(adaptor.getSource());
1596  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1597  extractOp, getTypeConverter()->getIndexType(),
1598  desc.alignedPtr(rewriter, extractOp->getLoc()));
1599  return success();
1600  }
1601 };
1602 
1603 /// Materialize the MemRef descriptor represented by the results of
1604 /// ExtractStridedMetadataOp.
1605 class ExtractStridedMetadataOpLowering
1606  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1607 public:
1608  using ConvertOpToLLVMPattern<
1609  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1610 
1612  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1613  OpAdaptor adaptor,
1614  ConversionPatternRewriter &rewriter) const override {
1615 
1616  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1617  return failure();
1618 
1619  // Create the descriptor.
1620  MemRefDescriptor sourceMemRef(adaptor.getSource());
1621  Location loc = extractStridedMetadataOp.getLoc();
1622  Value source = extractStridedMetadataOp.getSource();
1623 
1624  auto sourceMemRefType = cast<MemRefType>(source.getType());
1625  int64_t rank = sourceMemRefType.getRank();
1626  SmallVector<Value> results;
1627  results.reserve(2 + rank * 2);
1628 
1629  // Base buffer.
1630  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1631  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1633  rewriter, loc, *getTypeConverter(),
1634  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1635  baseBuffer, alignedBuffer);
1636  results.push_back((Value)dstMemRef);
1637 
1638  // Offset.
1639  results.push_back(sourceMemRef.offset(rewriter, loc));
1640 
1641  // Sizes.
1642  for (unsigned i = 0; i < rank; ++i)
1643  results.push_back(sourceMemRef.size(rewriter, loc, i));
1644  // Strides.
1645  for (unsigned i = 0; i < rank; ++i)
1646  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1647 
1648  rewriter.replaceOp(extractStridedMetadataOp, results);
1649  return success();
1650  }
1651 };
1652 
1653 } // namespace
1654 
1656  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1657  // clang-format off
1658  patterns.add<
1659  AllocaOpLowering,
1660  AllocaScopeOpLowering,
1661  AtomicRMWOpLowering,
1662  AssumeAlignmentOpLowering,
1663  ConvertExtractAlignedPointerAsIndex,
1664  DimOpLowering,
1665  ExtractStridedMetadataOpLowering,
1666  GenericAtomicRMWOpLowering,
1667  GlobalMemrefOpLowering,
1668  GetGlobalMemrefOpLowering,
1669  LoadOpLowering,
1670  MemRefCastOpLowering,
1671  MemRefCopyOpLowering,
1672  MemorySpaceCastOpLowering,
1673  MemRefReinterpretCastOpLowering,
1674  MemRefReshapeOpLowering,
1675  PrefetchOpLowering,
1676  RankOpLowering,
1677  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1678  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1679  StoreOpLowering,
1680  SubViewOpLowering,
1682  ViewOpLowering>(converter);
1683  // clang-format on
1684  auto allocLowering = converter.getOptions().allocLowering;
1686  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1687  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1688  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1689 }
1690 
1691 namespace {
1692 struct FinalizeMemRefToLLVMConversionPass
1693  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1694  FinalizeMemRefToLLVMConversionPass> {
1695  using FinalizeMemRefToLLVMConversionPassBase::
1696  FinalizeMemRefToLLVMConversionPassBase;
1697 
1698  void runOnOperation() override {
1699  Operation *op = getOperation();
1700  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1702  dataLayoutAnalysis.getAtOrAbove(op));
1703  options.allocLowering =
1706 
1707  options.useGenericFunctions = useGenericFunctions;
1708 
1709  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1710  options.overrideIndexBitwidth(indexBitwidth);
1711 
1712  LLVMTypeConverter typeConverter(&getContext(), options,
1713  &dataLayoutAnalysis);
1714  RewritePatternSet patterns(&getContext());
1715  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1716  LLVMConversionTarget target(getContext());
1717  target.addLegalOp<func::FuncOp>();
1718  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1719  signalPassFailure();
1720  }
1721 };
1722 
1723 /// Implement the interface to convert MemRef to LLVM.
1724 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1726  void loadDependentDialects(MLIRContext *context) const final {
1727  context->loadDialect<LLVM::LLVMDialect>();
1728  }
1729 
1730  /// Hook for derived dialect interface to provide conversion patterns
1731  /// and mark dialect legal for the conversion target.
1732  void populateConvertToLLVMConversionPatterns(
1733  ConversionTarget &target, LLVMTypeConverter &typeConverter,
1734  RewritePatternSet &patterns) const final {
1735  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1736  }
1737 };
1738 
1739 } // namespace
1740 
1742  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1743  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1744  });
1745 }
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Operation & back()
Definition: Block.h:145
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:202
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI8Type()
Definition: Builders.cpp:79
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:93
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:427
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
result_range getResults()
Definition: Operation.h:410
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:845
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:22
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:626
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Lowering for AllocOp and AllocaOp.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26