MLIR  20.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <optional>
31 
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
36 
37 using namespace mlir;
38 
39 namespace {
40 
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42  return !ShapedType::isDynamic(strideOrOffset);
43 }
44 
45 LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46  ModuleOp module) {
47  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48 
49  if (useGenericFn)
50  return LLVM::lookupOrCreateGenericFreeFn(module);
51 
52  return LLVM::lookupOrCreateFreeFn(module);
53 }
54 
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56  AllocOpLowering(const LLVMTypeConverter &converter)
57  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58  converter) {}
59  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60  Location loc, Value sizeBytes,
61  Operation *op) const override {
62  return allocateBufferManuallyAlign(
63  rewriter, loc, sizeBytes, op,
64  getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
65  }
66 };
67 
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69  AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71  converter) {}
72  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73  Location loc, Value sizeBytes,
74  Operation *op) const override {
75  Value ptr = allocateBufferAutoAlign(
76  rewriter, loc, sizeBytes, op, &defaultLayout,
77  alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78  &defaultLayout));
79  if (!ptr)
80  return std::make_tuple(Value(), Value());
81  return std::make_tuple(ptr, ptr);
82  }
83 
84 private:
85  /// Default layout to use in absence of the corresponding analysis.
86  DataLayout defaultLayout;
87 };
88 
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90  AllocaOpLowering(const LLVMTypeConverter &converter)
91  : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92  converter) {
93  setRequiresNumElements();
94  }
95 
96  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97  /// is set to null for stack allocations. `accessAlignment` is set if
98  /// alignment is needed post allocation (for eg. in conjunction with malloc).
99  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100  Location loc, Value size,
101  Operation *op) const override {
102 
103  // With alloca, one gets a pointer to the element type right away.
104  // For stack allocations.
105  auto allocaOp = cast<memref::AllocaOp>(op);
106  auto elementType =
107  typeConverter->convertType(allocaOp.getType().getElementType());
108  unsigned addrSpace =
109  *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110  auto elementPtrType =
111  LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
112 
113  auto allocatedElementPtr =
114  rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115  allocaOp.getAlignment().value_or(0));
116 
117  return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
118  }
119 };
120 
121 struct AllocaScopeOpLowering
122  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
124 
125  LogicalResult
126  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  OpBuilder::InsertionGuard guard(rewriter);
129  Location loc = allocaScopeOp.getLoc();
130 
131  // Split the current block before the AllocaScopeOp to create the inlining
132  // point.
133  auto *currentBlock = rewriter.getInsertionBlock();
134  auto *remainingOpsBlock =
135  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136  Block *continueBlock;
137  if (allocaScopeOp.getNumResults() == 0) {
138  continueBlock = remainingOpsBlock;
139  } else {
140  continueBlock = rewriter.createBlock(
141  remainingOpsBlock, allocaScopeOp.getResultTypes(),
142  SmallVector<Location>(allocaScopeOp->getNumResults(),
143  allocaScopeOp.getLoc()));
144  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
145  }
146 
147  // Inline body region.
148  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
151 
152  // Save stack and then branch into the body of the region.
153  rewriter.setInsertionPointToEnd(currentBlock);
154  auto stackSaveOp =
155  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
157 
158  // Replace the alloca_scope return with a branch that jumps out of the body.
159  // Stack restore before leaving the body region.
160  rewriter.setInsertionPointToEnd(afterBody);
161  auto returnOp =
162  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164  returnOp, returnOp.getResults(), continueBlock);
165 
166  // Insert stack restore before jumping out the body of the region.
167  rewriter.setInsertionPoint(branchOp);
168  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
169 
170  // Replace the op with values return from the body region.
171  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
172 
173  return success();
174  }
175 };
176 
177 struct AssumeAlignmentOpLowering
178  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
180  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181  explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182  : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
183 
184  LogicalResult
185  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187  Value memref = adaptor.getMemref();
188  unsigned alignment = op.getAlignment();
189  auto loc = op.getLoc();
190 
191  auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192  Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193  rewriter);
194 
195  // Emit llvm.assume(true) ["align"(memref, alignment)].
196  // This is more direct than ptrtoint-based checks, is explicitly supported,
197  // and works with non-integral address spaces.
198  Value trueCond =
199  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
200  Value alignmentConst =
201  createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
202  rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
203  alignmentConst);
204 
205  rewriter.eraseOp(op);
206  return success();
207  }
208 };
209 
210 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
211 // The memref descriptor being an SSA value, there is no need to clean it up
212 // in any way.
213 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
215 
216  explicit DeallocOpLowering(const LLVMTypeConverter &converter)
217  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
218 
219  LogicalResult
220  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221  ConversionPatternRewriter &rewriter) const override {
222  // Insert the `free` declaration if it is not already present.
223  LLVM::LLVMFuncOp freeFunc =
224  getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225  Value allocatedPtr;
226  if (auto unrankedTy =
227  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
228  auto elementPtrTy = LLVM::LLVMPointerType::get(
229  rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
231  rewriter, op.getLoc(),
232  UnrankedMemRefDescriptor(adaptor.getMemref())
233  .memRefDescPtr(rewriter, op.getLoc()),
234  elementPtrTy);
235  } else {
236  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
237  .allocatedPtr(rewriter, op.getLoc());
238  }
239  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
240  return success();
241  }
242 };
243 
244 // A `dim` is converted to a constant for static sizes and to an access to the
245 // size stored in the memref descriptor for dynamic sizes.
246 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
248 
249  LogicalResult
250  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
251  ConversionPatternRewriter &rewriter) const override {
252  Type operandType = dimOp.getSource().getType();
253  if (isa<UnrankedMemRefType>(operandType)) {
254  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
255  operandType, dimOp, adaptor.getOperands(), rewriter);
256  if (failed(extractedSize))
257  return failure();
258  rewriter.replaceOp(dimOp, {*extractedSize});
259  return success();
260  }
261  if (isa<MemRefType>(operandType)) {
262  rewriter.replaceOp(
263  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
264  adaptor.getOperands(), rewriter)});
265  return success();
266  }
267  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
268  }
269 
270 private:
271  FailureOr<Value>
272  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
273  OpAdaptor adaptor,
274  ConversionPatternRewriter &rewriter) const {
275  Location loc = dimOp.getLoc();
276 
277  auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
278  auto scalarMemRefType =
279  MemRefType::get({}, unrankedMemRefType.getElementType());
280  FailureOr<unsigned> maybeAddressSpace =
281  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
282  if (failed(maybeAddressSpace)) {
283  dimOp.emitOpError("memref memory space must be convertible to an integer "
284  "address space");
285  return failure();
286  }
287  unsigned addressSpace = *maybeAddressSpace;
288 
289  // Extract pointer to the underlying ranked descriptor and bitcast it to a
290  // memref<element_type> descriptor pointer to minimize the number of GEP
291  // operations.
292  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
293  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
294 
295  Type elementType = typeConverter->convertType(scalarMemRefType);
296 
297  // Get pointer to offset field of memref<element_type> descriptor.
298  auto indexPtrTy =
299  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
300  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
301  loc, indexPtrTy, elementType, underlyingRankedDesc,
302  ArrayRef<LLVM::GEPArg>{0, 2});
303 
304  // The size value that we have to extract can be obtained using GEPop with
305  // `dimOp.index() + 1` index argument.
306  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
307  loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
308  adaptor.getIndex());
309  Value sizePtr = rewriter.create<LLVM::GEPOp>(
310  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
311  idxPlusOne);
312  return rewriter
313  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
314  .getResult();
315  }
316 
317  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
318  if (auto idx = dimOp.getConstantIndex())
319  return idx;
320 
321  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
322  return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
323 
324  return std::nullopt;
325  }
326 
327  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
328  OpAdaptor adaptor,
329  ConversionPatternRewriter &rewriter) const {
330  Location loc = dimOp.getLoc();
331 
332  // Take advantage if index is constant.
333  MemRefType memRefType = cast<MemRefType>(operandType);
334  Type indexType = getIndexType();
335  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
336  int64_t i = *index;
337  if (i >= 0 && i < memRefType.getRank()) {
338  if (memRefType.isDynamicDim(i)) {
339  // extract dynamic size from the memref descriptor.
340  MemRefDescriptor descriptor(adaptor.getSource());
341  return descriptor.size(rewriter, loc, i);
342  }
343  // Use constant for static size.
344  int64_t dimSize = memRefType.getDimSize(i);
345  return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
346  }
347  }
348  Value index = adaptor.getIndex();
349  int64_t rank = memRefType.getRank();
350  MemRefDescriptor memrefDescriptor(adaptor.getSource());
351  return memrefDescriptor.size(rewriter, loc, index, rank);
352  }
353 };
354 
355 /// Common base for load and store operations on MemRefs. Restricts the match
356 /// to supported MemRef types. Provides functionality to emit code accessing a
357 /// specific element of the underlying data buffer.
358 template <typename Derived>
359 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
362  using Base = LoadStoreOpLowering<Derived>;
363 
364  LogicalResult match(Derived op) const override {
365  MemRefType type = op.getMemRefType();
366  return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
367  }
368 };
369 
370 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
371 /// retried until it succeeds in atomically storing a new value into memory.
372 ///
373 /// +---------------------------------+
374 /// | <code before the AtomicRMWOp> |
375 /// | <compute initial %loaded> |
376 /// | cf.br loop(%loaded) |
377 /// +---------------------------------+
378 /// |
379 /// -------| |
380 /// | v v
381 /// | +--------------------------------+
382 /// | | loop(%loaded): |
383 /// | | <body contents> |
384 /// | | %pair = cmpxchg |
385 /// | | %ok = %pair[0] |
386 /// | | %new = %pair[1] |
387 /// | | cf.cond_br %ok, end, loop(%new) |
388 /// | +--------------------------------+
389 /// | | |
390 /// |----------- |
391 /// v
392 /// +--------------------------------+
393 /// | end: |
394 /// | <code after the AtomicRMWOp> |
395 /// +--------------------------------+
396 ///
397 struct GenericAtomicRMWOpLowering
398  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
399  using Base::Base;
400 
401  LogicalResult
402  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
403  ConversionPatternRewriter &rewriter) const override {
404  auto loc = atomicOp.getLoc();
405  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
406 
407  // Split the block into initial, loop, and ending parts.
408  auto *initBlock = rewriter.getInsertionBlock();
409  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
410  loopBlock->addArgument(valueType, loc);
411 
412  auto *endBlock =
413  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
414 
415  // Compute the loaded value and branch to the loop block.
416  rewriter.setInsertionPointToEnd(initBlock);
417  auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
418  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
419  adaptor.getIndices(), rewriter);
420  Value init = rewriter.create<LLVM::LoadOp>(
421  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
422  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
423 
424  // Prepare the body of the loop block.
425  rewriter.setInsertionPointToStart(loopBlock);
426 
427  // Clone the GenericAtomicRMWOp region and extract the result.
428  auto loopArgument = loopBlock->getArgument(0);
429  IRMapping mapping;
430  mapping.map(atomicOp.getCurrentValue(), loopArgument);
431  Block &entryBlock = atomicOp.body().front();
432  for (auto &nestedOp : entryBlock.without_terminator()) {
433  Operation *clone = rewriter.clone(nestedOp, mapping);
434  mapping.map(nestedOp.getResults(), clone->getResults());
435  }
436  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
437 
438  // Prepare the epilog of the loop block.
439  // Append the cmpxchg op to the end of the loop block.
440  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
441  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
442  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
443  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
444  // Extract the %new_loaded and %ok values from the pair.
445  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
446  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
447 
448  // Conditionally branch to the end or back to the loop depending on %ok.
449  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
450  loopBlock, newLoaded);
451 
452  rewriter.setInsertionPointToEnd(endBlock);
453 
454  // The 'result' of the atomic_rmw op is the newly loaded value.
455  rewriter.replaceOp(atomicOp, {newLoaded});
456 
457  return success();
458  }
459 };
460 
461 /// Returns the LLVM type of the global variable given the memref type `type`.
462 static Type
463 convertGlobalMemrefTypeToLLVM(MemRefType type,
464  const LLVMTypeConverter &typeConverter) {
465  // LLVM type for a global memref will be a multi-dimension array. For
466  // declarations or uninitialized global memrefs, we can potentially flatten
467  // this to a 1D array. However, for memref.global's with an initial value,
468  // we do not intend to flatten the ElementsAttribute when going from std ->
469  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
470  Type elementType = typeConverter.convertType(type.getElementType());
471  Type arrayTy = elementType;
472  // Shape has the outermost dim at index 0, so need to walk it backwards
473  for (int64_t dim : llvm::reverse(type.getShape()))
474  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
475  return arrayTy;
476 }
477 
478 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
479 struct GlobalMemrefOpLowering
480  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
482 
483  LogicalResult
484  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486  MemRefType type = global.getType();
487  if (!isConvertibleAndHasIdentityMaps(type))
488  return failure();
489 
490  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
491 
492  LLVM::Linkage linkage =
493  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
494 
495  Attribute initialValue = nullptr;
496  if (!global.isExternal() && !global.isUninitialized()) {
497  auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
498  initialValue = elementsAttr;
499 
500  // For scalar memrefs, the global variable created is of the element type,
501  // so unpack the elements attribute to extract the value.
502  if (type.getRank() == 0)
503  initialValue = elementsAttr.getSplatValue<Attribute>();
504  }
505 
506  uint64_t alignment = global.getAlignment().value_or(0);
507  FailureOr<unsigned> addressSpace =
508  getTypeConverter()->getMemRefAddressSpace(type);
509  if (failed(addressSpace))
510  return global.emitOpError(
511  "memory space cannot be converted to an integer address space");
512  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
513  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
514  initialValue, alignment, *addressSpace);
515  if (!global.isExternal() && global.isUninitialized()) {
516  rewriter.createBlock(&newGlobal.getInitializerRegion());
517  Value undef[] = {
518  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
519  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
520  }
521  return success();
522  }
523 };
524 
525 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
526 /// the first element stashed into the descriptor. This reuses
527 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
528 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
529  GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
530  : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
531  converter) {}
532 
533  /// Buffer "allocation" for memref.get_global op is getting the address of
534  /// the global variable referenced.
535  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
536  Location loc, Value sizeBytes,
537  Operation *op) const override {
538  auto getGlobalOp = cast<memref::GetGlobalOp>(op);
539  MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
540 
541  // This is called after a type conversion, which would have failed if this
542  // call fails.
543  FailureOr<unsigned> maybeAddressSpace =
544  getTypeConverter()->getMemRefAddressSpace(type);
545  if (failed(maybeAddressSpace))
546  return std::make_tuple(Value(), Value());
547  unsigned memSpace = *maybeAddressSpace;
548 
549  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
550  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
551  auto addressOf =
552  rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
553 
554  // Get the address of the first element in the array by creating a GEP with
555  // the address of the GV as the base, and (rank + 1) number of 0 indices.
556  auto gep = rewriter.create<LLVM::GEPOp>(
557  loc, ptrTy, arrayTy, addressOf,
558  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
559 
560  // We do not expect the memref obtained using `memref.get_global` to be
561  // ever deallocated. Set the allocated pointer to be known bad value to
562  // help debug if that ever happens.
563  auto intPtrType = getIntPtrType(memSpace);
564  Value deadBeefConst =
565  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
566  auto deadBeefPtr =
567  rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
568 
569  // Both allocated and aligned pointers are same. We could potentially stash
570  // a nullptr for the allocated pointer since we do not expect any dealloc.
571  return std::make_tuple(deadBeefPtr, gep);
572  }
573 };
574 
575 // Load operation is lowered to obtaining a pointer to the indexed element
576 // and loading it.
577 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
578  using Base::Base;
579 
580  LogicalResult
581  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
582  ConversionPatternRewriter &rewriter) const override {
583  auto type = loadOp.getMemRefType();
584 
585  Value dataPtr =
586  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
587  adaptor.getIndices(), rewriter);
588  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
589  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
590  false, loadOp.getNontemporal());
591  return success();
592  }
593 };
594 
595 // Store operation is lowered to obtaining a pointer to the indexed element,
596 // and storing the given value to it.
597 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
598  using Base::Base;
599 
600  LogicalResult
601  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
602  ConversionPatternRewriter &rewriter) const override {
603  auto type = op.getMemRefType();
604 
605  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
606  adaptor.getIndices(), rewriter);
607  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
608  0, false, op.getNontemporal());
609  return success();
610  }
611 };
612 
613 // The prefetch operation is lowered in a way similar to the load operation
614 // except that the llvm.prefetch operation is used for replacement.
615 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
616  using Base::Base;
617 
618  LogicalResult
619  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
620  ConversionPatternRewriter &rewriter) const override {
621  auto type = prefetchOp.getMemRefType();
622  auto loc = prefetchOp.getLoc();
623 
624  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
625  adaptor.getIndices(), rewriter);
626 
627  // Replace with llvm.prefetch.
628  IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
629  IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
630  IntegerAttr isData =
631  rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
632  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
633  localityHint, isData);
634  return success();
635  }
636 };
637 
638 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
640 
641  LogicalResult
642  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
643  ConversionPatternRewriter &rewriter) const override {
644  Location loc = op.getLoc();
645  Type operandType = op.getMemref().getType();
646  if (dyn_cast<UnrankedMemRefType>(operandType)) {
647  UnrankedMemRefDescriptor desc(adaptor.getMemref());
648  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
649  return success();
650  }
651  if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
652  Type indexType = getIndexType();
653  rewriter.replaceOp(op,
654  {createIndexAttrConstant(rewriter, loc, indexType,
655  rankedMemRefType.getRank())});
656  return success();
657  }
658  return failure();
659  }
660 };
661 
662 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
664 
665  LogicalResult match(memref::CastOp memRefCastOp) const override {
666  Type srcType = memRefCastOp.getOperand().getType();
667  Type dstType = memRefCastOp.getType();
668 
669  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
670  // used for type erasure. For now they must preserve underlying element type
671  // and require source and result type to have the same rank. Therefore,
672  // perform a sanity check that the underlying structs are the same. Once op
673  // semantics are relaxed we can revisit.
674  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
675  return success(typeConverter->convertType(srcType) ==
676  typeConverter->convertType(dstType));
677 
678  // At least one of the operands is unranked type
679  assert(isa<UnrankedMemRefType>(srcType) ||
680  isa<UnrankedMemRefType>(dstType));
681 
682  // Unranked to unranked cast is disallowed
683  return !(isa<UnrankedMemRefType>(srcType) &&
684  isa<UnrankedMemRefType>(dstType))
685  ? success()
686  : failure();
687  }
688 
689  void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
690  ConversionPatternRewriter &rewriter) const override {
691  auto srcType = memRefCastOp.getOperand().getType();
692  auto dstType = memRefCastOp.getType();
693  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
694  auto loc = memRefCastOp.getLoc();
695 
696  // For ranked/ranked case, just keep the original descriptor.
697  if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
698  return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
699 
700  if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
701  // Casting ranked to unranked memref type
702  // Set the rank in the destination from the memref type
703  // Allocate space on the stack and copy the src memref descriptor
704  // Set the ptr in the destination to the stack space
705  auto srcMemRefType = cast<MemRefType>(srcType);
706  int64_t rank = srcMemRefType.getRank();
707  // ptr = AllocaOp sizeof(MemRefDescriptor)
708  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
709  loc, adaptor.getSource(), rewriter);
710 
711  // rank = ConstantOp srcRank
712  auto rankVal = rewriter.create<LLVM::ConstantOp>(
713  loc, getIndexType(), rewriter.getIndexAttr(rank));
714  // undef = UndefOp
715  UnrankedMemRefDescriptor memRefDesc =
716  UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
717  // d1 = InsertValueOp undef, rank, 0
718  memRefDesc.setRank(rewriter, loc, rankVal);
719  // d2 = InsertValueOp d1, ptr, 1
720  memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
721  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
722 
723  } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
724  // Casting from unranked type to ranked.
725  // The operation is assumed to be doing a correct cast. If the destination
726  // type mismatches the unranked the type, it is undefined behavior.
727  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
728  // ptr = ExtractValueOp src, 1
729  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
730 
731  // struct = LoadOp ptr
732  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
733  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
734  } else {
735  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
736  }
737  }
738 };
739 
740 /// Pattern to lower a `memref.copy` to llvm.
741 ///
742 /// For memrefs with identity layouts, the copy is lowered to the llvm
743 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
744 /// to the generic `MemrefCopyFn`.
745 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
747 
748  LogicalResult
749  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
750  ConversionPatternRewriter &rewriter) const {
751  auto loc = op.getLoc();
752  auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
753 
754  MemRefDescriptor srcDesc(adaptor.getSource());
755 
756  // Compute number of elements.
757  Value numElements = rewriter.create<LLVM::ConstantOp>(
758  loc, getIndexType(), rewriter.getIndexAttr(1));
759  for (int pos = 0; pos < srcType.getRank(); ++pos) {
760  auto size = srcDesc.size(rewriter, loc, pos);
761  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
762  }
763 
764  // Get element size.
765  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
766  // Compute total.
767  Value totalSize =
768  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
769 
770  Type elementType = typeConverter->convertType(srcType.getElementType());
771 
772  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
773  Value srcOffset = srcDesc.offset(rewriter, loc);
774  Value srcPtr = rewriter.create<LLVM::GEPOp>(
775  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
776  MemRefDescriptor targetDesc(adaptor.getTarget());
777  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
778  Value targetOffset = targetDesc.offset(rewriter, loc);
779  Value targetPtr = rewriter.create<LLVM::GEPOp>(
780  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
781  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
782  /*isVolatile=*/false);
783  rewriter.eraseOp(op);
784 
785  return success();
786  }
787 
788  LogicalResult
789  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
790  ConversionPatternRewriter &rewriter) const {
791  auto loc = op.getLoc();
792  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
793  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
794 
795  // First make sure we have an unranked memref descriptor representation.
796  auto makeUnranked = [&, this](Value ranked, MemRefType type) {
797  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
798  type.getRank());
799  auto *typeConverter = getTypeConverter();
800  auto ptr =
801  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
802 
803  auto unrankedType =
804  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
806  rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
807  };
808 
809  // Save stack position before promoting descriptors
810  auto stackSaveOp =
811  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
812 
813  auto srcMemRefType = dyn_cast<MemRefType>(srcType);
814  Value unrankedSource =
815  srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
816  : adaptor.getSource();
817  auto targetMemRefType = dyn_cast<MemRefType>(targetType);
818  Value unrankedTarget =
819  targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
820  : adaptor.getTarget();
821 
822  // Now promote the unranked descriptors to the stack.
823  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
824  rewriter.getIndexAttr(1));
825  auto promote = [&](Value desc) {
826  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
827  auto allocated =
828  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
829  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
830  return allocated;
831  };
832 
833  auto sourcePtr = promote(unrankedSource);
834  auto targetPtr = promote(unrankedTarget);
835 
836  // Derive size from llvm.getelementptr which will account for any
837  // potential alignment
838  auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
839  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
840  op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
841  rewriter.create<LLVM::CallOp>(loc, copyFn,
842  ValueRange{elemSize, sourcePtr, targetPtr});
843 
844  // Restore stack used for descriptors
845  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
846 
847  rewriter.eraseOp(op);
848 
849  return success();
850  }
851 
852  LogicalResult
853  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
854  ConversionPatternRewriter &rewriter) const override {
855  auto srcType = cast<BaseMemRefType>(op.getSource().getType());
856  auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
857 
858  auto isContiguousMemrefType = [&](BaseMemRefType type) {
859  auto memrefType = dyn_cast<mlir::MemRefType>(type);
860  // We can use memcpy for memrefs if they have an identity layout or are
861  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
862  // special case handled by memrefCopy.
863  return memrefType &&
864  (memrefType.getLayout().isIdentity() ||
865  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
867  };
868 
869  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
870  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
871 
872  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
873  }
874 };
875 
876 struct MemorySpaceCastOpLowering
877  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
879  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
880 
881  LogicalResult
882  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
883  ConversionPatternRewriter &rewriter) const override {
884  Location loc = op.getLoc();
885 
886  Type resultType = op.getDest().getType();
887  if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
888  auto resultDescType =
889  cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
890  Type newPtrType = resultDescType.getBody()[0];
891 
892  SmallVector<Value> descVals;
893  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
894  descVals);
895  descVals[0] =
896  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
897  descVals[1] =
898  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
899  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
900  resultTypeR, descVals);
901  rewriter.replaceOp(op, result);
902  return success();
903  }
904  if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
905  // Since the type converter won't be doing this for us, get the address
906  // space.
907  auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
908  FailureOr<unsigned> maybeSourceAddrSpace =
909  getTypeConverter()->getMemRefAddressSpace(sourceType);
910  if (failed(maybeSourceAddrSpace))
911  return rewriter.notifyMatchFailure(loc,
912  "non-integer source address space");
913  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
914  FailureOr<unsigned> maybeResultAddrSpace =
915  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
916  if (failed(maybeResultAddrSpace))
917  return rewriter.notifyMatchFailure(loc,
918  "non-integer result address space");
919  unsigned resultAddrSpace = *maybeResultAddrSpace;
920 
921  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
922  Value rank = sourceDesc.rank(rewriter, loc);
923  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
924 
925  // Create and allocate storage for new memref descriptor.
926  auto result = UnrankedMemRefDescriptor::undef(
927  rewriter, loc, typeConverter->convertType(resultTypeU));
928  result.setRank(rewriter, loc, rank);
929  SmallVector<Value, 1> sizes;
930  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
931  result, resultAddrSpace, sizes);
932  Value resultUnderlyingSize = sizes.front();
933  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
934  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
935  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
936 
937  // Copy pointers, performing address space casts.
938  auto sourceElemPtrType =
939  LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
940  auto resultElemPtrType =
941  LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
942 
943  Value allocatedPtr = sourceDesc.allocatedPtr(
944  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
945  Value alignedPtr =
946  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
947  sourceUnderlyingDesc, sourceElemPtrType);
948  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
949  loc, resultElemPtrType, allocatedPtr);
950  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
951  loc, resultElemPtrType, alignedPtr);
952 
953  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
954  resultElemPtrType, allocatedPtr);
955  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
956  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
957 
958  // Copy all the index-valued operands.
959  Value sourceIndexVals =
960  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
961  sourceUnderlyingDesc, sourceElemPtrType);
962  Value resultIndexVals =
963  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
964  resultUnderlyingDesc, resultElemPtrType);
965 
966  int64_t bytesToSkip =
967  2 * llvm::divideCeil(
968  getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
969  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
970  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
971  Value copySize = rewriter.create<LLVM::SubOp>(
972  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
973  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
974  copySize, /*isVolatile=*/false);
975 
976  rewriter.replaceOp(op, ValueRange{result});
977  return success();
978  }
979  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
980  }
981 };
982 
983 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
984 /// memref type. In unranked case, the fields are extracted from the underlying
985 /// ranked descriptor.
986 static void extractPointersAndOffset(Location loc,
987  ConversionPatternRewriter &rewriter,
988  const LLVMTypeConverter &typeConverter,
989  Value originalOperand,
990  Value convertedOperand,
991  Value *allocatedPtr, Value *alignedPtr,
992  Value *offset = nullptr) {
993  Type operandType = originalOperand.getType();
994  if (isa<MemRefType>(operandType)) {
995  MemRefDescriptor desc(convertedOperand);
996  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
997  *alignedPtr = desc.alignedPtr(rewriter, loc);
998  if (offset != nullptr)
999  *offset = desc.offset(rewriter, loc);
1000  return;
1001  }
1002 
1003  // These will all cause assert()s on unconvertible types.
1004  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1005  cast<UnrankedMemRefType>(operandType));
1006  auto elementPtrType =
1007  LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1008 
1009  // Extract pointer to the underlying ranked memref descriptor and cast it to
1010  // ElemType**.
1011  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1012  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1013 
1015  rewriter, loc, underlyingDescPtr, elementPtrType);
1017  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1018  if (offset != nullptr) {
1020  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1021  }
1022 }
1023 
1024 struct MemRefReinterpretCastOpLowering
1025  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1026  using ConvertOpToLLVMPattern<
1027  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1028 
1029  LogicalResult
1030  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1031  ConversionPatternRewriter &rewriter) const override {
1032  Type srcType = castOp.getSource().getType();
1033 
1034  Value descriptor;
1035  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1036  adaptor, &descriptor)))
1037  return failure();
1038  rewriter.replaceOp(castOp, {descriptor});
1039  return success();
1040  }
1041 
1042 private:
1043  LogicalResult convertSourceMemRefToDescriptor(
1044  ConversionPatternRewriter &rewriter, Type srcType,
1045  memref::ReinterpretCastOp castOp,
1046  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1047  MemRefType targetMemRefType =
1048  cast<MemRefType>(castOp.getResult().getType());
1049  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1050  typeConverter->convertType(targetMemRefType));
1051  if (!llvmTargetDescriptorTy)
1052  return failure();
1053 
1054  // Create descriptor.
1055  Location loc = castOp.getLoc();
1056  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1057 
1058  // Set allocated and aligned pointers.
1059  Value allocatedPtr, alignedPtr;
1060  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1061  castOp.getSource(), adaptor.getSource(),
1062  &allocatedPtr, &alignedPtr);
1063  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1064  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1065 
1066  // Set offset.
1067  if (castOp.isDynamicOffset(0))
1068  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1069  else
1070  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1071 
1072  // Set sizes and strides.
1073  unsigned dynSizeId = 0;
1074  unsigned dynStrideId = 0;
1075  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1076  if (castOp.isDynamicSize(i))
1077  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1078  else
1079  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1080 
1081  if (castOp.isDynamicStride(i))
1082  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1083  else
1084  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1085  }
1086  *descriptor = desc;
1087  return success();
1088  }
1089 };
1090 
1091 struct MemRefReshapeOpLowering
1092  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1094 
1095  LogicalResult
1096  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1097  ConversionPatternRewriter &rewriter) const override {
1098  Type srcType = reshapeOp.getSource().getType();
1099 
1100  Value descriptor;
1101  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1102  adaptor, &descriptor)))
1103  return failure();
1104  rewriter.replaceOp(reshapeOp, {descriptor});
1105  return success();
1106  }
1107 
1108 private:
1109  LogicalResult
1110  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1111  Type srcType, memref::ReshapeOp reshapeOp,
1112  memref::ReshapeOp::Adaptor adaptor,
1113  Value *descriptor) const {
1114  auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1115  if (shapeMemRefType.hasStaticShape()) {
1116  MemRefType targetMemRefType =
1117  cast<MemRefType>(reshapeOp.getResult().getType());
1118  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1119  typeConverter->convertType(targetMemRefType));
1120  if (!llvmTargetDescriptorTy)
1121  return failure();
1122 
1123  // Create descriptor.
1124  Location loc = reshapeOp.getLoc();
1125  auto desc =
1126  MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1127 
1128  // Set allocated and aligned pointers.
1129  Value allocatedPtr, alignedPtr;
1130  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1131  reshapeOp.getSource(), adaptor.getSource(),
1132  &allocatedPtr, &alignedPtr);
1133  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1134  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1135 
1136  // Extract the offset and strides from the type.
1137  int64_t offset;
1138  SmallVector<int64_t> strides;
1139  if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1140  return rewriter.notifyMatchFailure(
1141  reshapeOp, "failed to get stride and offset exprs");
1142 
1143  if (!isStaticStrideOrOffset(offset))
1144  return rewriter.notifyMatchFailure(reshapeOp,
1145  "dynamic offset is unsupported");
1146 
1147  desc.setConstantOffset(rewriter, loc, offset);
1148 
1149  assert(targetMemRefType.getLayout().isIdentity() &&
1150  "Identity layout map is a precondition of a valid reshape op");
1151 
1152  Type indexType = getIndexType();
1153  Value stride = nullptr;
1154  int64_t targetRank = targetMemRefType.getRank();
1155  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1156  if (!ShapedType::isDynamic(strides[i])) {
1157  // If the stride for this dimension is dynamic, then use the product
1158  // of the sizes of the inner dimensions.
1159  stride =
1160  createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1161  } else if (!stride) {
1162  // `stride` is null only in the first iteration of the loop. However,
1163  // since the target memref has an identity layout, we can safely set
1164  // the innermost stride to 1.
1165  stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1166  }
1167 
1168  Value dimSize;
1169  // If the size of this dimension is dynamic, then load it at runtime
1170  // from the shape operand.
1171  if (!targetMemRefType.isDynamicDim(i)) {
1172  dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1173  targetMemRefType.getDimSize(i));
1174  } else {
1175  Value shapeOp = reshapeOp.getShape();
1176  Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1177  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1178  Type indexType = getIndexType();
1179  if (dimSize.getType() != indexType)
1180  dimSize = typeConverter->materializeTargetConversion(
1181  rewriter, loc, indexType, dimSize);
1182  assert(dimSize && "Invalid memref element type");
1183  }
1184 
1185  desc.setSize(rewriter, loc, i, dimSize);
1186  desc.setStride(rewriter, loc, i, stride);
1187 
1188  // Prepare the stride value for the next dimension.
1189  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1190  }
1191 
1192  *descriptor = desc;
1193  return success();
1194  }
1195 
1196  // The shape is a rank-1 tensor with unknown length.
1197  Location loc = reshapeOp.getLoc();
1198  MemRefDescriptor shapeDesc(adaptor.getShape());
1199  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1200 
1201  // Extract address space and element type.
1202  auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1203  unsigned addressSpace =
1204  *getTypeConverter()->getMemRefAddressSpace(targetType);
1205 
1206  // Create the unranked memref descriptor that holds the ranked one. The
1207  // inner descriptor is allocated on stack.
1208  auto targetDesc = UnrankedMemRefDescriptor::undef(
1209  rewriter, loc, typeConverter->convertType(targetType));
1210  targetDesc.setRank(rewriter, loc, resultRank);
1211  SmallVector<Value, 4> sizes;
1212  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1213  targetDesc, addressSpace, sizes);
1214  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1215  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1216  sizes.front());
1217  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1218 
1219  // Extract pointers and offset from the source memref.
1220  Value allocatedPtr, alignedPtr, offset;
1221  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1222  reshapeOp.getSource(), adaptor.getSource(),
1223  &allocatedPtr, &alignedPtr, &offset);
1224 
1225  // Set pointers and offset.
1226  auto elementPtrType =
1227  LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1228 
1229  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1230  elementPtrType, allocatedPtr);
1231  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1232  underlyingDescPtr, elementPtrType,
1233  alignedPtr);
1234  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1235  underlyingDescPtr, elementPtrType,
1236  offset);
1237 
1238  // Use the offset pointer as base for further addressing. Copy over the new
1239  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1240  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1241  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1242  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1243  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1244  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1245  Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1246  Value resultRankMinusOne =
1247  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1248 
1249  Block *initBlock = rewriter.getInsertionBlock();
1250  Type indexType = getTypeConverter()->getIndexType();
1251  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1252 
1253  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1254  {indexType, indexType}, {loc, loc});
1255 
1256  // Move the remaining initBlock ops to condBlock.
1257  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1258  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1259 
1260  rewriter.setInsertionPointToEnd(initBlock);
1261  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1262  condBlock);
1263  rewriter.setInsertionPointToStart(condBlock);
1264  Value indexArg = condBlock->getArgument(0);
1265  Value strideArg = condBlock->getArgument(1);
1266 
1267  Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1268  Value pred = rewriter.create<LLVM::ICmpOp>(
1269  loc, IntegerType::get(rewriter.getContext(), 1),
1270  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1271 
1272  Block *bodyBlock =
1273  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1274  rewriter.setInsertionPointToStart(bodyBlock);
1275 
1276  // Copy size from shape to descriptor.
1277  auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1278  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1279  loc, llvmIndexPtrType,
1280  typeConverter->convertType(shapeMemRefType.getElementType()),
1281  shapeOperandPtr, indexArg);
1282  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1283  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1284  targetSizesBase, indexArg, size);
1285 
1286  // Write stride value and compute next one.
1287  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1288  targetStridesBase, indexArg, strideArg);
1289  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1290 
1291  // Decrement loop counter and branch back.
1292  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1293  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1294  condBlock);
1295 
1296  Block *remainder =
1297  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1298 
1299  // Hook up the cond exit to the remainder.
1300  rewriter.setInsertionPointToEnd(condBlock);
1301  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1302  remainder, std::nullopt);
1303 
1304  // Reset position to beginning of new remainder block.
1305  rewriter.setInsertionPointToStart(remainder);
1306 
1307  *descriptor = targetDesc;
1308  return success();
1309  }
1310 };
1311 
1312 /// RessociatingReshapeOp must be expanded before we reach this stage.
1313 /// Report that information.
1314 template <typename ReshapeOp>
1315 class ReassociatingReshapeOpConversion
1316  : public ConvertOpToLLVMPattern<ReshapeOp> {
1317 public:
1319  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1320 
1321  LogicalResult
1322  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1323  ConversionPatternRewriter &rewriter) const override {
1324  return rewriter.notifyMatchFailure(
1325  reshapeOp,
1326  "reassociation operations should have been expanded beforehand");
1327  }
1328 };
1329 
1330 /// Subviews must be expanded before we reach this stage.
1331 /// Report that information.
1332 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1334 
1335  LogicalResult
1336  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1337  ConversionPatternRewriter &rewriter) const override {
1338  return rewriter.notifyMatchFailure(
1339  subViewOp, "subview operations should have been expanded beforehand");
1340  }
1341 };
1342 
1343 /// Conversion pattern that transforms a transpose op into:
1344 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1345 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1346 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1347 /// and stride. Size and stride are permutations of the original values.
1348 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1349 /// The transpose op is replaced by the alloca'ed pointer.
1350 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1351 public:
1353 
1354  LogicalResult
1355  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1356  ConversionPatternRewriter &rewriter) const override {
1357  auto loc = transposeOp.getLoc();
1358  MemRefDescriptor viewMemRef(adaptor.getIn());
1359 
1360  // No permutation, early exit.
1361  if (transposeOp.getPermutation().isIdentity())
1362  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1363 
1364  auto targetMemRef = MemRefDescriptor::undef(
1365  rewriter, loc,
1366  typeConverter->convertType(transposeOp.getIn().getType()));
1367 
1368  // Copy the base and aligned pointers from the old descriptor to the new
1369  // one.
1370  targetMemRef.setAllocatedPtr(rewriter, loc,
1371  viewMemRef.allocatedPtr(rewriter, loc));
1372  targetMemRef.setAlignedPtr(rewriter, loc,
1373  viewMemRef.alignedPtr(rewriter, loc));
1374 
1375  // Copy the offset pointer from the old descriptor to the new one.
1376  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1377 
1378  // Iterate over the dimensions and apply size/stride permutation:
1379  // When enumerating the results of the permutation map, the enumeration
1380  // index is the index into the target dimensions and the DimExpr points to
1381  // the dimension of the source memref.
1382  for (const auto &en :
1383  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1384  int targetPos = en.index();
1385  int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1386  targetMemRef.setSize(rewriter, loc, targetPos,
1387  viewMemRef.size(rewriter, loc, sourcePos));
1388  targetMemRef.setStride(rewriter, loc, targetPos,
1389  viewMemRef.stride(rewriter, loc, sourcePos));
1390  }
1391 
1392  rewriter.replaceOp(transposeOp, {targetMemRef});
1393  return success();
1394  }
1395 };
1396 
1397 /// Conversion pattern that transforms an op into:
1398 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1399 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1400 /// and stride.
1401 /// The view op is replaced by the descriptor.
1402 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1404 
1405  // Build and return the value for the idx^th shape dimension, either by
1406  // returning the constant shape dimension or counting the proper dynamic size.
1407  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1408  ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1409  Type indexType) const {
1410  assert(idx < shape.size());
1411  if (!ShapedType::isDynamic(shape[idx]))
1412  return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1413  // Count the number of dynamic dims in range [0, idx]
1414  unsigned nDynamic =
1415  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1416  return dynamicSizes[nDynamic];
1417  }
1418 
1419  // Build and return the idx^th stride, either by returning the constant stride
1420  // or by computing the dynamic stride from the current `runningStride` and
1421  // `nextSize`. The caller should keep a running stride and update it with the
1422  // result returned by this function.
1423  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1424  ArrayRef<int64_t> strides, Value nextSize,
1425  Value runningStride, unsigned idx, Type indexType) const {
1426  assert(idx < strides.size());
1427  if (!ShapedType::isDynamic(strides[idx]))
1428  return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1429  if (nextSize)
1430  return runningStride
1431  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1432  : nextSize;
1433  assert(!runningStride);
1434  return createIndexAttrConstant(rewriter, loc, indexType, 1);
1435  }
1436 
1437  LogicalResult
1438  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1439  ConversionPatternRewriter &rewriter) const override {
1440  auto loc = viewOp.getLoc();
1441 
1442  auto viewMemRefType = viewOp.getType();
1443  auto targetElementTy =
1444  typeConverter->convertType(viewMemRefType.getElementType());
1445  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1446  if (!targetDescTy || !targetElementTy ||
1447  !LLVM::isCompatibleType(targetElementTy) ||
1448  !LLVM::isCompatibleType(targetDescTy))
1449  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1450  failure();
1451 
1452  int64_t offset;
1453  SmallVector<int64_t, 4> strides;
1454  auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1455  if (failed(successStrides))
1456  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1457  assert(offset == 0 && "expected offset to be 0");
1458 
1459  // Target memref must be contiguous in memory (innermost stride is 1), or
1460  // empty (special case when at least one of the memref dimensions is 0).
1461  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1462  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1463  failure();
1464 
1465  // Create the descriptor.
1466  MemRefDescriptor sourceMemRef(adaptor.getSource());
1467  auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1468 
1469  // Field 1: Copy the allocated pointer, used for malloc/free.
1470  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1471  auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1472  targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1473 
1474  // Field 2: Copy the actual aligned pointer to payload.
1475  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1476  alignedPtr = rewriter.create<LLVM::GEPOp>(
1477  loc, alignedPtr.getType(),
1478  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1479  adaptor.getByteShift());
1480 
1481  targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1482 
1483  Type indexType = getIndexType();
1484  // Field 3: The offset in the resulting type must be 0. This is
1485  // because of the type change: an offset on srcType* may not be
1486  // expressible as an offset on dstType*.
1487  targetMemRef.setOffset(
1488  rewriter, loc,
1489  createIndexAttrConstant(rewriter, loc, indexType, offset));
1490 
1491  // Early exit for 0-D corner case.
1492  if (viewMemRefType.getRank() == 0)
1493  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1494 
1495  // Fields 4 and 5: Update sizes and strides.
1496  Value stride = nullptr, nextSize = nullptr;
1497  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1498  // Update size.
1499  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1500  adaptor.getSizes(), i, indexType);
1501  targetMemRef.setSize(rewriter, loc, i, size);
1502  // Update stride.
1503  stride =
1504  getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1505  targetMemRef.setStride(rewriter, loc, i, stride);
1506  nextSize = size;
1507  }
1508 
1509  rewriter.replaceOp(viewOp, {targetMemRef});
1510  return success();
1511  }
1512 };
1513 
1514 //===----------------------------------------------------------------------===//
1515 // AtomicRMWOpLowering
1516 //===----------------------------------------------------------------------===//
1517 
1518 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1519 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1520 static std::optional<LLVM::AtomicBinOp>
1521 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1522  switch (atomicOp.getKind()) {
1523  case arith::AtomicRMWKind::addf:
1524  return LLVM::AtomicBinOp::fadd;
1525  case arith::AtomicRMWKind::addi:
1526  return LLVM::AtomicBinOp::add;
1527  case arith::AtomicRMWKind::assign:
1528  return LLVM::AtomicBinOp::xchg;
1529  case arith::AtomicRMWKind::maximumf:
1530  return LLVM::AtomicBinOp::fmax;
1531  case arith::AtomicRMWKind::maxs:
1532  return LLVM::AtomicBinOp::max;
1533  case arith::AtomicRMWKind::maxu:
1534  return LLVM::AtomicBinOp::umax;
1535  case arith::AtomicRMWKind::minimumf:
1536  return LLVM::AtomicBinOp::fmin;
1537  case arith::AtomicRMWKind::mins:
1538  return LLVM::AtomicBinOp::min;
1539  case arith::AtomicRMWKind::minu:
1540  return LLVM::AtomicBinOp::umin;
1541  case arith::AtomicRMWKind::ori:
1542  return LLVM::AtomicBinOp::_or;
1543  case arith::AtomicRMWKind::andi:
1544  return LLVM::AtomicBinOp::_and;
1545  default:
1546  return std::nullopt;
1547  }
1548  llvm_unreachable("Invalid AtomicRMWKind");
1549 }
1550 
1551 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1552  using Base::Base;
1553 
1554  LogicalResult
1555  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1556  ConversionPatternRewriter &rewriter) const override {
1557  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1558  if (!maybeKind)
1559  return failure();
1560  auto memRefType = atomicOp.getMemRefType();
1561  SmallVector<int64_t> strides;
1562  int64_t offset;
1563  if (failed(getStridesAndOffset(memRefType, strides, offset)))
1564  return failure();
1565  auto dataPtr =
1566  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1567  adaptor.getIndices(), rewriter);
1568  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1569  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1570  LLVM::AtomicOrdering::acq_rel);
1571  return success();
1572  }
1573 };
1574 
1575 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1576 class ConvertExtractAlignedPointerAsIndex
1577  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1578 public:
1579  using ConvertOpToLLVMPattern<
1580  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1581 
1582  LogicalResult
1583  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1584  OpAdaptor adaptor,
1585  ConversionPatternRewriter &rewriter) const override {
1586  BaseMemRefType sourceTy = extractOp.getSource().getType();
1587 
1588  Value alignedPtr;
1589  if (sourceTy.hasRank()) {
1590  MemRefDescriptor desc(adaptor.getSource());
1591  alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1592  } else {
1593  auto elementPtrTy = LLVM::LLVMPointerType::get(
1594  rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1595 
1596  UnrankedMemRefDescriptor desc(adaptor.getSource());
1597  Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1598 
1600  rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1601  elementPtrTy);
1602  }
1603 
1604  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1605  extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1606  return success();
1607  }
1608 };
1609 
1610 /// Materialize the MemRef descriptor represented by the results of
1611 /// ExtractStridedMetadataOp.
1612 class ExtractStridedMetadataOpLowering
1613  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1614 public:
1615  using ConvertOpToLLVMPattern<
1616  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1617 
1618  LogicalResult
1619  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1620  OpAdaptor adaptor,
1621  ConversionPatternRewriter &rewriter) const override {
1622 
1623  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1624  return failure();
1625 
1626  // Create the descriptor.
1627  MemRefDescriptor sourceMemRef(adaptor.getSource());
1628  Location loc = extractStridedMetadataOp.getLoc();
1629  Value source = extractStridedMetadataOp.getSource();
1630 
1631  auto sourceMemRefType = cast<MemRefType>(source.getType());
1632  int64_t rank = sourceMemRefType.getRank();
1633  SmallVector<Value> results;
1634  results.reserve(2 + rank * 2);
1635 
1636  // Base buffer.
1637  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1638  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1640  rewriter, loc, *getTypeConverter(),
1641  cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1642  baseBuffer, alignedBuffer);
1643  results.push_back((Value)dstMemRef);
1644 
1645  // Offset.
1646  results.push_back(sourceMemRef.offset(rewriter, loc));
1647 
1648  // Sizes.
1649  for (unsigned i = 0; i < rank; ++i)
1650  results.push_back(sourceMemRef.size(rewriter, loc, i));
1651  // Strides.
1652  for (unsigned i = 0; i < rank; ++i)
1653  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1654 
1655  rewriter.replaceOp(extractStridedMetadataOp, results);
1656  return success();
1657  }
1658 };
1659 
1660 } // namespace
1661 
1663  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1664  // clang-format off
1665  patterns.add<
1666  AllocaOpLowering,
1667  AllocaScopeOpLowering,
1668  AtomicRMWOpLowering,
1669  AssumeAlignmentOpLowering,
1670  ConvertExtractAlignedPointerAsIndex,
1671  DimOpLowering,
1672  ExtractStridedMetadataOpLowering,
1673  GenericAtomicRMWOpLowering,
1674  GlobalMemrefOpLowering,
1675  GetGlobalMemrefOpLowering,
1676  LoadOpLowering,
1677  MemRefCastOpLowering,
1678  MemRefCopyOpLowering,
1679  MemorySpaceCastOpLowering,
1680  MemRefReinterpretCastOpLowering,
1681  MemRefReshapeOpLowering,
1682  PrefetchOpLowering,
1683  RankOpLowering,
1684  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1685  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1686  StoreOpLowering,
1687  SubViewOpLowering,
1689  ViewOpLowering>(converter);
1690  // clang-format on
1691  auto allocLowering = converter.getOptions().allocLowering;
1693  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1694  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1695  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1696 }
1697 
1698 namespace {
1699 struct FinalizeMemRefToLLVMConversionPass
1700  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1701  FinalizeMemRefToLLVMConversionPass> {
1702  using FinalizeMemRefToLLVMConversionPassBase::
1703  FinalizeMemRefToLLVMConversionPassBase;
1704 
1705  void runOnOperation() override {
1706  Operation *op = getOperation();
1707  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1709  dataLayoutAnalysis.getAtOrAbove(op));
1710  options.allocLowering =
1713 
1714  options.useGenericFunctions = useGenericFunctions;
1715 
1716  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1717  options.overrideIndexBitwidth(indexBitwidth);
1718 
1719  LLVMTypeConverter typeConverter(&getContext(), options,
1720  &dataLayoutAnalysis);
1723  LLVMConversionTarget target(getContext());
1724  target.addLegalOp<func::FuncOp>();
1725  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1726  signalPassFailure();
1727  }
1728 };
1729 
1730 /// Implement the interface to convert MemRef to LLVM.
1731 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1733  void loadDependentDialects(MLIRContext *context) const final {
1734  context->loadDialect<LLVM::LLVMDialect>();
1735  }
1736 
1737  /// Hook for derived dialect interface to provide conversion patterns
1738  /// and mark dialect legal for the conversion target.
1739  void populateConvertToLLVMConversionPatterns(
1740  ConversionTarget &target, LLVMTypeConverter &typeConverter,
1741  RewritePatternSet &patterns) const final {
1743  }
1744 };
1745 
1746 } // namespace
1747 
1749  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1750  dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1751  });
1752 }
static MLIRContext * getContext(OpFoldResult val)
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI8Type()
Definition: Builders.cpp:103
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The main mechanism for performing data layout queries.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:451
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:415
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void computeSizes(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Lowering for AllocOp and AllocaOp.